From d0568fb679ffa0bcd30625d5b9393be5000a997e Mon Sep 17 00:00:00 2001 From: Nicola Franco Date: Sat, 23 May 2026 15:54:20 +0200 Subject: [PATCH 01/23] feat(router): unify chat adapters on LiteLLM (#379) Route every chat-completion AgentType through `litellm.completion`. The LiteLLMAgent base now owns provider-prefix routing, the unified `thinking` knob, tool-calls, and reasoning-content extraction. OpenAI and Ollama agents collapse to thin subclasses that pin the provider prefix (`openai`, `ollama_chat`) and translate `thinking` to `reasoning_effort` or Ollama's `think` field. ADK still requires HTTP transport LiteLLM doesn't speak natively, so the adapter registers a per-instance `litellm.CustomLLM` handler that implements the `POST /run` + sessions + events protocol and packages the result into a `ModelResponse`. From the router's perspective ADK is now just another LiteLLM provider. Router config branches collapse: all chat AgentTypes share the same metadata-merge path; ADK only adds `user_id`. Unit tests for OpenAI, Ollama, and ADK rewritten against `litellm.completion` patching, and the OpenAI integration test no longer asserts on a non-existent SDK client. LITELLM_ROUTER_REFACTOR_PLAN.md captures the follow-up plan to move the call site into `router.py` and shrink the adapters folder further. Co-Authored-By: Claude Opus 4.7 (1M context) --- LITELLM_ROUTER_REFACTOR_PLAN.md | 304 +++++++ hackagent/router/adapters/google_adk.py | 957 +++++++++------------- hackagent/router/adapters/litellm.py | 430 ++++++---- hackagent/router/adapters/ollama.py | 539 ++---------- hackagent/router/adapters/openai.py | 548 ++----------- hackagent/router/router.py | 178 ++-- tests/integration/adapters/test_openai.py | 5 +- tests/unit/adapters/test_google_adk.py | 413 ++++++---- tests/unit/adapters/test_ollama.py | 624 ++++---------- tests/unit/adapters/test_openai.py | 860 ++++--------------- 10 files changed, 1742 insertions(+), 3116 deletions(-) create mode 100644 LITELLM_ROUTER_REFACTOR_PLAN.md diff --git a/LITELLM_ROUTER_REFACTOR_PLAN.md b/LITELLM_ROUTER_REFACTOR_PLAN.md new file mode 100644 index 00000000..c8e9822c --- /dev/null +++ b/LITELLM_ROUTER_REFACTOR_PLAN.md @@ -0,0 +1,304 @@ +# Plan — Collapse adapters into `router.py` on top of LiteLLM + +**Tracks:** Issue [#379](https://github.com/AISecurityLab/hackagent/issues/379) and follow-up. +**Status:** Draft. +**Author:** generated 2026-05-23. + +This doc proposes the second step of #379: now that every chat-completion +adapter already routes through LiteLLM (PR on branch +`feat/litellm-unified-adapters-379`), make the most of LiteLLM by moving +the remaining adapter responsibilities into `hackagent/router/router.py` +and a thin set of helpers, and reduce `hackagent/router/adapters/` to the +irreducible minimum (gap-fillers for non-chat protocols like Google ADK). + +--- + +## 1. What LiteLLM gives us + +Verified against the upstream docs while drafting this plan: + +| Capability | LiteLLM API | Notes | +|---|---|---| +| Unified call across 140+ providers | `litellm.completion(model="/", ...)` | OpenAI Chat Completion schema is the lingua franca; provider-specific fields are translated automatically where possible. | +| Standardized response | `ModelResponse` — `choices[0].message.{content, role, tool_calls, reasoning_content}`, `usage`, `finish_reason`, `model`, `id`, `created` | Provider-specific extras surface in `provider_specific_fields`. | +| Call identifier for correlation | `x-litellm-call-id` response header; `litellm_call_id` in callback `kwargs` | Lets us join input / output / cost without our own UUID. | +| Per-call metadata that flows to callbacks | `litellm.completion(..., metadata={...})` → `kwargs["litellm_params"]["metadata"]` | This is how we will attach `agent_id` / `registration_key` to every call. | +| Lifecycle hooks | `CustomLogger` (`log_pre_api_call`, `log_post_api_call`, `log_success_event`, `log_failure_event`, async variants) | Registered globally via `litellm.callbacks = [handler]`. | +| Cost tracking | `kwargs["response_cost"]`, `litellm.completion_cost(...)` | Free metric for HackAgent traces. | +| Custom provider extension | Subclass `litellm.CustomLLM`, register in `litellm.custom_provider_map` | Already used by `ADKAgent` after #379. | +| Multi-deployment routing | `litellm.Router(model_list=[...])` | Optional; out of scope for #379 but worth keeping in mind for HA. | +| Streaming | `stream=True` returns a `CustomStreamWrapper` | Not used today; design must not foreclose it. | +| Standardized knobs | `tools`, `tool_choice`, `response_format`, `stream`, `seed`, `stop`, `logprobs`, `presence_penalty`, `frequency_penalty`, `user`, `reasoning_effort`, `thinking` | The reasoning knobs are translated per provider. | + +Sources: +- +- +- +- +- +- + +--- + +## 2. Where we are after this PR + +``` +hackagent/router/ +├── router.py # AgentRouter: registers backend agent, dispatches +├── types.py # AgentTypeEnum +└── adapters/ + ├── base.py # Agent, ChatCompletionsAgent, envelope helpers + ├── litellm.py # LiteLLMAgent — unified LiteLLM wrapper (handles + │ # thinking, tools, prefix, response shaping) + ├── openai.py # OpenAIAgent(LiteLLMAgent) — `openai/` prefix + + │ # `reasoning_effort` translation + ├── ollama.py # OllamaAgent(LiteLLMAgent) — `ollama_chat/` prefix + │ # + `think` translation + Ollama extras + └── google_adk.py # ADKAgent(LiteLLMAgent) registers a CustomLLM + # that speaks ADK's /run + sessions protocol +``` + +Issue: `OpenAIAgent` and `OllamaAgent` are now nearly trivial — each one +exists to (a) set a provider prefix and (b) translate `thinking`. The +`base.ChatCompletionsAgent` template is overkill once everything goes +through `litellm.completion(...)`. The envelope-building code in +`base.py` is the only meaningful logic that's still adapter-shaped. + +## 3. Target architecture + +Move from "one adapter class per AgentType" to "one entry point in +`router.py` that calls LiteLLM, with a small per-AgentType config +table and a CustomLogger that fills the HackAgent envelope". + +``` +hackagent/router/ +├── router.py # AgentRouter.route_request() → litellm.completion() +│ # + envelope building from ModelResponse + +│ # registers a CustomLogger for I/O capture. +├── types.py # AgentTypeEnum (unchanged) +├── provider_config.py # AgentType → ProviderConfig table: +│ # - litellm_prefix (e.g. "openai", "ollama_chat", None) +│ # - thinking_translator (callable) +│ # - extra_param_keys (allow-list) +│ # - custom_llm_factory (for ADK / MCP / A2A) +├── envelope.py # Pure functions: build_success / build_error +│ # from a ModelResponse → HackAgent dict. +├── tracking_logger.py # CustomLogger that emits StepTracker events +│ # for log_pre_api_call / log_success_event / +│ # log_failure_event. +└── providers/ + └── adk_custom_llm.py # The ADK CustomLLM (and future MCP/A2A). +``` + +Notes: + +- **No more chat-adapter classes.** Everything for AgentType in + `{LITELLM, OPENAI_SDK, OLLAMA, LANGCHAIN}` is handled by the same + code path in `router.py`, parameterised by `ProviderConfig`. +- **ADK / MCP / A2A** stay as `CustomLLM` providers — that's the + irreducible exception for protocols LiteLLM doesn't speak natively. +- **The envelope shape doesn't change.** Downstream consumers + (`StepTracker`, `advprefix`, `pair`, evaluators, dashboard) keep + the dict shape they get today; we just build it in one place + instead of in N adapter subclasses. + +### Request flow, end to end + +``` +AgentRouter.route_request(registration_key, request_data) + └─ resolve ProviderConfig for the stored AgentType + └─ build litellm kwargs: + model = provider.litellm_model(name) + messages = request_data["messages"] or [{role:"user", content:prompt}] + max_tokens / temperature / top_p / tools / tool_choice + thinking = provider.translate_thinking(request_data.get("thinking")) + metadata = { + "hackagent_agent_id": registration_key, + "hackagent_adapter_type": provider.adapter_label, + "hackagent_org_id": ..., + } + api_base, api_key, extra_body, ... + └─ try: response = litellm.completion(**kwargs) + (LiteLLM dispatches via Router or CustomLLM as appropriate) + (CustomLogger.log_success_event fires → StepTracker is updated) + └─ envelope.build_success(response, request_data, metadata) + └─ return envelope dict (same shape as today) +``` + +For errors the path is symmetric: `envelope.build_error` is fed by +either the exception itself or by `log_failure_event` data depending on +whether the error originates pre-call or post-call. + +--- + +## 4. What stays, what goes, what's new + +### Stays +- `AgentRouter` (its lifecycle responsibilities — backend agent + creation, registration-key mapping — are independent of LLM transport + and unchanged). +- `AgentTypeEnum`. +- The CustomLLM extension point — that's still how ADK is wired in. +- The current external envelope shape (downstream code expects it). +- `LiteLLMConfigurationError` / `AdapterConfigurationError` style + exceptions, possibly relocated under `router/`. + +### Goes +- `adapters/base.py` (`Agent`, `ChatCompletionsAgent`, the abstract + `handle_request` dance). Its logic becomes free functions in + `router/envelope.py` and `router/router.py`. +- `adapters/litellm.py` as a class — its `_prepare_litellm_params`, + `_extract_raw_response_content`, `_extract_tool_calls` become + functions in `router/envelope.py`. +- `adapters/openai.py` and `adapters/ollama.py` — collapsed into + entries in `router/provider_config.py`. +- The hard-coded thinking translation in each subclass — replaced by + one `translate_thinking` callable per AgentType. + +### New +- `router/provider_config.py` — single source of truth for + AgentType → behaviour mapping. +- `router/envelope.py` — pure helpers from `ModelResponse` to the + HackAgent dict. +- `router/tracking_logger.py` — `CustomLogger` subclass that: + - On `log_pre_api_call`: emits a `StepTracker` "request_started" + event with `metadata["hackagent_agent_id"]` and the prompt + preview. + - On `log_success_event`: emits a `StepTracker` "request_finished" + event with response text, usage, cost, `x-litellm-call-id`. + - On `log_failure_event`: emits a "request_failed" event with the + exception type + message. +- `router/providers/adk_custom_llm.py` — the ADK custom provider, + moved out of `adapters/`. + +### Migration map + +| Today | Tomorrow | +|---|---| +| `adapters/litellm.py::LiteLLMAgent._prepare_litellm_params` | `router/envelope.py::build_litellm_kwargs` | +| `adapters/litellm.py::LiteLLMAgent._extract_raw_response_content` | `router/envelope.py::extract_text` | +| `adapters/litellm.py::LiteLLMAgent._extract_tool_calls` | `router/envelope.py::extract_tool_calls` | +| `adapters/base.py::Agent._build_success_response` | `router/envelope.py::build_success` | +| `adapters/base.py::Agent._build_error_response` | `router/envelope.py::build_error` | +| `adapters/base.py::ChatCompletionsAgent.handle_request` | inlined in `AgentRouter.route_request` | +| `adapters/openai.py::OpenAIAgent` | `ProviderConfig(prefix="openai", translate_thinking=openai_thinking)` | +| `adapters/ollama.py::OllamaAgent` | `ProviderConfig(prefix="ollama_chat", translate_thinking=ollama_thinking, extra_keys={"top_k","num_ctx","stream"})` | +| `adapters/google_adk.py::ADKAgent + _ADKCustomLLM` | `router/providers/adk_custom_llm.py::ADKCustomLLM` registered when AgentType=GOOGLE_ADK | +| `adapters/base.py::Agent._strip_think_prefix` | `router/envelope.py::strip_think_prefix` (called inside `build_success`) | + +--- + +## 5. Phased execution + +Each phase is independently shippable; each ends with `pytest tests/unit` +green. + +### Phase A — Extract pure helpers (no behaviour change) +1. Create `router/envelope.py`. Move `_strip_think_prefix`, + `_extract_raw_response_content`, `_extract_tool_calls`, + `_build_success_response`, `_build_error_response`, + `_prepare_litellm_params` out of the adapters as **free functions** + that take a `ProviderConfig`-like argument. +2. Have the current adapter classes delegate to those functions so + their public behaviour is identical. +3. Tests untouched; they still pass. + +### Phase B — Introduce `ProviderConfig` table +1. Create `router/provider_config.py` with one entry per AgentType. +2. Have `LiteLLMAgent`, `OpenAIAgent`, `OllamaAgent` initialise + themselves from their corresponding `ProviderConfig` (instead of + class-level `PROVIDER_PREFIX` and method overrides). +3. The class structure still exists but the only difference between + the three is the config they look up. Tests still pass. + +### Phase C — Hoist call path into `AgentRouter` +1. Add `AgentRouter._dispatch_via_litellm(registration_key, + request_data)` that builds litellm kwargs from the + `ProviderConfig` and calls `litellm.completion(...)`. +2. Make `AgentRouter.route_request` use this path for every AgentType + in `{LITELLM, OPENAI_SDK, OLLAMA, LANGCHAIN}`, bypassing the + adapter classes. +3. Keep `GOOGLE_ADK` on the adapter path (or already-registered + CustomLLM — even simpler). +4. Mark the chat adapter classes as deprecated. + +### Phase D — Wire `CustomLogger` for I/O capture +1. Implement `router/tracking_logger.py::HackAgentTrackingLogger`. +2. On `AgentRouter.__init__`, register a single instance on + `litellm.callbacks` (idempotent — guard against double-registration). +3. Pass `metadata={"hackagent_agent_id": ..., ...}` on every + `litellm.completion(...)` call. +4. Move the `🌐 Querying model …` / `✅ Model responded …` logging + from the adapters into the logger. + +### Phase E — Delete the chat adapter classes +1. Remove `adapters/litellm.py`, `adapters/openai.py`, + `adapters/ollama.py`, the `ChatCompletionsAgent` parts of + `adapters/base.py`. +2. Move `adapters/google_adk.py` to `router/providers/adk_custom_llm.py`. +3. Rename `adapters/__init__.py` exports to point at the new + locations (keep import aliases for one release for backwards + compatibility). +4. Update tests to match the new layout — most existing tests can be + reused with import-path edits since they already patch + `litellm.completion` after this PR. + +### Phase F — Optional follow-ups +- Adopt `litellm.Router` for built-in load balancing / fallback / + rate-limit awareness when an org configures multiple endpoints + per agent. +- Standardise on `metadata` for richer downstream filtering + (`org_id`, `attack_id`, `evaluator_id`). +- Surface `response_cost` and `x-litellm-call-id` in the envelope + so attack reports can include cost-per-attempt. +- Streaming support (`stream=True`) — needs a separate envelope + path that yields incrementally. + +--- + +## 6. Risks and how to mitigate them + +| Risk | Likelihood | Mitigation | +|---|---|---| +| Downstream code depends on the envelope dict shape. | High | Keep the shape byte-identical in Phase A; only the building code moves. Add a dedicated test that snapshots the dict for a known input. | +| Global `litellm.callbacks` may interfere with user-supplied callbacks. | Medium | Register our logger only when an `AgentRouter` is constructed; tag with `metadata["hackagent_owned"] = True` so we ignore other apps' calls. | +| CustomLogger doesn't fire for exceptions raised *before* the API call (e.g. bad config). | Medium | Handle pre-call errors directly in `AgentRouter._dispatch_via_litellm`; the logger only covers the post-call path. | +| LiteLLM bumps the `kwargs` schema in callbacks. | Low | Pin LiteLLM in `pyproject.toml`; add a smoke test that imports and triggers the callback against the pinned version. | +| ADK CustomLLM registration leaks across tests. | Low | Already mitigated in this PR (`custom_provider_map` is filtered before append). Add a fixture that snapshots / restores `litellm.custom_provider_map` between tests. | +| LangChain / MCP / A2A AgentTypes need their own gap-fillers eventually. | Medium | Reserve `ProviderConfig.custom_llm_factory` from day one so adding them later is one entry plus one file. | + +--- + +## 7. Open questions + +1. Do we want `AgentRouter` to own a single global `litellm.callbacks` + registration, or one logger instance per `AgentRouter`? A single + global is simpler and matches LiteLLM's design. +2. Should the envelope grow new fields now that LiteLLM gives us + `response_cost` and `litellm_call_id` for free? (Recommendation: yes, + add them as optional fields without removing anything.) +3. Should we keep an `Agent` abstract class for type hints elsewhere in + the codebase, or fully delete it? (Recommendation: delete; replace + with a `Protocol` if any caller actually needs it — most don't.) +4. Do we want to expose `litellm.Router` semantics in + `AgentRouter`, or is `AgentRouter` strictly about the HackAgent + side and `litellm.Router` would be configured independently? (Lean: + keep them separate; `AgentRouter` can *use* `litellm.Router` under + the hood when an agent has multiple deployments.) + +--- + +## 8. Definition of done + +- `hackagent/router/adapters/` contains at most: `__init__.py`, + `base.py` (only the exception classes), and possibly nothing else. +- `hackagent/router/providers/adk_custom_llm.py` is the only file that + knows about the ADK protocol. +- `router.py` calls `litellm.completion` directly. +- A `HackAgentTrackingLogger` subclass of `CustomLogger` is responsible + for emitting `StepTracker` events. +- The router-level test confirms the envelope dict matches the + pre-refactor shape for at least: prompt-only request, messages-only + request, error path, and ADK request. +- All existing example scripts under `hackagent/examples/` keep working + without code edits. diff --git a/hackagent/router/adapters/google_adk.py b/hackagent/router/adapters/google_adk.py index 7ba16dc3..66d881cd 100644 --- a/hackagent/router/adapters/google_adk.py +++ b/hackagent/router/adapters/google_adk.py @@ -1,671 +1,444 @@ # Copyright 2026 - AI4I. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +""" +Google ADK (Agent Development Kit) adapter built on top of LiteLLM. + +LiteLLM has no built-in provider for the ADK server protocol (POST /run +with sessions and events), so issue #379 routes ADK through LiteLLM by +registering a per-instance :class:`litellm.CustomLLM` handler under a +unique provider name. The HTTP transport against the deployed ADK server +lives in the lazily-defined ``_ADKCustomLLM`` class, while +:class:`ADKAgent` itself is a thin :class:`LiteLLMAgent` subclass that +registers the handler and asks LiteLLM to route through it. +""" import json +import uuid from hackagent.logger import get_logger -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional import requests -from requests.structures import CaseInsensitiveDict -from hackagent.router.adapters.base import ( - Agent, +from .base import ( AdapterConfigurationError, AdapterInteractionError, AdapterResponseParsingError, ) +from .litellm import LiteLLMAgent, _get_litellm + -# Global logger for this module, can be used by utility functions too logger = get_logger(__name__) -# --- Custom Exceptions (subclass from base) --- +# --- Custom exceptions (kept for backwards compatibility) --- class AgentConfigurationError(AdapterConfigurationError): - """Custom exception for agent configuration issues.""" + """ADK adapter configuration issues.""" pass class AgentInteractionError(AdapterInteractionError): - """Custom exception for errors during interaction with the agent API.""" + """Errors interacting with the ADK agent server.""" pass class ResponseParsingError(AdapterResponseParsingError): - """Custom exception for errors parsing the agent's response.""" + """Errors parsing the ADK server's event-list response.""" pass -class ADKAgent(Agent): - """ - Adapter for interacting with ADK (Agent Development Kit) based agents. - - This class implements the common `Agent` interface. It translates requests - and responses between the router's standard format and the specific format - required by ADK agents. It encapsulates all logic for ADK communication, - including session management (optional), request formatting, execution, - response parsing, and error handling. - - Attributes: - name (str): The name of the ADK application (used for router registration AND as ADK app identifier). - endpoint (str): The base API endpoint for the ADK agent. - user_id (str): The user identifier for ADK sessions. - timeout (int): Timeout in seconds for requests to the ADK agent. - logger (logging.Logger): Logger instance for this adapter. - """ - - ADAPTER_TYPE = "ADKAgent" - - def __init__(self, id: str, config: Dict[str, Any]): - """ - Initializes the ADKAgent. - - Args: - id: The unique identifier for this ADK agent instance. - config: Configuration dictionary for the ADK agent. - Expected keys include: - - 'name': Name of the ADK application (e.g., 'multi_tool_agent'). - - 'endpoint': Base URL of the ADK agent. - - 'user_id': User ID for the ADK session. - - 'timeout' (optional): Request timeout in seconds - (defaults to 120). - - Raises: - AgentConfigurationError: If any required configuration key (name, endpoint, user_id) is missing. - """ - super().__init__(id, config) - - # Validate required configuration keys using base class helper - self.name: str = self._require_config_key("name", AgentConfigurationError) - endpoint_raw: str = self._require_config_key( - "endpoint", AgentConfigurationError - ) - self.user_id: str = self._require_config_key("user_id", AgentConfigurationError) +_ADK_PROVIDER_PREFIX = "hackagent_adk" + + +def _last_user_text(messages: List[Dict[str, Any]]) -> Optional[str]: + """Return the text of the last user message in ``messages``.""" + for msg in reversed(messages or []): + if (msg or {}).get("role") != "user": + continue + content = msg.get("content") + if isinstance(content, str): + return content + # OpenAI-style content lists. + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + text = part.get("text") + if isinstance(text, str): + return text + return None + + +def _extract_final_text(events: List[Dict[str, Any]]) -> Optional[str]: + """Walk ``events`` newest-first and return the agent's final reply.""" + for event in reversed(events): + actions = event.get("actions") + if actions and isinstance(actions, dict) and actions.get("escalate"): + error_msg = event.get( + "error_message", + "No specific message provided by agent for escalation.", + ) + return f"Agent escalated: {error_msg}" - self.endpoint: str = endpoint_raw.strip("/") - self.timeout: int = self._get_config_key("timeout", 120) + content = event.get("content") + if not isinstance(content, dict): + continue + parts = content.get("parts") + if not isinstance(parts, list) or not parts: + continue + first = parts[0] + if not isinstance(first, dict): + continue + text = first.get("text") + if isinstance(text, str) and text.strip(): + return text + return None - # Option to use a fresh session for each request (useful for attack scenarios - # where session state pollution can cause issues) - self.fresh_session_per_request: bool = self._get_config_key( - "fresh_session_per_request", True - ) - # Generate a unique session ID for this adapter instance - # This keeps session state persistent across multiple requests to the same agent - import uuid +_ADK_CUSTOM_LLM_CLASS = None - self.session_id: str = self._get_config_key("session_id", str(uuid.uuid4())) - self.logger.info( - f"ADKAgent initialized with session_id: {self.session_id}, " - f"fresh_session_per_request: {self.fresh_session_per_request}" - ) +def _get_adk_custom_llm_class(): + """Lazily build the CustomLLM subclass once litellm is importable. - def _initialize_session( - self, session_id_to_init: str, initial_state: Optional[dict] = None - ) -> bool: - """ - (Optional) Creates or ensures a specific ADK session exists. - - Args: - session_id_to_init: The specific session ID to initialize. - initial_state: An optional dictionary to provide initial state when - creating the ADK session. - Returns: - True if the session was created successfully or already existed. - Raises: - AgentInteractionError: If there's an issue. - """ - self.logger.info(f"Explicitly initializing ADK session: {session_id_to_init}.") - try: - return self._create_session_internal( - session_id=session_id_to_init, initial_state=initial_state - ) - except AgentInteractionError as e: - self.logger.error( - f"Failed to initialize ADK session {session_id_to_init}: {e}" - ) - raise - - def _create_session_internal( - self, session_id: str, initial_state: Optional[dict] = None - ) -> bool: - """ - Internal helper to create a session on the ADK server. - - Sends a POST request to the ADK session creation endpoint. - - Args: - session_id: The specific session ID to create. - initial_state: An optional dictionary to provide initial state for the session. - - Returns: - True if the session was successfully created or if it already existed (HTTP 409, or HTTP 400 with specific message). - - Raises: - AgentInteractionError: If the HTTP request fails or the server returns - an unexpected error status. - """ - target_url = f"{self.endpoint}/apps/{self.name}/users/{self.user_id}/sessions/{session_id}" - headers = {"Content-Type": "application/json", "Accept": "application/json"} - payload = initial_state or {} - self.logger.info(f"Attempting to create ADK session: {target_url}") - try: - response = requests.post( - target_url, headers=headers, json=payload, timeout=30 + Defined as a function instead of a module-level class so this module + keeps loading even when litellm is missing — the LiteLLMAgent base + will raise a clear error before anyone tries to actually use ADK. + """ + global _ADK_CUSTOM_LLM_CLASS + if _ADK_CUSTOM_LLM_CLASS is not None: + return _ADK_CUSTOM_LLM_CLASS + + from litellm import CustomLLM + from litellm.types.utils import ModelResponse + + class _ADKCustomLLM(CustomLLM): + """LiteLLM CustomLLM handler that proxies to an ADK server.""" + + def __init__( + self, + *, + endpoint: str, + app_name: str, + user_id: str, + default_session_id: str, + fresh_session_per_request: bool, + timeout: int, + log, + ): + super().__init__() + self.endpoint = endpoint.rstrip("/") + self.app_name = app_name + self.user_id = user_id + self.default_session_id = default_session_id + self.fresh_session_per_request = fresh_session_per_request + self.timeout = timeout + self.logger = log + + # ---- ADK transport (kept close to the previous implementation) --- + + def _create_session( + self, session_id: str, initial_state: Optional[dict] = None + ) -> None: + url = ( + f"{self.endpoint}/apps/{self.app_name}/users/" + f"{self.user_id}/sessions/{session_id}" ) - response.raise_for_status() - self.logger.info(f"Successfully created ADK session {session_id}") - return True - except requests.exceptions.HTTPError as http_err: - if http_err.response is not None: - status_code = http_err.response.status_code - response_text_lower = "" - original_response_text = "[Could not read body]" - try: - original_response_text = http_err.response.text - response_text_lower = original_response_text.lower() - except Exception as e_text: - self.logger.warning( - f"Could not get text from error response (status {status_code}) for session {session_id}: {e_text}" - ) - - # Condition 1: HTTP 409 Conflict (standard "already exists") + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + } + payload = initial_state or {} + try: + response = requests.post(url, headers=headers, json=payload, timeout=30) + response.raise_for_status() + return + except requests.exceptions.HTTPError as http_err: + response_text = "" + status_code = None + if http_err.response is not None: + status_code = http_err.response.status_code + try: + response_text = http_err.response.text or "" + except Exception: + response_text = "" if status_code == 409: - self.logger.warning( - f"ADK session {session_id} already exists (HTTP 409). Proceeding." - ) - return True - - # Condition 2: HTTP 400 Bad Request + specific message (ADK server's current behavior) + return if ( status_code == 400 - and "session already exists" in response_text_lower + and "session already exists" in response_text.lower() ): - self.logger.warning( - f"ADK session {session_id} already exists (HTTP 400 with 'session already exists' in body). " - f"Proceeding. Body: {original_response_text}" - ) - return True - - # If neither of the above conditions met, then it's a genuine error - err_msg_detail_base = f"HTTP error creating ADK session {session_id}" - err_msg_detail_extended = "" - current_status_for_exc = "Unknown" - - if http_err.response is not None: - try: - current_status_for_exc = http_err.response.status_code - # Ensure response_text is defined for logging if it wasn't fetched successfully above - body_for_log = ( - original_response_text - if "original_response_text" in locals() - else "[Could not read body during logging]" - ) - err_msg_detail_extended = f": {http_err} - Status {current_status_for_exc} - Body: {body_for_log}" - except Exception as e_resp_attrs: - self.logger.warning( - f"Could not get all attributes from error response for session {session_id}: {e_resp_attrs}" - ) - err_msg_detail_extended = ( - f": {http_err} (Error response attributes inaccessible)" - ) - else: # http_err.response is None - err_msg_detail_extended = f": {http_err}" - - self.logger.error(f"{err_msg_detail_base}{err_msg_detail_extended}") - raise AgentInteractionError( - f"HTTP Error {current_status_for_exc} creating session {session_id}" - ) from http_err - except requests.exceptions.RequestException as e: - self.logger.error( - f"Request exception creating ADK session {session_id}: {e}" - ) - raise AgentInteractionError( - f"Request failed creating session {session_id}: {e}" - ) from e - - def _prepare_request_payload( - self, prompt_text: str, session_id: str - ) -> Tuple[dict, dict]: - """ - Prepares the HTTP headers and JSON payload for an ADK agent request. - - Args: - prompt_text: The user's prompt text to be sent to the agent. - session_id: The session identifier for this specific ADK interaction. - - Returns: - A tuple containing two dictionaries: the headers and the payload. - """ - payload = { - "app_name": self.name, - "user_id": self.user_id, - "session_id": session_id, - "new_message": {"role": "user", "parts": [{"text": prompt_text}]}, - } - headers = {"Content-Type": "application/json", "Accept": "application/json"} - return headers, payload - - def _execute_http_post( - self, url: str, headers: dict, payload: dict - ) -> requests.Response: - """ - Executes an HTTP POST request. - - Args: - url: The URL to send the POST request to. - headers: A dictionary of HTTP headers. - payload: A dictionary to be sent as the JSON payload. - - Returns: - A `requests.Response` object. - - Raises: - AgentInteractionError: If the request times out or another request-related - exception occurs. - """ - try: - # Log agent interaction for TUI visibility - self.logger.info(f"🌐 Sending request to agent endpoint: {url}") - if "message" in payload: - msg_preview = str(payload["message"])[:100] - self.logger.debug(f" Message preview: {msg_preview}...") - - response = requests.post( - url, headers=headers, json=payload, timeout=self.timeout - ) - - # Log response status - self.logger.info(f"✅ Agent responded with status {response.status_code}") - self.logger.debug( - f"Request to {url} completed with status {response.status_code}" - ) - return response - except requests.exceptions.Timeout as e: - self.logger.warning(f"Request timed out accessing {url}: {e}") - raise AgentInteractionError(f"Request timed out: {e}") from e - except requests.exceptions.RequestException as e: - self.logger.error(f"Request exception accessing {url}: {e}") - raise AgentInteractionError(f"Request failed: {e}") from e - - def _parse_response_json( - self, response: requests.Response - ) -> Tuple[Optional[str], Optional[list], str, Optional[CaseInsensitiveDict]]: - """ - Parses the JSON response from an ADK agent. - - It checks for HTTP errors first. Then, it attempts to parse the JSON body, - expecting a list of events. It iterates through these events (in reverse) - to find the agent's final text response or an escalation message. - - Args: - response: The `requests.Response` object from the ADK agent. - - Returns: - A tuple containing: - - final_response_text (Optional[str]): The extracted text response. - - events (Optional[list]): The full list of ADK events. - - response_body_str (str): The raw response body as a string. - - http_headers (Optional[CaseInsensitiveDict]): The response headers. - - Raises: - AgentInteractionError: If an HTTP error status (4xx or 5xx) is encountered. - ResponseParsingError: If the response body is not valid JSON, not in the - expected list format, or if a non-event detail message - is returned instead of events. - """ - response_body_str = response.text - http_headers = response.headers - self.logger.debug( - f"ADK Response Body for parsing: {response_body_str[:1000]}" - ) # Log more of the body + return + raise AgentInteractionError( + f"HTTP Error {status_code} creating session " + f"{session_id}: {response_text[:200]}" + ) from http_err + except requests.exceptions.RequestException as e: + raise AgentInteractionError( + f"Request failed creating session {session_id}: {e}" + ) from e + + def _run(self, prompt_text: str, session_id: str) -> Dict[str, Any]: + url = f"{self.endpoint}/run" + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + } + payload = { + "app_name": self.app_name, + "user_id": self.user_id, + "session_id": session_id, + "new_message": { + "role": "user", + "parts": [{"text": prompt_text}], + }, + } - try: - response.raise_for_status() - except requests.exceptions.HTTPError as http_err: - # Use response.status_code directly since we have the response object - status = response.status_code - self.logger.error( - f"HTTP error {status} from {response.url}: {response_body_str}" - ) - raise AgentInteractionError(f"HTTP Error: {status}") from http_err + try: + response = requests.post( + url, headers=headers, json=payload, timeout=self.timeout + ) + except requests.exceptions.Timeout as e: + raise AgentInteractionError(f"Request timed out: {e}") from e + except requests.exceptions.RequestException as e: + raise AgentInteractionError(f"Request failed: {e}") from e + + response_body = response.text + try: + response.raise_for_status() + except requests.exceptions.HTTPError as http_err: + raise AgentInteractionError( + f"HTTP Error: {response.status_code}" + ) from http_err + + try: + events = response.json() + except (json.JSONDecodeError, ValueError) as parse_err: + raise ResponseParsingError( + f"JSON parse failed: {parse_err}. Body: {response_body[:200]}" + ) from parse_err - final_response_text = None - events = None - try: - events = response.json() if not isinstance(events, list): - self.logger.warning( - f"ADK response was not a JSON list. Type: {type(events)}. Body: {response_body_str[:500]}" - ) if isinstance(events, dict) and "detail" in events: - detail_message = events["detail"] - self.logger.warning( - f"ADK returned non-event detail message: {detail_message}" - ) raise ResponseParsingError( - f"ADK returned detail message: {detail_message}" + f"ADK returned detail message: {events['detail']}" ) - self.logger.warning( - f"ADK response not a JSON list or recognized detail. Body: {response_body_str[:500]}" - ) raise ResponseParsingError( "ADK response format unrecognized (not a list)." ) - self.logger.debug(f"Received {len(events)} events from ADK for parsing.") + return { + "events": events, + "raw_request": payload, + "raw_response_body": response_body, + "raw_response_headers": dict(response.headers), + "status_code": response.status_code, + "final_text": _extract_final_text(events), + } + + # ---- LiteLLM CustomLLM API --------------------------------------- + + def completion(self, *args, **kwargs): + """Translate a LiteLLM completion call into an ADK /run request.""" + messages = kwargs.get("messages") or [] + optional_params = kwargs.get("optional_params") or {} + model_response: ModelResponse = ( + kwargs.get("model_response") or ModelResponse() + ) - for i, event in enumerate(reversed(events)): - self.logger.debug( - f"Parsing event {len(events) - 1 - i} (reversed index {i}): {str(event)[:200]}..." + prompt_text = _last_user_text(messages) + if not prompt_text: + raise AgentInteractionError( + "ADK adapter requires at least one user message with text content." ) - actions = event.get("actions") - if actions and isinstance(actions, dict) and actions.get("escalate"): - error_msg = event.get( - "error_message", - "No specific message provided by agent for escalation.", - ) - final_response_text = f"Agent escalated: {error_msg}" - self.logger.debug( - f"Escalation event found as final response: {final_response_text}" - ) - break # Found escalation, stop parsing - content = event.get("content") - if not content or not isinstance(content, dict): - self.logger.debug( - f"Event {len(events) - 1 - i} has no content or content is not a dict. Skipping for text." - ) - continue + session_id = optional_params.get("session_id") + if not session_id: + session_id = ( + str(uuid.uuid4()) + if self.fresh_session_per_request + else self.default_session_id + ) + initial_state = optional_params.get("initial_session_state") - parts = content.get("parts") - if not parts or not isinstance(parts, list) or len(parts) == 0: - self.logger.debug( - f"Event {len(events) - 1 - i} content has no parts, parts is not a list, or parts is empty. Skipping for text." - ) - continue + self.logger.info( + f"🌐 ADK run for app '{self.app_name}' (session {session_id})" + ) + self._create_session(session_id=session_id, initial_state=initial_state) + result = self._run(prompt_text=prompt_text, session_id=session_id) + + final_text = result["final_text"] or "" + model_response.choices[0].message.content = final_text # type: ignore[attr-defined] + try: + model_response.choices[0].finish_reason = "stop" # type: ignore[attr-defined] + except Exception: + pass + model_response.model = ( + kwargs.get("model") or f"{_ADK_PROVIDER_PREFIX}/{self.app_name}" + ) - # Check the first part for text - first_part = parts[0] - if not isinstance(first_part, dict): - self.logger.debug( - f"Event {len(events) - 1 - i} first part is not a dict: {type(first_part)}. Skipping for text." - ) - continue + # Stash ADK-specific bits where the outer adapter can find them. + try: + model_response.choices[0].message.provider_specific_fields = { # type: ignore[attr-defined] + "adk_events_list": result["events"], + "adk_session_id": session_id, + "adk_raw_response_body": result["raw_response_body"], + "adk_raw_request": result["raw_request"], + "adk_status_code": result["status_code"], + } + except Exception: + pass + return model_response + + async def acompletion(self, *args, **kwargs): + """Async wrapper — run the sync ADK transport in a worker thread.""" + import asyncio + + return await asyncio.get_event_loop().run_in_executor( + None, lambda: self.completion(*args, **kwargs) + ) - part_text = first_part.get("text") - if ( - part_text is None - ): # Explicitly check for None, as empty string is fine if stripped later - self.logger.debug( - f"Event {len(events) - 1 - i} first part has no 'text' key. Skipping for text." - ) - continue + _ADK_CUSTOM_LLM_CLASS = _ADKCustomLLM + return _ADKCustomLLM - if not isinstance(part_text, str): - self.logger.debug( - f"Event {len(events) - 1 - i} first part 'text' is not a string: {type(part_text)}. Skipping for text." - ) - continue - - # At this point, part_text is a string (could be empty) - # The original code also checks `part_text.strip()` to ensure it's not just whitespace. - # Let's keep that check. - if part_text.strip(): - final_response_text = part_text # Store the original text, stripping is for check only - self.logger.debug( - f"Found text in event {len(events) - 1 - i}, part 0, as final response: '{final_response_text[:100]}...'" - ) - break # Found usable text, stop parsing - else: - self.logger.debug( - f"Event {len(events) - 1 - i} first part text is empty or whitespace after strip. Skipping for text." - ) - if final_response_text is None: - self.logger.warning( - f"No final response text could be extracted from any of the {len(events)} ADK events from {response.url}." - ) - return final_response_text, events, response_body_str, http_headers - except ( - json.JSONDecodeError, - ValueError, - ) as parse_err: # Catch ValueError too for broader JSON issues - self.logger.warning( - f"Failed to parse ADK JSON from {response.url}: {parse_err}. Body: {response_body_str[:500]}" - ) - raise ResponseParsingError(f"JSON parse failed: {parse_err}") from parse_err - - def _process_agent_interaction(self, prompt_text: str, session_id: str) -> dict: - """ - Manages a single interaction (turn) with the ADK agent for a given prompt. - - This involves preparing the payload, executing the HTTP POST request to the - correct ADK :runTurn endpoint, and parsing the response. - - Args: - prompt_text: The prompt text to send to the ADK agent. - session_id: The ADK session ID for this interaction. - - Returns: - A dictionary containing detailed results of the interaction, including: - - 'generated_text': The agent's final response text. - - 'adapter_specific_events': Full list of ADK events. - - 'raw_request': The payload sent to the agent. - - 'raw_response_status': HTTP status code of the agent's response. - - 'raw_response_headers': HTTP headers from the agent's response. - - 'raw_response_body': Raw body of the agent's response. - - 'error_message': Any error message if an issue occurred. - """ - interaction_result: Dict[str, Any] = { - "generated_text": None, - "adapter_specific_events": None, - "raw_request": None, - "raw_response_status": None, - "raw_response_headers": None, - "raw_response_body": None, - "error_message": None, - } +class ADKAgent(LiteLLMAgent): + """ + Adapter for a deployed Google ADK agent server. + + The request travels through LiteLLM via a per-instance + :class:`CustomLLM` handler registered as + ``hackagent_adk_/``. From the router's perspective this + is just another LiteLLM agent. + + Required config: + - ``name``: ADK app name (used as both the model string and the + ``app_name`` in the request payload). + - ``endpoint``: ADK server base URL. + - ``user_id``: User ID for ADK sessions. + + Optional config: + - ``timeout`` (seconds, default 120). + - ``session_id``: sticky session ID; if unset a UUID is generated. + - ``fresh_session_per_request`` (default True): if True, every + request gets a brand-new session unless the caller supplies one. + """ - try: - headers, payload = self._prepare_request_payload(prompt_text, session_id) - interaction_result["raw_request"] = payload + ADAPTER_TYPE = "ADKAgent" - # Reverting to the simple /run endpoint as per general ADK docs - run_turn_url = f"{self.endpoint}/run" - self.logger.debug( - f"Sending ADK request to: {run_turn_url} with payload app_name: {payload.get('app_name')}" - ) + def __init__(self, id: str, config: Dict[str, Any]): + for key in ("name", "endpoint", "user_id"): + if key not in config: + raise AgentConfigurationError( + f"Missing required configuration key '{key}' for ADKAgent: {id}" + ) - response = self._execute_http_post(run_turn_url, headers, payload) - interaction_result["raw_response_status"] = response.status_code - # interaction_result["raw_response_headers"] is set in _parse_response_json - # interaction_result["raw_response_body"] is set in _parse_response_json - - ( - final_text, - events, - response_body_str, - resp_headers, - ) = self._parse_response_json(response) - - interaction_result["generated_text"] = final_text - interaction_result["adapter_specific_events"] = events - interaction_result["raw_response_body"] = response_body_str - interaction_result["raw_response_headers"] = ( - dict(resp_headers) if resp_headers else None - ) + # Provider name is per-instance so each ADKAgent gets its own handler. + # Set on self before super().__init__ runs so that the base's call to + # _resolve_litellm_model (overridden below) sees the right value. + self._provider_name = f"{_ADK_PROVIDER_PREFIX}_{id}" - except AgentInteractionError as aie: - self.logger.error(f"AgentInteractionError processing prompt: {aie}") - interaction_result["error_message"] = f"ADK Error: {aie}" - except ResponseParsingError as rpe: - self.logger.error(f"ResponseParsingError processing prompt: {rpe}") - interaction_result["error_message"] = f"ADK Response Parse Error: {rpe}" - except Exception as e: - self.logger.exception(f"Unexpected error during ADK agent interaction: {e}") - interaction_result["error_message"] = f"Unexpected ADK Adapter Error: {e}" + adk_endpoint = str(config["endpoint"]).strip("/") + adk_user_id = config["user_id"] + adk_app_name = config["name"] + adk_timeout = int(config.get("timeout", 120)) + fresh = bool(config.get("fresh_session_per_request", True)) + session_id = config.get("session_id") or str(uuid.uuid4()) - return interaction_result + # The base passes ``endpoint`` along to LiteLLM as ``api_base``; we + # don't want that since our custom provider hits ADK directly. + base_config = {k: v for k, v in config.items() if k != "endpoint"} + super().__init__(id, base_config) - def _build_error_response( - self, - error_message: str, - status_code: Optional[int], - raw_request: Optional[Dict[str, Any]] = None, - interaction_details: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: - """ - Constructs a standardized error response dictionary for the adapter. - - Args: - error_message: The primary error message string. - status_code: The HTTP status code associated with the error, if applicable. - raw_request: The original request data that led to the error. - interaction_details: A dictionary containing details from - `_process_agent_interaction` if the error occurred - during ADK processing. - - Returns: - A dictionary representing a standardized error response. - """ - raw_response_headers = None - raw_response_body = None - actual_status_code = status_code - adk_events = None - - if interaction_details: - raw_response_headers = interaction_details.get("response_headers") - raw_response_body = interaction_details.get("response_body_raw") - adk_events = interaction_details.get("adk_events_list") - if interaction_details.get("response_status_code") is not None: - actual_status_code = interaction_details.get("response_status_code") - if raw_request is None: - raw_request = interaction_details.get("request_payload") - - # Use base class method with ADK-specific data - return super()._build_error_response( - error_message=error_message, - status_code=actual_status_code, - raw_request=raw_request, - raw_response_body=raw_response_body, - raw_response_headers=raw_response_headers, - agent_specific_data={"adk_events_list": adk_events}, - ) + self.endpoint = adk_endpoint + self.user_id = adk_user_id + self.name = adk_app_name + self.timeout = adk_timeout + self.fresh_session_per_request = fresh + self.session_id = session_id - def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: - """ - Handles an incoming request by creating an ADK session (if not existing) - and then processing the request through the ADK agent. - - Args: - request_data: A dictionary containing the request data. Must include - a 'prompt' key with the text to send to the agent. - Optional keys: - - 'session_id': Override the adapter's default session_id (advanced usage) - - 'initial_session_state': Initial state dict for new sessions - - 'adk_session_id': Deprecated, use 'session_id' instead - - 'adk_user_id': Deprecated, adapter manages user_id - - Returns: - A dictionary representing the agent's response or an error. - """ - prompt_text = request_data.get("prompt") - - # Support both new 'session_id' and legacy 'adk_session_id' for backward compatibility - session_id_from_request = request_data.get( - "session_id", request_data.get("adk_session_id") - ) + self._register_custom_provider() - # Use adapter's instance session_id if not provided in request - # If fresh_session_per_request is enabled, generate a new UUID for each request - import uuid + self.logger.info( + f"ADKAgent '{self.id}' registered as LiteLLM provider " + f"'{self._provider_name}' targeting {self.endpoint} " + f"(app={self.name}, session={self.session_id}, " + f"fresh_session_per_request={self.fresh_session_per_request})" + ) - if session_id_from_request: - session_id_to_use = session_id_from_request - elif self.fresh_session_per_request: - session_id_to_use = str(uuid.uuid4()) - self.logger.debug( - f"Using fresh session ID for request: {session_id_to_use}" + def _register_custom_provider(self) -> None: + litellm, available = _get_litellm() + if not available: + raise AgentConfigurationError( + "litellm is required for ADKAgent but is not installed." ) - else: - session_id_to_use = self.session_id - - initial_session_state = request_data.get("initial_session_state") # Optional - if not prompt_text: - self.logger.warning("No 'prompt' found in request_data.") - return self._build_error_response( - error_message="Request data must include a 'prompt' field.", - status_code=400, - raw_request=request_data, - ) + handler_cls = _get_adk_custom_llm_class() + handler = handler_cls( + endpoint=self.endpoint, + app_name=self.name, + user_id=self.user_id, + default_session_id=self.session_id, + fresh_session_per_request=self.fresh_session_per_request, + timeout=self.timeout, + log=self.logger, + ) - self.logger.info( - f"Handling request for agent {self.id} with prompt: '{prompt_text[:75]}...' (Session: {session_id_to_use})" + provider = self._provider_name + # Replace any stale entry for this provider name (e.g. when an + # ADKAgent with the same id is re-created during tests). + litellm.custom_provider_map = [ + entry + for entry in litellm.custom_provider_map + if entry.get("provider") != provider + ] + litellm.custom_provider_map.append( + {"provider": provider, "custom_handler": handler} ) + if provider not in litellm._custom_providers: + litellm._custom_providers.append(provider) - try: - # Step 1: Ensure ADK session exists - self.logger.info( - f"Ensuring ADK session '{session_id_to_use}' exists before running turn." - ) - self._create_session_internal( - session_id=session_id_to_use, initial_state=initial_session_state - ) - # If _create_session_internal raises, it will be caught by the outer try-except - self.logger.info(f"Session '{session_id_to_use}' confirmed/created.") + self._custom_handler = handler - # Step 2: Process the agent interaction (send to /run) - interaction_details = self._process_agent_interaction( - prompt_text, session_id=session_id_to_use - ) + def _resolve_litellm_model(self, raw_model: str) -> str: + return f"{self._provider_name}/{raw_model}" - if interaction_details.get("error_message"): - self.logger.warning( - f"ADK interaction for agent {self.id} (session {session_id_to_use}) processed with error: " - f"{interaction_details['error_message']}" - ) - # Pass full interaction_details to enrich the error response - return self._build_error_response( - error_message=interaction_details["error_message"], - status_code=interaction_details.get("raw_response_status"), - interaction_details=interaction_details, - ) + # ---- forward ADK-specific request fields ---------------------------- - # Success case - return self._build_success_response( - processed_response=interaction_details.get("generated_text"), - raw_request=interaction_details.get("raw_request"), - raw_response_body=interaction_details.get("raw_response_body"), - raw_response_headers=interaction_details.get("raw_response_headers"), - agent_specific_data={ - "adk_events_list": interaction_details.get( - "adapter_specific_events" - ) - }, - status_code=interaction_details.get("raw_response_status") or 200, - ) - except AgentInteractionError as aie_session: # Specific catch for session errors from _create_session_internal - self.logger.error( - f"Failed to ensure ADK session '{session_id_to_use}': {aie_session}" - ) - return self._build_error_response( - error_message=f"Failed to create/verify ADK session '{session_id_to_use}': {aie_session}", - status_code=500, # Or a more specific code if available from aie_session - raw_request=request_data, - ) - except Exception as e: - self.logger.exception( - f"Unexpected error in handle_request for agent {self.id} (session {session_id_to_use}): {e}" - ) - return self._build_error_response( - error_message=f"Unexpected adapter error: {type(e).__name__} - {str(e)}", - status_code=500, - raw_request=request_data, + def _get_completion_parameters( + self, request_data: Dict[str, Any] + ) -> Dict[str, Any]: + params = super()._get_completion_parameters(request_data) + session_id = request_data.get("session_id", request_data.get("adk_session_id")) + if session_id: + params["session_id"] = session_id + if "initial_session_state" in request_data: + params["initial_session_state"] = request_data["initial_session_state"] + return params + + def _get_excluded_request_keys(self) -> set: + base = super()._get_excluded_request_keys() + return base | {"session_id", "adk_session_id", "initial_session_state"} + + def _build_agent_specific_data( + self, + completion_result: Dict[str, Any], + parameters: Dict[str, Any], + ) -> Dict[str, Any]: + data = super()._build_agent_specific_data(completion_result, parameters) + raw = completion_result.get("raw_response") + adk_fields: Dict[str, Any] = {} + try: + adk_fields = ( + getattr(raw.choices[0].message, "provider_specific_fields", None) or {} ) + except (AttributeError, IndexError, TypeError): + adk_fields = {} + events = adk_fields.get("adk_events_list") + if events is not None: + data["adk_events_list"] = events + if "adk_session_id" in adk_fields: + data["adk_session_id"] = adk_fields["adk_session_id"] + return data diff --git a/hackagent/router/adapters/litellm.py b/hackagent/router/adapters/litellm.py index 3621fde4..947eb518 100644 --- a/hackagent/router/adapters/litellm.py +++ b/hackagent/router/adapters/litellm.py @@ -87,41 +87,77 @@ class LiteLLMConfigurationError(AdapterConfigurationError): logger = get_logger(__name__) # Module-level logger -class LiteLLMAgent(ChatCompletionsAgent): - """ - Adapter for interacting with LLMs via the LiteLLM library. +# Provider prefixes that LiteLLM recognises natively. When a model string +# already starts with one of these, we leave it alone instead of prepending +# our own provider prefix. +_KNOWN_LITELLM_PROVIDER_PREFIXES = ( + "openai/", + "anthropic/", + "azure/", + "bedrock/", + "vertex_ai/", + "huggingface/", + "replicate/", + "together_ai/", + "anyscale/", + "ollama/", + "ollama_chat/", + "groq/", + "mistral/", + "cohere/", + "gemini/", + "deepseek/", +) - This adapter supports multiple LLM providers through LiteLLM's unified interface. - For custom/self-hosted endpoints, the endpoint URL must be provided correctly: - OpenAI-Compatible Endpoints: - - Provide the base URL ending with /v1 (e.g., "http://localhost:8000/v1") - - The OpenAI client will automatically append /chat/completions - - Example: endpoint="http://localhost:8000/v1" → requests to http://localhost:8000/v1/chat/completions - - Non-OpenAI Protocols: - - Use the appropriate agent type (LANGCHAIN, MCP, A2A) instead of routing through LiteLLM - - LANGCHAIN: Use LangServe endpoints (e.g., "http://localhost:8000/invoke") - - MCP: Use Model Context Protocol adapter (not LiteLLM) - - A2A: Use Agent-to-Agent protocol adapter (not LiteLLM) +class LiteLLMAgent(ChatCompletionsAgent): + """ + Unified adapter that routes every chat-completion request through LiteLLM. + + All chat-style adapters (OpenAI-SDK, Ollama, LangChain, plain LiteLLM) + subclass this class. Each subclass sets ``PROVIDER_PREFIX`` to declare + which LiteLLM provider the AgentType maps to (e.g. ``"openai"`` for an + OpenAI-compatible endpoint, ``"ollama_chat"`` for Ollama). The base class + handles model-string normalisation, endpoint plumbing, generation + parameters, tool calls, and the unified ``thinking`` knob. + + Thinking knob: + Any subclass can be asked to enable or disable provider reasoning by + setting ``thinking`` in the adapter config or per-request payload. + Accepted values: + - bool: ``True`` enables thinking with provider defaults, + ``False`` disables it explicitly. + - dict: passed through verbatim (e.g. + ``{"type": "enabled", "budget_tokens": 1024}`` for Anthropic). + - str: a reasoning effort level (``"low"``, ``"medium"``, + ``"high"``) — translated by the subclass as appropriate. + - int: budget tokens for providers that accept a budget. + Subclasses override ``_apply_thinking`` to translate the value into + the provider-specific request fields. """ ADAPTER_TYPE = "LiteLLMAgent" + # When set, the model string passed to LiteLLM is prefixed with + # ``"{PROVIDER_PREFIX}/"`` unless it already starts with a known + # LiteLLM provider prefix. ``None`` means "let LiteLLM auto-detect". + PROVIDER_PREFIX: Optional[str] = None def __init__(self, id: str, config: Dict[str, Any]): """ - Initializes the LiteLLMAgent. + Initialise the adapter from configuration. Args: - id: The unique identifier for this LiteLLM agent instance. - config: Configuration dictionary for the LiteLLM agent. - Expected keys: - - 'name': Model string for LiteLLM (e.g., "ollama/llama3"). - - 'endpoint' (optional): Base URL for the API. - - 'api_key' (optional): Name of the environment variable holding the API key. - - 'max_tokens' (optional): Default max tokens for generation (defaults to 100). - - 'temperature' (optional): Default temperature (defaults to 0.8). - - 'top_p' (optional): Default top_p (defaults to 0.95). + id: Unique identifier for this adapter instance. + config: Configuration dict. Supported keys: + - ``name``: model string (e.g. ``"llama3"`` or + ``"gpt-4"``). Required. + - ``endpoint`` (optional): API base URL. + - ``api_key`` (optional): API key or environment variable name. + - ``max_tokens`` / ``temperature`` / ``top_p`` (optional). + - ``tools`` / ``tool_choice`` (optional): function-calling + definitions, passed through to LiteLLM. + - ``thinking`` (optional): see class docstring. + - ``extra_body`` (optional): provider-specific request body. """ super().__init__(id, config) @@ -129,40 +165,94 @@ def __init__(self, id: str, config: Dict[str, Any]): self.model_name = self._require_config_key("name", LiteLLMConfigurationError) self.api_base_url: Optional[str] = self._get_config_key("endpoint") - # Handle API key configuration using base class helper - self.actual_api_key: Optional[str] = None - - # Determine appropriate fallback env var based on model name - env_var_fallback = None - if not self.api_base_url: - # No custom endpoint - try standard env vars for public APIs - if self.model_name.startswith("openai/") or self.model_name.startswith( - "gpt-" - ): - env_var_fallback = "OPENAI_API_KEY" - elif self.model_name.startswith("anthropic/") or self.model_name.startswith( - "claude-" - ): - env_var_fallback = "ANTHROPIC_API_KEY" + # Determine the effective LiteLLM model string (with provider prefix). + self.litellm_model = self._resolve_litellm_model(self.model_name) + # Handle API key configuration using base class helper + env_var_fallback = self._default_api_key_env_var() self.actual_api_key = self._resolve_api_key( config_key="api_key", env_var_fallback=env_var_fallback ) - # When using custom endpoint without credentials, rely on endpoint-side auth. + # When using a custom endpoint without credentials, rely on + # endpoint-side auth (common for local model servers). if self.api_base_url and not self.actual_api_key: self.logger.debug( - f"Using custom endpoint '{self.api_base_url}' without api_key - endpoint handles its own auth" + f"Using custom endpoint '{self.api_base_url}' without api_key - " + "endpoint handles its own auth" ) self.logger.info( - f"LiteLLMAgent '{self.id}' initialized for model: '{self.model_name}'" + f"{self.ADAPTER_TYPE} '{self.id}' initialised for LiteLLM model: " + f"'{self.litellm_model}'" + (f" API Base: '{self.api_base_url}'" if self.api_base_url else "") ) - # Initialize default generation parameters using base class method + # Default generation parameters (max_tokens, temperature, top_p). self._init_generation_params() + # Pass-through fields commonly supplied via config. + self.default_tools = self._get_config_key("tools") + self.default_tool_choice = self._get_config_key("tool_choice") + self.default_extra_body = self._get_config_key("extra_body") + self.default_thinking = self._get_config_key("thinking") + + # ---- subclass extension points --------------------------------------- + + def _resolve_litellm_model(self, raw_model: str) -> str: + """Return the model string to pass to ``litellm.completion``. + + Honors the subclass ``PROVIDER_PREFIX`` while leaving names that + already carry an explicit LiteLLM provider prefix untouched. + """ + if self.PROVIDER_PREFIX is None: + return raw_model + if raw_model.startswith(_KNOWN_LITELLM_PROVIDER_PREFIXES): + return raw_model + return f"{self.PROVIDER_PREFIX}/{raw_model}" + + def _default_api_key_env_var(self) -> Optional[str]: + """Return the env var used as a fallback when no API key is configured.""" + if self.api_base_url: + return None + if self.litellm_model.startswith(("openai/", "gpt-")): + return "OPENAI_API_KEY" + if self.litellm_model.startswith(("anthropic/", "claude-")): + return "ANTHROPIC_API_KEY" + return None + + def _apply_thinking(self, litellm_params: Dict[str, Any], thinking: Any) -> None: + """Translate the unified ``thinking`` value into LiteLLM params. + + The default implementation mirrors LiteLLM's own conventions: + - dict: forwarded verbatim as ``thinking=...`` + - str: forwarded as ``reasoning_effort=...`` + - int: forwarded as ``thinking={"type": "enabled", + "budget_tokens": int}`` + - True/False: forwarded as + ``thinking={"type": "enabled" | "disabled"}`` + Subclasses override this method when their provider needs different + field names (e.g. Ollama's ``think``). + """ + if thinking is None: + return + if isinstance(thinking, dict): + litellm_params["thinking"] = dict(thinking) + elif isinstance(thinking, str): + litellm_params["reasoning_effort"] = thinking + elif isinstance(thinking, bool): + litellm_params["thinking"] = {"type": "enabled" if thinking else "disabled"} + elif isinstance(thinking, int): + litellm_params["thinking"] = { + "type": "enabled", + "budget_tokens": int(thinking), + } + else: + # Best-effort passthrough for unknown shapes. + litellm_params["thinking"] = thinking + + # ---- request preparation -------------------------------------------- + def _prepare_litellm_params( self, messages: List[Dict[str, str]], @@ -171,62 +261,70 @@ def _prepare_litellm_params( top_p: float, **kwargs, ) -> Dict[str, Any]: - """Prepare parameters for litellm.completion call.""" - litellm_params = { - "model": self.model_name, + """Build the kwargs dict for ``litellm.completion``.""" + litellm_params: Dict[str, Any] = { + "model": self.litellm_model, "messages": messages, "max_tokens": max_tokens, "temperature": temperature, "top_p": top_p, } - # Only include api_base and api_key if they are set if self.api_base_url: litellm_params["api_base"] = self.api_base_url if self.actual_api_key: litellm_params["api_key"] = self.actual_api_key - # Handle custom endpoint scenarios (LangChain, custom agents, etc.) - if self.api_base_url: - # For custom endpoints, treat as OpenAI-compatible unless model has a known provider prefix - if not any( - self.model_name.startswith(prefix) - for prefix in [ - "openai/", - "anthropic/", - "azure/", - "bedrock/", - "vertex_ai/", - "huggingface/", - "replicate/", - "together_ai/", - "anyscale/", - "ollama/", - ] - ): - # Model name without provider prefix - treat as OpenAI-compatible custom endpoint - litellm_params["custom_llm_provider"] = "openai" - # Use the endpoint exactly as provided - user specifies the complete URL - # For OpenAI-compatible endpoints, this should be the base URL (e.g., http://host:port/v1) - # and the OpenAI client will append /chat/completions automatically - litellm_params["api_base"] = self.api_base_url + # When the caller provides a custom endpoint without a recognised + # LiteLLM provider prefix, treat it as OpenAI-compatible. This + # preserves the previous behaviour for plain LiteLLM users and gives + # us a sensible default for LangChain-style endpoints. + if self.api_base_url and not self.litellm_model.startswith( + _KNOWN_LITELLM_PROVIDER_PREFIXES + ): + litellm_params["custom_llm_provider"] = "openai" litellm_params["extra_headers"] = {"User-Agent": "HackAgent/0.1.0"} + elif self.api_base_url: + # Keep the User-Agent for outbound requests even when a provider + # prefix is supplied — useful for self-hosted proxies. + litellm_params["extra_headers"] = {"User-Agent": "HackAgent/0.1.0"} + + # Thinking handling — config default merged with per-request override. + thinking = kwargs.pop("thinking", self.default_thinking) + self._apply_thinking(litellm_params, thinking) + + # Tool calls. + tools = kwargs.pop("tools", self.default_tools) + tool_choice = kwargs.pop("tool_choice", self.default_tool_choice) + if tools: + litellm_params["tools"] = tools + if tool_choice is not None: + litellm_params["tool_choice"] = tool_choice + + # Provider-specific extra body (e.g. OpenRouter ``reasoning``). + extra_body = kwargs.pop("extra_body", self.default_extra_body) + if extra_body is not None: + litellm_params["extra_body"] = ( + dict(extra_body) if isinstance(extra_body, dict) else extra_body + ) litellm_params.update(kwargs) return litellm_params def _extract_raw_response_content(self, response: Any, context: str = "") -> str: - """Extract content from raw litellm response object, handling various response formats.""" + """Extract content from a litellm response object.""" if not (response and response.choices and response.choices[0].message): self.logger.warning( - f"LiteLLM received unexpected response structure for model '{self.model_name}'{context}. Response: {response}" + f"LiteLLM received unexpected response structure for model " + f"'{self.litellm_model}'{context}. Response: {response}" ) return "[GENERATION_ERROR: UNEXPECTED_RESPONSE]" message = response.choices[0].message content = message.content if message.content else "" - # Try to extract reasoning content from various possible locations + # Reasoning models surface their output in a dedicated field; fall + # back to it when the regular content is empty. reasoning_content = None if hasattr(message, "reasoning_content") and message.reasoning_content: reasoning_content = message.reasoning_content @@ -240,51 +338,78 @@ def _extract_raw_response_content(self, response: Any, context: str = "") -> str "reasoning_content" ) or message.provider_specific_fields.get("reasoning") - # Use content if available, otherwise fall back to reasoning content if content: return content - elif reasoning_content: + if reasoning_content: self.logger.debug( - f"LiteLLM using reasoning content for model '{self.model_name}' (content field was empty)" + f"LiteLLM using reasoning content for model " + f"'{self.litellm_model}' (content field was empty)" ) return reasoning_content - else: - self.logger.warning( - f"LiteLLM received empty content and no reasoning field for model '{self.model_name}'{context}. Message: {message}" - ) - return "[GENERATION_ERROR: EMPTY_RESPONSE]" + + self.logger.warning( + f"LiteLLM received empty content and no reasoning field for model " + f"'{self.litellm_model}'{context}. Message: {message}" + ) + return "[GENERATION_ERROR: EMPTY_RESPONSE]" + + def _extract_tool_calls(self, response: Any) -> Optional[List[Dict[str, Any]]]: + """Extract OpenAI-style tool_calls from a LiteLLM response, if any.""" + try: + message = response.choices[0].message + except (AttributeError, IndexError, TypeError): + return None + tool_calls = getattr(message, "tool_calls", None) + if not tool_calls: + return None + result = [] + for tc in tool_calls: + try: + result.append( + { + "id": getattr(tc, "id", None), + "type": getattr(tc, "type", "function"), + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + ) + except AttributeError: + continue + return result or None def _get_excluded_request_keys(self) -> set: - """Return keys to exclude when passing additional kwargs.""" + """Return keys handled explicitly so they aren't re-passed as kwargs.""" return { "prompt", "messages", "max_tokens", - "max_tokens", "temperature", "top_p", + "tools", + "tool_choice", + "thinking", + "extra_body", } - def _execute_completion( - self, messages: List[Dict[str, str]], **parameters + def _get_completion_parameters( + self, request_data: Dict[str, Any] ) -> Dict[str, Any]: - """ - Execute a completion using litellm.completion. + """Extract completion parameters with provider-agnostic passthroughs.""" + params = super()._get_completion_parameters(request_data) + # Carry passthrough fields when present in the request. + for key in ("tools", "tool_choice", "thinking", "extra_body"): + if key in request_data: + params[key] = request_data[key] + return params - This implements the abstract method from ChatCompletionsAgent. + # ---- execution ------------------------------------------------------- - Args: - messages: List of message dictionaries with 'role' and 'content'. - **parameters: Completion parameters including max_tokens, temperature, top_p. - - Returns: - Dictionary with: - - success: Boolean indicating if completion succeeded - - content: The generated text (if successful) - - error_type: Type of error (if failed) - - error_message: Error description (if failed) - - raw_response: The raw API response (if available) - """ + def _execute_completion( + self, messages: List[Dict[str, str]], **parameters + ) -> Dict[str, Any]: + """Execute a completion via ``litellm.completion``.""" litellm, is_available = _get_litellm() if not is_available: return { @@ -297,47 +422,59 @@ def _execute_completion( AuthenticationError = exceptions["AuthenticationError"] try: - # Log agent interaction for TUI visibility if messages: msg_preview = str(messages[-1].get("content", ""))[:100] - self.logger.info(f"🌐 Querying model {self.model_name}") + self.logger.info(f"🌐 Querying model {self.litellm_model}") self.logger.debug(f" Message preview: {msg_preview}...") - # Extract parameters - max_tokens = parameters.get("max_tokens", self.default_max_tokens) - temperature = parameters.get("temperature", self.default_temperature) - top_p = parameters.get("top_p", self.default_top_p) - - # Remove these from kwargs to avoid duplication - kwargs = { - k: v - for k, v in parameters.items() - if k not in {"max_tokens", "temperature", "top_p"} - } + max_tokens = parameters.pop("max_tokens", self.default_max_tokens) + temperature = parameters.pop("temperature", self.default_temperature) + top_p = parameters.pop("top_p", self.default_top_p) litellm_params = self._prepare_litellm_params( - messages, max_tokens, temperature, top_p, **kwargs + messages, max_tokens, temperature, top_p, **parameters ) response = litellm.completion(**litellm_params) content = self._extract_raw_response_content(response) + tool_calls = self._extract_tool_calls(response) + self.logger.info(f"✅ Model responded ({len(content)} chars)") - return { + result: Dict[str, Any] = { "success": True, "content": content, "raw_response": response, } + if tool_calls is not None: + result["tool_calls"] = tool_calls + # Surface useful diagnostics when available. + try: + result["finish_reason"] = response.choices[0].finish_reason + except (AttributeError, IndexError, TypeError): + pass + try: + if response.usage is not None: + result["usage"] = response.usage.model_dump() + except AttributeError: + pass + try: + result["provider_model"] = response.model + except AttributeError: + pass + + return result except AuthenticationError as e: - error_msg = f"Authentication failed for model '{self.model_name}': {str(e)}" + error_msg = f"Authentication failed for model '{self.litellm_model}': {e}" self.logger.error(error_msg) - # Re-raise authentication errors so they can be handled specially llm_provider = e.llm_provider if hasattr(e, "llm_provider") else "unknown" - raise AuthenticationError(error_msg, llm_provider, self.model_name) from e + raise AuthenticationError( + error_msg, llm_provider, self.litellm_model + ) from e except Exception as e: self.logger.error( - f"LiteLLM completion call failed for model '{self.model_name}': {e}", + f"LiteLLM completion call failed for model '{self.litellm_model}': {e}", exc_info=True, ) return { @@ -346,6 +483,25 @@ def _execute_completion( "error_message": str(e), } + # ---- response shaping ------------------------------------------------ + + def _build_agent_specific_data( + self, + completion_result: Dict[str, Any], + parameters: Dict[str, Any], + ) -> Dict[str, Any]: + """Include common LiteLLM metadata (finish_reason, usage, tools).""" + data = super()._build_agent_specific_data(completion_result, parameters) + for key in ("finish_reason", "usage", "provider_model"): + value = completion_result.get(key) + if value is not None and key not in data: + data[key] = value + if completion_result.get("tool_calls"): + data["tool_calls"] = completion_result["tool_calls"] + return data + + # ---- legacy convenience helpers ------------------------------------- + def _execute_litellm_completion_with_messages( self, messages: List[Dict[str, str]], @@ -354,10 +510,7 @@ def _execute_litellm_completion_with_messages( top_p: float, **kwargs, ) -> str: - """Execute a single completion using litellm.completion with messages format. - - This is a convenience method that wraps _execute_completion for backwards compatibility. - """ + """Single completion call returning the generated text only.""" result = self._execute_completion( messages, max_tokens=max_tokens, @@ -365,11 +518,9 @@ def _execute_litellm_completion_with_messages( top_p=top_p, **kwargs, ) - if result.get("success"): return result.get("content", "") - else: - return f"[GENERATION_ERROR: {result.get('error_type', 'UNKNOWN')}]" + return f"[GENERATION_ERROR: {result.get('error_type', 'UNKNOWN')}]" def _execute_litellm_completion( self, @@ -379,7 +530,7 @@ def _execute_litellm_completion( top_p: float, **kwargs, ) -> List[str]: - """Generate completions for multiple text prompts using litellm.completion.""" + """Generate completions for a batch of prompt strings.""" if not texts: return [] @@ -390,14 +541,14 @@ def _execute_litellm_completion( exceptions = _get_litellm_exceptions() AuthenticationError = exceptions["AuthenticationError"] - completions = [] + completions: List[str] = [] self.logger.info( - f"Sending {len(texts)} requests via LiteLLM to model '{self.model_name}'..." + f"Sending {len(texts)} requests via LiteLLM to model " + f"'{self.litellm_model}'..." ) for text_prompt in texts: messages = [{"role": "user", "content": text_prompt}] - try: litellm_params = self._prepare_litellm_params( messages, max_tokens, temperature, top_p, **kwargs @@ -406,29 +557,30 @@ def _execute_litellm_completion( completion_text = self._extract_raw_response_content( response, context=f" for prompt '{text_prompt[:50]}...'" ) - except AuthenticationError as e: error_msg = ( - f"Authentication failed for model '{self.model_name}': {str(e)}" + f"Authentication failed for model '{self.litellm_model}': {e}" ) self.logger.error(error_msg) llm_provider = ( e.llm_provider if hasattr(e, "llm_provider") else "unknown" ) raise AuthenticationError( - error_msg, llm_provider, self.model_name + error_msg, llm_provider, self.litellm_model ) from e except Exception as e: self.logger.error( - f"LiteLLM completion call failed for model '{self.model_name}' for prompt '{text_prompt[:50]}...': {e}", + f"LiteLLM completion call failed for model " + f"'{self.litellm_model}' for prompt " + f"'{text_prompt[:50]}...': {e}", exc_info=True, ) completion_text = f" [GENERATION_ERROR: {type(e).__name__}]" - full_text = text_prompt + completion_text - completions.append(full_text) + completions.append(text_prompt + completion_text) self.logger.info( - f"Finished LiteLLM requests for model '{self.model_name}'. Generated {len(completions)} responses." + f"Finished LiteLLM requests for model '{self.litellm_model}'. " + f"Generated {len(completions)} responses." ) return completions diff --git a/hackagent/router/adapters/ollama.py b/hackagent/router/adapters/ollama.py index f27098bb..77b3fe49 100644 --- a/hackagent/router/adapters/ollama.py +++ b/hackagent/router/adapters/ollama.py @@ -2,498 +2,148 @@ # SPDX-License-Identifier: Apache-2.0 """ -Ollama Agent Adapter +Ollama adapter built on top of LiteLLM. -This adapter provides direct integration with Ollama for running local LLMs. -It uses Ollama's native HTTP API for efficient communication. +LiteLLM ships with an ``ollama_chat`` provider that targets a local or +remote Ollama server's ``/api/chat`` endpoint, so we no longer have to +hand-roll the HTTP calls. This adapter just pins the provider prefix, +normalises the endpoint URL the way the previous direct adapter did, and +translates the unified ``thinking`` knob into Ollama's ``think`` parameter. """ -from hackagent.logger import get_logger import os +from hackagent.logger import get_logger from typing import Any, Dict, List, Optional -import requests +from .base import AdapterConfigurationError +from .litellm import LiteLLMAgent -from .base import Agent, AdapterConfigurationError, AdapterInteractionError - -# --- Custom Exceptions (subclass from base) --- class OllamaConfigurationError(AdapterConfigurationError): """Custom exception for Ollama adapter configuration issues.""" pass -class OllamaConnectionError(AdapterInteractionError): - """Custom exception for Ollama connection issues.""" - - pass - - logger = get_logger(__name__) -class OllamaAgent(Agent): +class OllamaAgent(LiteLLMAgent): """ - Adapter for interacting with Ollama's native HTTP API. - - This adapter provides direct integration with Ollama for running local LLMs, - bypassing LiteLLM for more efficient and direct communication. - - Ollama API Endpoints: - - /api/generate: Generate completions (used for text generation) - - /api/chat: Chat completions (used for chat-based models) - - /api/tags: List available models - - /api/show: Show model information + Adapter for an Ollama server. Configuration: - - 'name': Model name (e.g., "llama3", "mistral", "codellama") - - 'endpoint': Ollama API base URL (default: "http://localhost:11434") - - 'max_tokens': Maximum tokens to generate (default: 100) - - 'temperature': Sampling temperature (default: 0.8) - - 'top_p': Top-p sampling parameter (default: 0.95) - - 'top_k': Top-k sampling parameter (optional) - - 'num_ctx': Context window size (optional) - - 'stream': Whether to stream responses (default: False) - - 'thinking': Optional bool/level controlling Ollama think traces + - ``name``: Ollama model tag (e.g. ``"llama3"``, ``"mistral"``). + - ``endpoint`` (optional): Ollama base URL. Defaults to + ``$OLLAMA_BASE_URL`` if set, otherwise ``http://localhost:11434``. + API-path suffixes such as ``/api/chat`` are stripped automatically + so users can paste their browser URL. + - ``thinking`` (optional): see :class:`LiteLLMAgent` for the + accepted shapes. Translated into Ollama's native ``think`` field. + - ``top_k`` / ``num_ctx`` / ``stream`` (optional): forwarded as + Ollama generation options. """ ADAPTER_TYPE = "OllamaAgent" + PROVIDER_PREFIX = "ollama_chat" DEFAULT_ENDPOINT = "http://localhost:11434" def __init__(self, id: str, config: Dict[str, Any]): - """ - Initializes the OllamaAgent. - - Args: - id: The unique identifier for this Ollama agent instance. - config: Configuration dictionary for the Ollama agent. - Expected keys: - - 'name': Model name (required, e.g., "llama3", "mistral") - - 'endpoint' (optional): Ollama API base URL (default: http://localhost:11434) - - 'max_tokens' (optional): Default max tokens for generation (default: 100) - - 'temperature' (optional): Default temperature (default: 0.8) - - 'top_p' (optional): Default top_p (default: 0.95) - - 'top_k' (optional): Default top_k sampling - - 'num_ctx' (optional): Context window size - - 'stream' (optional): Enable streaming (default: False) - - 'thinking' (optional): Forwarded as Ollama `think` - """ - super().__init__(id, config) - - # Require model name using base class helper - self.model_name = self._require_config_key("name", OllamaConfigurationError) - - # Handle endpoint configuration - # Priority: config['endpoint'] > OLLAMA_BASE_URL env var > default - self.api_base_url: str = self._get_config_key("endpoint") - if not self.api_base_url: - self.api_base_url = os.environ.get("OLLAMA_BASE_URL", self.DEFAULT_ENDPOINT) - - # Normalize endpoint: remove trailing slash and /api/* suffixes - self.api_base_url = self._normalize_endpoint(self.api_base_url) + # Resolve and normalise the endpoint before delegating to the base. + effective_endpoint = config.get("endpoint") or os.environ.get( + "OLLAMA_BASE_URL", self.DEFAULT_ENDPOINT + ) + effective_endpoint = self._normalize_endpoint(effective_endpoint) + config = {**config, "endpoint": effective_endpoint} - # Initialize default generation parameters using base class method - self._init_generation_params() + try: + super().__init__(id, config) + except AdapterConfigurationError as e: + raise OllamaConfigurationError(str(e)) from e - # Additional Ollama-specific parameters + # Ollama-specific generation options that LiteLLM forwards via + # ``optional_params`` (any extra kwarg passed to + # ``litellm.completion`` for the ``ollama_chat`` provider). self.default_top_k = self._get_config_key("top_k") self.default_num_ctx = self._get_config_key("num_ctx") self.default_stream = self._get_config_key("stream", False) - self.default_thinking = self._get_config_key("thinking") - # Request timeout - self.timeout = self._get_config_key( - "timeout", self._get_config_key("request_timeout", 120) - ) - - self.logger.info( - f"OllamaAgent '{self.id}' initialized for model: '{self.model_name}' " - f"at endpoint: '{self.api_base_url}'" - ) - - def _normalize_endpoint(self, endpoint: str) -> str: - """ - Normalize the Ollama endpoint URL. - - Strips trailing slashes and common API path suffixes that users might - mistakenly include (/api/generate, /api/chat, /api/tags, etc.). - - Args: - endpoint: The raw endpoint URL from configuration - - Returns: - Normalized base URL for Ollama API - """ + @staticmethod + def _normalize_endpoint(endpoint: str) -> str: + """Strip trailing slash and Ollama API path suffixes from ``endpoint``.""" endpoint = endpoint.rstrip("/") - - # Common suffixes users mistakenly include - api_suffixes = ["/api/generate", "/api/chat", "/api/tags", "/api/show", "/api"] - for suffix in api_suffixes: + for suffix in ("/api/generate", "/api/chat", "/api/tags", "/api/show", "/api"): if endpoint.endswith(suffix): - original = endpoint endpoint = endpoint[: -len(suffix)] - self.logger.info( - f"Normalized Ollama endpoint from '{original}' to '{endpoint}' " - f"(removed '{suffix}' suffix)" - ) break - return endpoint - def _build_options( - self, - max_tokens: Optional[int] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - top_k: Optional[int] = None, - num_ctx: Optional[int] = None, - **kwargs, - ) -> Dict[str, Any]: - """ - Build Ollama options dictionary from parameters. - - Args: - max_tokens: Maximum tokens to generate - temperature: Sampling temperature - top_p: Top-p sampling parameter - top_k: Top-k sampling parameter - num_ctx: Context window size - **kwargs: Additional options to pass - - Returns: - Dictionary of Ollama options - """ - options = {} - - # Use provided values or fall back to defaults - if max_tokens is not None: - options["num_predict"] = max_tokens - elif self.default_max_tokens is not None: - options["num_predict"] = self.default_max_tokens - - if temperature is not None: - options["temperature"] = temperature - elif self.default_temperature is not None: - options["temperature"] = self.default_temperature - - if top_p is not None: - options["top_p"] = top_p - elif self.default_top_p is not None: - options["top_p"] = self.default_top_p - - if top_k is not None: - options["top_k"] = top_k - elif self.default_top_k is not None: - options["top_k"] = self.default_top_k - - if num_ctx is not None: - options["num_ctx"] = num_ctx - elif self.default_num_ctx is not None: - options["num_ctx"] = self.default_num_ctx - - # Add any additional kwargs that are valid Ollama options - valid_ollama_options = [ - "seed", - "repeat_penalty", - "presence_penalty", - "frequency_penalty", - "mirostat", - "mirostat_tau", - "mirostat_eta", - "stop", - ] - for key in valid_ollama_options: - if key in kwargs and kwargs[key] is not None: - options[key] = kwargs[key] - - return options - - def _execute_generate( - self, - prompt: str, - options: Dict[str, Any], - stream: bool = False, - system: Optional[str] = None, - thinking: Optional[Any] = None, - ) -> Dict[str, Any]: - """ - Execute a generate request to Ollama's /api/generate endpoint. - - Args: - prompt: The prompt text - options: Ollama generation options - stream: Whether to stream the response - system: Optional system prompt - - Returns: - Dictionary containing response data - """ - url = f"{self.api_base_url}/api/generate" - - payload = { - "model": self.model_name, - "prompt": prompt, - "stream": stream, - "options": options, - } - - if system: - payload["system"] = system - if thinking is not None: - payload["think"] = thinking - - self.logger.info( - f"Sending generate request to Ollama model '{self.model_name}' at '{url}'" - ) - self.logger.debug(f"Generate payload: {payload}") - - try: - response = requests.post( - url, - json=payload, - timeout=self.timeout, - headers={"Content-Type": "application/json"}, - ) - response.raise_for_status() - return response.json() - - except requests.exceptions.ConnectionError as e: - self.logger.error(f"Failed to connect to Ollama at {url}: {e}") - raise OllamaConnectionError( - f"Failed to connect to Ollama at {url}. " - f"Make sure Ollama is running: `ollama serve`" - ) from e - except requests.exceptions.Timeout as e: - self.logger.error(f"Ollama request timed out after {self.timeout}s: {e}") - raise OllamaConnectionError( - f"Ollama request timed out after {self.timeout} seconds" - ) from e - except requests.exceptions.HTTPError as e: - self.logger.error(f"Ollama HTTP error: {e}") - raise + # ---- request shaping ------------------------------------------------ - def _execute_chat( - self, - messages: List[Dict[str, str]], - options: Dict[str, Any], - stream: bool = False, - thinking: Optional[Any] = None, + def _get_completion_parameters( + self, request_data: Dict[str, Any] ) -> Dict[str, Any]: - """ - Execute a chat request to Ollama's /api/chat endpoint. - - Args: - messages: List of chat messages with 'role' and 'content' - options: Ollama generation options - stream: Whether to stream the response - - Returns: - Dictionary containing response data - """ - url = f"{self.api_base_url}/api/chat" - - payload = { - "model": self.model_name, - "messages": messages, - "stream": stream, - "options": options, - } - if thinking is not None: - payload["think"] = thinking - - self.logger.info( - f"Sending chat request to Ollama model '{self.model_name}' at '{url}' " - f"with {len(messages)} messages" - ) - self.logger.debug(f"Chat payload: {payload}") - - try: - response = requests.post( - url, - json=payload, - timeout=self.timeout, - headers={"Content-Type": "application/json"}, - ) - response.raise_for_status() - return response.json() - - except requests.exceptions.ConnectionError as e: - self.logger.error(f"Failed to connect to Ollama at {url}: {e}") - raise OllamaConnectionError( - f"Failed to connect to Ollama at {url}. " - f"Make sure Ollama is running: `ollama serve`" - ) from e - except requests.exceptions.Timeout as e: - self.logger.error(f"Ollama request timed out after {self.timeout}s: {e}") - raise OllamaConnectionError( - f"Ollama request timed out after {self.timeout} seconds" - ) from e - except requests.exceptions.HTTPError as e: - self.logger.error(f"Ollama HTTP error: {e}") - raise - - def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: - """ - Processes an incoming request using Ollama's API. - - This method handles both 'prompt' (for /api/generate) and 'messages' - (for /api/chat) formats, automatically selecting the appropriate endpoint. - - Args: - request_data: The data for the agent to process. Expected keys: - - 'prompt' or 'messages': The input for generation - - 'max_tokens' (optional): Override default max tokens - - 'temperature' (optional): Override default temperature - - 'top_p' (optional): Override default top_p - - 'top_k' (optional): Override default top_k - - 'system' (optional): System prompt for generate endpoint - - 'stream' (optional): Enable streaming - - 'thinking' (optional): Forwarded as Ollama `think` - - Returns: - A dictionary containing: - - 'status_code': HTTP-like status code - - 'raw_request': The original request data - - 'raw_response': The raw Ollama response - - 'processed_response': The generated text - - 'error_message': Error message if any - - 'agent_specific_data': Ollama-specific metadata - """ - self.logger.info( - f"OllamaAgent '{self.id}' handling request for model '{self.model_name}'" - ) - - # Validate request using base class method - is_valid, prompt, messages = self._validate_request(request_data) - - if not is_valid: - error_msg = "Request data must include either 'messages' or 'prompt' field." - self.logger.warning(error_msg) - return self._build_error_response( - error_message=error_msg, - status_code=400, - raw_request=request_data, - ) - - # Build options from request data - max_tokens_value = request_data.get("max_tokens") - options = self._build_options( - max_tokens=max_tokens_value, - temperature=request_data.get("temperature"), - top_p=request_data.get("top_p"), - top_k=request_data.get("top_k"), - num_ctx=request_data.get("num_ctx"), - seed=request_data.get("seed"), - repeat_penalty=request_data.get("repeat_penalty"), - stop=request_data.get("stop"), - ) - - stream = request_data.get("stream", self.default_stream) - system = request_data.get("system") - thinking = request_data.get("thinking", self.default_thinking) - - try: - if messages: - # Use chat endpoint - raw_response = self._execute_chat( - messages, options, stream, thinking=thinking - ) - # Chat response has message.content - processed_response = raw_response.get("message", {}).get("content", "") + """Inject Ollama-specific defaults (top_k, num_ctx, stream).""" + params = super()._get_completion_parameters(request_data) + if "top_k" not in params and self.default_top_k is not None: + params["top_k"] = self.default_top_k + if "num_ctx" not in params and self.default_num_ctx is not None: + params["num_ctx"] = self.default_num_ctx + if "stream" not in params and self.default_stream: + params["stream"] = self.default_stream + return params + + def _get_excluded_request_keys(self) -> set: + base = super()._get_excluded_request_keys() + return base | {"top_k", "num_ctx", "stream", "system"} + + def _apply_thinking(self, litellm_params: Dict[str, Any], thinking: Any) -> None: + """Translate ``thinking`` into Ollama's native ``think`` field. + + Ollama accepts a boolean (``true``/``false``) or a reasoning level + such as ``"low"``/``"medium"``/``"high"`` depending on the model. + Dicts and ints are coerced into the most reasonable boolean. + """ + if thinking is None: + return + if isinstance(thinking, bool): + litellm_params["think"] = thinking + elif isinstance(thinking, str): + litellm_params["think"] = thinking + elif isinstance(thinking, int): + litellm_params["think"] = thinking > 0 + elif isinstance(thinking, dict): + kind = (thinking.get("type") or "").lower() + if kind == "disabled": + litellm_params["think"] = False else: - # Use generate endpoint - if prompt is None: - raise ValueError("Prompt request resolved to None") - raw_response = self._execute_generate( - prompt, - options, - stream, - system, - thinking=thinking, - ) - # Generate response has 'response' field - processed_response = raw_response.get("response", "") - - self.logger.info( - f"Ollama request successful. Response length: {len(processed_response)} chars" - ) - - return self._build_success_response( - processed_response=processed_response, - raw_request=request_data, - raw_response_body=raw_response, - agent_specific_data={ - "model_name": self.model_name, - "endpoint": self.api_base_url, - "invoked_options": options, - "invoked_thinking": thinking, - "eval_count": raw_response.get("eval_count"), - "eval_duration": raw_response.get("eval_duration"), - "prompt_eval_count": raw_response.get("prompt_eval_count"), - "total_duration": raw_response.get("total_duration"), - }, - ) - - except OllamaConnectionError as e: - error_msg = f"Ollama connection error: {e}" - self.logger.error(error_msg) - return self._build_error_response( - error_message=error_msg, - status_code=503, - raw_request=request_data, - ) + litellm_params["think"] = True + else: + litellm_params["think"] = bool(thinking) - except requests.exceptions.HTTPError as e: - error_msg = f"Ollama HTTP error: {e}" - status_code = e.response.status_code if e.response else 500 - self.logger.error(error_msg) - return self._build_error_response( - error_message=error_msg, - status_code=status_code, - raw_request=request_data, - raw_response_body=e.response.text if e.response else None, - ) - - except Exception as e: - error_msg = ( - f"Ollama generation error: [GENERATION_ERROR: {type(e).__name__}] {e}" - ) - self.logger.error(error_msg, exc_info=True) - return self._build_error_response( - error_message=error_msg, - status_code=500, - raw_request=request_data, - ) + # ---- diagnostics passthroughs (kept for callers/tests) -------------- def list_models(self) -> List[Dict[str, Any]]: - """ - List available models from Ollama. + """Return models reported by ``GET {endpoint}/api/tags``.""" + import requests - Returns: - List of model information dictionaries - """ - url = f"{self.api_base_url}/api/tags" try: - response = requests.get(url, timeout=self.timeout) + response = requests.get(f"{self.api_base_url}/api/tags", timeout=30) response.raise_for_status() - data = response.json() - return data.get("models", []) + return response.json().get("models", []) except Exception as e: self.logger.error(f"Failed to list Ollama models: {e}") return [] def model_info(self) -> Dict[str, Any]: - """ - Get information about the current model. + """Return ``POST {endpoint}/api/show`` payload for the current model.""" + import requests - Returns: - Dictionary with model information - """ - url = f"{self.api_base_url}/api/show" try: response = requests.post( - url, json={"name": self.model_name}, timeout=self.timeout + f"{self.api_base_url}/api/show", + json={"name": self.model_name}, + timeout=30, ) response.raise_for_status() return response.json() @@ -502,21 +152,14 @@ def model_info(self) -> Dict[str, Any]: return {} def is_available(self) -> bool: - """ - Check if Ollama is available and the model is loaded. - - Returns: - True if Ollama is reachable and the model exists - """ + """True iff the configured model appears in ``/api/tags``.""" try: models = self.list_models() - model_names = [m.get("name", "").split(":")[0] for m in models] - # Check if our model (without tag) exists if not self.model_name: return False base_model = self.model_name.split(":")[0] - return base_model in model_names or self.model_name in [ - m.get("name") for m in models - ] + names: List[Optional[str]] = [m.get("name") for m in models] + base_names = [(m.get("name") or "").split(":")[0] for m in models] + return base_model in base_names or self.model_name in names except Exception: return False diff --git a/hackagent/router/adapters/openai.py b/hackagent/router/adapters/openai.py index 415102ce..597ec050 100644 --- a/hackagent/router/adapters/openai.py +++ b/hackagent/router/adapters/openai.py @@ -1,500 +1,118 @@ # Copyright 2026 - AI4I. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +""" +OpenAI-compatible adapter built on top of LiteLLM. -from hackagent.logger import get_logger -import re -import time -from typing import Any, Dict, List, Optional - -from .base import ChatCompletionsAgent, AdapterConfigurationError - -# Lazy-load openai to improve startup time -_openai_module = None -_openai_available = None - -# Module-level names for test patching compatibility -# These will be populated when _get_openai() is first called, -# but tests can patch them directly -OpenAI = None -OPENAI_AVAILABLE = None - - -def _get_openai(): - """Lazily import and return the openai module.""" - global _openai_module, _openai_available, OpenAI, OPENAI_AVAILABLE - if _openai_module is None: - try: - import openai as _openai - - _openai_module = _openai - _openai_available = True - # Also set module-level names for compatibility - OpenAI = _openai.OpenAI - OPENAI_AVAILABLE = True - except ImportError: - _openai_module = False - _openai_available = False - OPENAI_AVAILABLE = False - return _openai_module if _openai_module else None - - -def _get_openai_exceptions(): - """Get OpenAI exception classes, or dummy classes if not available.""" - openai = _get_openai() - if openai: - return ( - openai.OpenAIError, - openai.APIConnectionError, - openai.RateLimitError, - openai.APITimeoutError, - ) - else: - # Return dummy exceptions - return (Exception, Exception, Exception, Exception) - - -def _is_openai_available(): - """Check if openai is available.""" - global _openai_available, OPENAI_AVAILABLE - # Allow test patches to override OPENAI_AVAILABLE - if OPENAI_AVAILABLE is not None: - return OPENAI_AVAILABLE - if _openai_available is None: - _get_openai() - return _openai_available +The OpenAI agent type used to talk to the OpenAI SDK directly. As of +issue #379 every chat-completion adapter routes through LiteLLM, so this +class is now a thin specialisation of :class:`LiteLLMAgent` that pins the +provider prefix to ``openai`` and translates the unified ``thinking`` knob +into OpenAI's ``reasoning_effort`` field for the o-series models. +""" +from hackagent.logger import get_logger +from typing import Any, Dict -def _check_openai_available(): - global OPENAI_AVAILABLE - if OPENAI_AVAILABLE is None: - OPENAI_AVAILABLE = _is_openai_available() - return OPENAI_AVAILABLE +from .base import AdapterConfigurationError +from .litellm import LiteLLMAgent -# --- Custom Exceptions (subclass from base) --- +# Keep this exception public for backwards compatibility — downstream code +# (and several tests) import OpenAIConfigurationError from this module. class OpenAIConfigurationError(AdapterConfigurationError): """Custom exception for OpenAI adapter configuration issues.""" pass -logger = get_logger(__name__) # Module-level logger +logger = get_logger(__name__) -_CONTEXT_LIMIT_ERROR_RE = re.compile( - r"maximum context length is\s*(\d+)\s*tokens\s*and\s*your request has\s*(\d+)\s*input tokens", - flags=re.IGNORECASE, -) +# OpenAI reasoning models that natively understand ``reasoning_effort``. +# ``thinking=True`` defaults to "medium" for these; for other models we fall +# back to LiteLLM's generic ``thinking`` payload. +_OPENAI_REASONING_MODEL_PREFIXES = ("o1", "o3", "o4", "gpt-5", "gpt-6") -class OpenAIAgent(ChatCompletionsAgent): - """ - Adapter for interacting with AI agents built using the OpenAI SDK. - This adapter supports OpenAI's chat completions API, including support for - function calling and tool use, which are common patterns in agent implementations. +class OpenAIAgent(LiteLLMAgent): + """ + Adapter for OpenAI-compatible chat endpoints. + + Configured via the ``OPENAI_SDK`` agent type. Internally uses LiteLLM, + so any OpenAI-compatible server (the official API, a local model server + exposing ``/v1/chat/completions``, OpenRouter, etc.) works as the + endpoint. + + Reasoning / "thinking": + Set ``thinking`` in the adapter config or per request to enable or + disable the model's reasoning. For the o-series and newer GPT + reasoning models the value is translated to ``reasoning_effort`` + (low/medium/high). """ ADAPTER_TYPE = "OpenAIAgent" - DEFAULT_TEMPERATURE = 1.0 # OpenAI default - MAX_CONNECTION_RETRIES_CAP = 5 + PROVIDER_PREFIX = "openai" + DEFAULT_TEMPERATURE = 1.0 def __init__(self, id: str, config: Dict[str, Any]): - """ - Initializes the OpenAIAgent. - - Args: - id: The unique identifier for this OpenAI agent instance. - config: Configuration dictionary for the OpenAI agent. - Expected keys: - - 'name': Model name (e.g., "gpt-4", "gpt-3.5-turbo"). - - 'endpoint' (optional): Base URL for the API (for custom endpoints). - - 'api_key' (optional): Name of the environment variable holding the API key, - or the API key itself. Defaults to OPENAI_API_KEY env var. - - 'max_tokens' (optional): Default max tokens for generation. - - 'temperature' (optional): Default temperature (defaults to 1.0). - - 'timeout' (optional): Default request timeout. - - 'tools' (optional): List of tool/function definitions for function calling. - - 'tool_choice' (optional): Controls which tools the model can call. - """ - super().__init__(id, config) - - if not _is_openai_available(): - msg = ( - f"OpenAI SDK is not installed. Please install it with: pip install openai. " - f"OpenAIAgent: {self.id}" - ) - self.logger.error(msg) - raise OpenAIConfigurationError(msg) + # Custom endpoints don't always require a model name; default to + # ``"default"`` (the server then decides) when an endpoint is set + # but no model is provided. + if "name" not in config and config.get("endpoint"): + config = {**config, "name": config.get("name", "default")} - self.api_base_url: Optional[str] = self._get_config_key("endpoint") - - # Model name defaults to "default" for custom endpoints (server decides the model) - if "name" not in self.config: - if self.api_base_url: - # Custom endpoint - use a default model name, server will handle it - self.model_name = self._get_config_key("name", "default") - self.logger.info( - "No model name specified for custom endpoint, using 'default'" - ) - else: - self.model_name = self._require_config_key( - "name", OpenAIConfigurationError - ) - else: - self.model_name = self.config["name"] - - # Handle API key resolution - self.actual_api_key = self._resolve_api_key( - config_key="api_key", env_var_fallback="OPENAI_API_KEY" - ) - - # For custom endpoints without API key, use a placeholder - # (some local servers don't require authentication) + try: + super().__init__(id, config) + except AdapterConfigurationError as e: + # Re-raise as the OpenAI-flavoured subclass so legacy callers + # that catch OpenAIConfigurationError keep working. + raise OpenAIConfigurationError(str(e)) from e + + # For custom endpoints without an API key, use a placeholder so the + # OpenAI client (under LiteLLM's hood) doesn't error out. if not self.actual_api_key and self.api_base_url: self.actual_api_key = "not-required" self.logger.info( - f"No API key configured for custom endpoint '{self.api_base_url}', using placeholder" + f"No API key configured for custom endpoint " + f"'{self.api_base_url}', using placeholder" ) - # Initialize OpenAI client - # Check for test-patched OpenAI first, then fall back to lazy-loaded module - global OpenAI - if OpenAI is not None: - # Use patched or pre-loaded OpenAI class - openai_client_class = OpenAI - else: - # Lazy load the module - openai = _get_openai() - if openai is None: - raise OpenAIConfigurationError("OpenAI SDK is unavailable") - openai_client_class = openai.OpenAI - - client_kwargs = {} - if self.actual_api_key: - client_kwargs["api_key"] = self.actual_api_key - if self.api_base_url: - client_kwargs["base_url"] = self.api_base_url - - timeout = self._get_config_key( - "timeout", self._get_config_key("request_timeout", 120) - ) - client_kwargs["timeout"] = timeout - - self.client = openai_client_class(**client_kwargs) - - self.logger.info( - f"OpenAIAgent '{self.id}' initialized for model: '{self.model_name}'" - + (f" API Base: '{self.api_base_url}'" if self.api_base_url else "") - ) - - # Store default generation parameters - self.default_max_tokens = self._get_config_key( - "max_tokens", self.DEFAULT_MAX_TOKENS - ) - self.default_temperature = self._get_config_key( - "temperature", self.DEFAULT_TEMPERATURE - ) - # Provider-specific request payload (e.g., OpenRouter "reasoning"). - # This can be overridden per-call via request_data["extra_body"]. - self.default_extra_body = self._get_config_key("extra_body") - self.default_tools = self._get_config_key("tools") - self.default_tool_choice = self._get_config_key("tool_choice") - # Retry only transient transport failures and cap retries to avoid long hangs. - self.max_connection_retries = self._get_max_connection_retries() - - def _get_excluded_request_keys(self) -> set: - """Returns keys to exclude when extracting additional kwargs.""" - base_keys = super()._get_excluded_request_keys() - return base_keys | {"tools", "tool_choice"} - - def _get_completion_parameters( - self, request_data: Dict[str, Any] - ) -> Dict[str, Any]: - """Extract parameters including OpenAI-specific tools.""" - params = super()._get_completion_parameters(request_data) - - # Add OpenAI-specific parameters - params["tools"] = request_data.get("tools", self.default_tools) - params["tool_choice"] = request_data.get( - "tool_choice", self.default_tool_choice - ) - if "extra_body" in request_data: - params["extra_body"] = request_data.get("extra_body") - elif self.default_extra_body is not None: - if isinstance(self.default_extra_body, dict): - params["extra_body"] = dict(self.default_extra_body) - else: - params["extra_body"] = self.default_extra_body - - return params + # ---- thinking translation ------------------------------------------- - def _execute_completion( - self, - messages: List[Dict[str, str]], - **kwargs, - ) -> Dict[str, Any]: - """ - Execute the completion request using OpenAI's chat completions API. + def _is_reasoning_model(self) -> bool: + bare = self.model_name.split("/")[-1] + return bare.startswith(_OPENAI_REASONING_MODEL_PREFIXES) - Args: - messages: List of message dictionaries with 'role' and 'content'. - **kwargs: Additional parameters (temperature, max_tokens, tools, etc.) + def _apply_thinking(self, litellm_params: Dict[str, Any], thinking: Any) -> None: + """Map ``thinking`` to ``reasoning_effort`` for OpenAI reasoning models. - Returns: - A dictionary containing the result with 'success', 'content', etc. + Non-reasoning models fall back to LiteLLM's default ``thinking`` + passthrough, so callers can still attach arbitrary provider payload + if they need to. """ - max_tokens = kwargs.pop("max_tokens", None) - temperature = kwargs.pop("temperature", self.default_temperature) - tools = kwargs.pop("tools", None) - tool_choice = kwargs.pop("tool_choice", None) - - self.logger.info( - f"Sending request to OpenAI model '{self.model_name}' with {len(messages)} messages..." - ) - - try: - openai_params = { - "model": self.model_name, - "messages": messages, - "temperature": temperature, - } - - if max_tokens is not None: - openai_params["max_tokens"] = max_tokens - - if tools: - openai_params["tools"] = tools - if tool_choice: - openai_params["tool_choice"] = tool_choice - - # Add any additional kwargs - openai_params.update(kwargs) - - # Log request parameters at debug level - self.logger.debug( - f"OpenAI API request params: model={self.model_name}, " - f"base_url={self.api_base_url}, " - f"messages={messages[:1] if messages else []}, " - f"temperature={temperature}, max_tokens={max_tokens}, " - f"extra_kwargs={list(kwargs.keys())}" - ) - - # Make the API call (with one automatic retry for context-limit token errors) - openai = _get_openai() - api_connection_error_type = openai.APIConnectionError if openai else tuple() - - connection_retry_count = 0 - while True: - try: - try: - response = self.client.chat.completions.create(**openai_params) - except Exception as first_error: - adjusted_max_tokens = self._get_adjusted_max_tokens_from_error( - first_error, openai_params.get("max_tokens") - ) - if adjusted_max_tokens is not None: - self.logger.warning( - "OpenAI request exceeded context window; retrying with " - f"max_tokens={adjusted_max_tokens} for model '{self.model_name}'" - ) - openai_params["max_tokens"] = adjusted_max_tokens - response = self.client.chat.completions.create( - **openai_params - ) - else: - raise first_error - break - except Exception as connection_error: - is_connection_error = isinstance( - connection_error, api_connection_error_type - ) - if ( - is_connection_error - and connection_retry_count < self.max_connection_retries - ): - connection_retry_count += 1 - backoff_seconds = min( - 0.5 * (2 ** (connection_retry_count - 1)), 4.0 - ) - self.logger.warning( - "OpenAI API connection error for model '%s'; retry %d/%d in %.1fs", - self.model_name, - connection_retry_count, - self.max_connection_retries, - backoff_seconds, - ) - time.sleep(backoff_seconds) - continue - raise connection_error - - # Extract response data - message = response.choices[0].message - content = message.content if message.content else "" - - # For reasoning models (e.g., o1-preview, o1-mini), check reasoning field - if not content and hasattr(message, "reasoning") and message.reasoning: - content = message.reasoning - self.logger.info( - f"OpenAI extracted text from 'reasoning' field (reasoning model) for '{self.model_name}'" - ) - - # Check if there are tool calls - tool_calls = None - if hasattr(message, "tool_calls") and message.tool_calls: - tool_calls = [ - { - "id": tc.id, - "type": tc.type, - "function": { - "name": tc.function.name, - "arguments": tc.function.arguments, - }, - } - for tc in message.tool_calls - ] - - result = { - "success": True, - "content": content, - "finish_reason": response.choices[0].finish_reason, - "usage": response.usage.model_dump() if response.usage else None, - "model": response.model, - "tool_calls": tool_calls, - "raw_response": response, - } - - self.logger.info( - f"Successfully received response from OpenAI model '{self.model_name}'. " - f"Finish reason: {result['finish_reason']}" - ) - - return result - - except Exception as e: - # Get OpenAI exceptions dynamically - openai = _get_openai() - if openai: - OpenAIError = openai.OpenAIError - APITimeoutError = openai.APITimeoutError - RateLimitError = openai.RateLimitError - APIConnectionError = openai.APIConnectionError + if thinking is None: + return + + if self._is_reasoning_model(): + if thinking is True: + litellm_params["reasoning_effort"] = "medium" + elif thinking is False: + # Explicit disable: omit the parameter entirely so the + # provider falls back to whatever its server-side default is. + # (OpenAI doesn't currently accept reasoning_effort="off".) + return + elif isinstance(thinking, str): + litellm_params["reasoning_effort"] = thinking + elif isinstance(thinking, dict): + effort = thinking.get("reasoning_effort") or thinking.get("effort") + if effort: + litellm_params["reasoning_effort"] = effort + else: + litellm_params["thinking"] = dict(thinking) else: - # If openai not available, these will never match - OpenAIError = APITimeoutError = RateLimitError = APIConnectionError = ( - type(None) - ) - - if isinstance(e, APITimeoutError): - self.logger.error( - f"OpenAI API timeout for model '{self.model_name}': {e}", - exc_info=True, - ) - return { - "success": False, - "error_type": "timeout", - "error_message": str(e), - } - elif isinstance(e, RateLimitError): - self.logger.error( - f"OpenAI rate limit exceeded for model '{self.model_name}': {e}", - exc_info=True, - ) - return { - "success": False, - "error_type": "rate_limit", - "error_message": str(e), - } - elif isinstance(e, APIConnectionError): - self.logger.error( - f"OpenAI API connection error for model '{self.model_name}': {e}", - exc_info=True, - ) - return { - "success": False, - "error_type": "connection", - "error_message": str(e), - } - elif isinstance(e, OpenAIError): - self.logger.error( - f"OpenAI API error for model '{self.model_name}': {e}", - exc_info=True, - ) - return { - "success": False, - "error_type": "api_error", - "error_message": str(e), - } - else: - self.logger.exception( - f"Unexpected error during OpenAI completion for model '{self.model_name}': {e}" - ) - return { - "success": False, - "error_type": "unexpected", - "error_message": f"{type(e).__name__}: {str(e)}", - } - - def _get_max_connection_retries(self) -> int: - """Get connection-retry budget from config, capped to avoid excessive waits.""" - raw_value = self._get_config_key( - "max_connection_retries", self.MAX_CONNECTION_RETRIES_CAP - ) - try: - parsed = int(raw_value) - except (TypeError, ValueError): - parsed = self.MAX_CONNECTION_RETRIES_CAP - return max(0, min(parsed, self.MAX_CONNECTION_RETRIES_CAP)) - - def _get_adjusted_max_tokens_from_error( - self, error: Exception, current_max_tokens: Any - ) -> Optional[int]: - """Parse context-limit errors and return a safe reduced max_tokens value.""" - if current_max_tokens is None: - return None - - message = str(error) - match = _CONTEXT_LIMIT_ERROR_RE.search(message) - if not match: - return None - - try: - max_context = int(match.group(1)) - input_tokens = int(match.group(2)) - current = int(current_max_tokens) - except (TypeError, ValueError): - return None - - available = max_context - input_tokens - # Keep a small safety margin to avoid repeated boundary errors. - safe_max_tokens = max(1, available - 8) - if safe_max_tokens >= current: - return None - return safe_max_tokens - - def _build_agent_specific_data( - self, - completion_result: Dict[str, Any], - parameters: Dict[str, Any], - ) -> Dict[str, Any]: - """Build OpenAI-specific response data including tool calls.""" - data = super()._build_agent_specific_data(completion_result, parameters) - - # Expose provider completion metadata for latency/quality diagnostics. - if completion_result.get("finish_reason") is not None: - data["finish_reason"] = completion_result.get("finish_reason") - if completion_result.get("usage") is not None: - data["usage"] = completion_result.get("usage") - if completion_result.get("model") is not None: - data["provider_model"] = completion_result.get("model") - - # Add tool calls if present - if completion_result.get("tool_calls"): - data["tool_calls"] = completion_result["tool_calls"] - - # Add tools_provided flag - data["invoked_parameters"]["tools_provided"] = ( - parameters.get("tools") is not None - ) + super()._apply_thinking(litellm_params, thinking) + return - return data + # Non-reasoning model: defer to the generic translation. + super()._apply_thinking(litellm_params, thinking) diff --git a/hackagent/router/router.py b/hackagent/router/router.py index de4f3bf3..3c3f36eb 100644 --- a/hackagent/router/router.py +++ b/hackagent/router/router.py @@ -166,144 +166,60 @@ def _configure_and_instantiate_adapter( adapter_operational_config.copy() if adapter_operational_config else {} ) + # Every adapter now subclasses LiteLLMAgent, so the same set of + # config fields applies (with ADK adding a required user_id). + # ``name`` is the model string, ``endpoint`` is the API base URL. + if "name" not in adapter_instance_config: + metadata = self.backend_agent.metadata + if isinstance(metadata, dict) and "name" in metadata: + adapter_instance_config["name"] = metadata["name"] + else: + logger.warning( + f"Agent '{name}' (Type: {agent_type.value}) missing 'name' " + f"(model string) in metadata. Defaulting to agent name " + f"'{self.backend_agent.name}'." + ) + adapter_instance_config["name"] = self.backend_agent.name + + if "endpoint" not in adapter_instance_config and self.backend_agent.endpoint: + adapter_instance_config["endpoint"] = str(self.backend_agent.endpoint) + + # Merge through any optional generation/provider knobs stored on + # the backend agent's metadata so adapter subclasses see them. + optional_passthrough_keys = ( + "api_key", + "max_tokens", + "temperature", + "top_p", + "top_k", + "num_ctx", + "stream", + "timeout", + "thinking", + "tools", + "tool_choice", + "extra_body", + "reasoning_effort", + ) + if isinstance(self.backend_agent.metadata, dict): + for key in optional_passthrough_keys: + if ( + key not in adapter_instance_config + and key in self.backend_agent.metadata + ): + adapter_instance_config[key] = self.backend_agent.metadata[key] + if agent_type == AgentTypeEnum.GOOGLE_ADK: + # ADK uses the agent name as the app_name in its run payload. adapter_instance_config["name"] = self.backend_agent.name - adapter_instance_config["endpoint"] = str(self.backend_agent.endpoint) if "user_id" not in adapter_instance_config: logger.error( - f"CRITICAL: user_id not found in adapter_instance_config for ADK agent '{self.backend_agent.name}' just before adapter instantiation. This should have been set in __init__." + f"CRITICAL: user_id not found in adapter_instance_config " + f"for ADK agent '{self.backend_agent.name}'. Defaulting " + f"to context user_id." ) adapter_instance_config["user_id"] = self.user_id_str - elif agent_type in [AgentTypeEnum.LITELLM, AgentTypeEnum.LANGCHAIN]: - if "name" not in adapter_instance_config: - if ( - isinstance(self.backend_agent.metadata, dict) - and "name" in self.backend_agent.metadata - ): - adapter_instance_config["name"] = self.backend_agent.metadata[ - "name" - ] - else: - logger.warning( - f"Agent '{name}' (Type: {agent_type.value}) missing 'name' (model string) in metadata. " - f"Defaulting to agent name '{self.backend_agent.name}'." - ) - adapter_instance_config["name"] = self.backend_agent.name - - # Always use backend agent's endpoint if not already in config - if ( - "endpoint" not in adapter_instance_config - and self.backend_agent.endpoint - ): - adapter_instance_config["endpoint"] = str(self.backend_agent.endpoint) - - optional_litellm_keys = [ - "api_key", - "max_tokens", - "temperature", - "top_p", - ] - if isinstance(self.backend_agent.metadata, dict): - for key in optional_litellm_keys: - if ( - key not in adapter_instance_config - and key in self.backend_agent.metadata - ): - adapter_instance_config[key] = self.backend_agent.metadata[key] - - elif agent_type == AgentTypeEnum.OPENAI_SDK: - if "name" not in adapter_instance_config: - if ( - isinstance(self.backend_agent.metadata, dict) - and "name" in self.backend_agent.metadata - ): - adapter_instance_config["name"] = self.backend_agent.metadata[ - "name" - ] - # For custom endpoints, model name is optional (will default to 'default') - # Only raise error if no endpoint is configured (i.e., using OpenAI API directly) - elif ( - "endpoint" not in adapter_instance_config - and not self.backend_agent.endpoint - ): - raise ValueError( - f"OpenAI SDK agent '{name}' (ID: {registration_key}) missing " - f"'name' (model string) in adapter_operational_config or backend metadata. " - f"Cannot configure OpenAIAgent." - ) - else: - # Fall back to the registered agent name (e.g. full local model path) - logger.warning( - f"Agent '{name}' (Type: {agent_type.value}) missing 'name' in metadata. " - f"Defaulting to agent name '{self.backend_agent.name}'." - ) - adapter_instance_config["name"] = self.backend_agent.name - - # Always use backend agent's endpoint if not already in config - if ( - "endpoint" not in adapter_instance_config - and self.backend_agent.endpoint - ): - adapter_instance_config["endpoint"] = str(self.backend_agent.endpoint) - - optional_openai_keys = [ - "api_key", - "max_tokens", - "temperature", - "tools", - "tool_choice", - ] - if isinstance(self.backend_agent.metadata, dict): - for key in optional_openai_keys: - if ( - key not in adapter_instance_config - and key in self.backend_agent.metadata - ): - adapter_instance_config[key] = self.backend_agent.metadata[key] - - elif agent_type == AgentTypeEnum.OLLAMA: - # Configure Ollama adapter - if "name" not in adapter_instance_config: - if ( - isinstance(self.backend_agent.metadata, dict) - and "name" in self.backend_agent.metadata - ): - adapter_instance_config["name"] = self.backend_agent.metadata[ - "name" - ] - else: - logger.warning( - f"Agent '{name}' (Type: {agent_type.value}) missing 'name' (model string) in metadata. " - f"Defaulting to agent name '{self.backend_agent.name}'." - ) - adapter_instance_config["name"] = self.backend_agent.name - - # Always use backend agent's endpoint if not already in config - if ( - "endpoint" not in adapter_instance_config - and self.backend_agent.endpoint - ): - adapter_instance_config["endpoint"] = str(self.backend_agent.endpoint) - - optional_ollama_keys = [ - "max_tokens", - "temperature", - "top_p", - "top_k", - "num_ctx", - "stream", - "timeout", - "thinking", - ] - if isinstance(self.backend_agent.metadata, dict): - for key in optional_ollama_keys: - if ( - key not in adapter_instance_config - and key in self.backend_agent.metadata - ): - adapter_instance_config[key] = self.backend_agent.metadata[key] - try: logger.debug( f"ROUTER_DEBUG: About to call adapter_class(id='{registration_key}', config_keys={list(adapter_instance_config.keys())})" diff --git a/tests/integration/adapters/test_openai.py b/tests/integration/adapters/test_openai.py index 74869a41..4790bf8c 100644 --- a/tests/integration/adapters/test_openai.py +++ b/tests/integration/adapters/test_openai.py @@ -55,7 +55,10 @@ def test_adapter_initialization( assert adapter.id == "test_openai_init" assert adapter.model_name == openai_config["name"] - assert adapter.client is not None + # Since #379 the OpenAI adapter routes through LiteLLM, so the + # model string carries an `openai/` provider prefix and there is + # no longer a raw OpenAI SDK client to inspect. + assert adapter.litellm_model.endswith(openai_config["name"]) logger.info(f"OpenAI adapter initialized: model={adapter.model_name}") def test_chat_completion( diff --git a/tests/unit/adapters/test_google_adk.py b/tests/unit/adapters/test_google_adk.py index d83f8a57..ad9f7837 100644 --- a/tests/unit/adapters/test_google_adk.py +++ b/tests/unit/adapters/test_google_adk.py @@ -1,217 +1,290 @@ # Copyright 2026 - AI4I. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +""" +Unit tests for the Google ADK adapter. + +Issue #379 routes ADK through LiteLLM via a custom provider, so the +``ADKAgent`` no longer makes the HTTP calls itself — its custom handler +does. These tests exercise both layers: handler-level (HTTP transport) +and adapter-level (end-to-end via the public ``handle_request``). +""" import logging import unittest +import uuid from unittest.mock import MagicMock, patch -import requests # Added for requests.exceptions +import requests from hackagent.router.adapters.google_adk import ( ADKAgent, AgentConfigurationError, AgentInteractionError, + _get_adk_custom_llm_class, + _extract_final_text, + _last_user_text, ) -# Disable logging for tests to keep output clean logging.disable(logging.CRITICAL) -class TestADKAgentInit(unittest.TestCase): - def test_init_success_with_all_required_config(self): - adapter_id = "adk_test_agent_001" - config = { - "name": "multi_tool_agent_app", - "endpoint": "http://fake-adk-endpoint.com/api", - "user_id": "test_user_adk", - "timeout": 60, - } - try: - adapter = ADKAgent(id=adapter_id, config=config) - self.assertEqual(adapter.id, adapter_id) - self.assertEqual(adapter.name, config["name"]) - self.assertEqual(adapter.endpoint, config["endpoint"].strip("/")) - self.assertEqual(adapter.user_id, config["user_id"]) - self.assertEqual(adapter.timeout, config["timeout"]) - except AgentConfigurationError: - self.fail("ADKAgent initialization failed unexpectedly with valid config.") - - def test_init_uses_default_timeout_if_not_provided(self): - adapter_id = "adk_test_agent_002" - config = { - "name": "another_agent", - "endpoint": "http://another-endpoint.com", - "user_id": "user_abc", - } - adapter = ADKAgent(id=adapter_id, config=config) - self.assertEqual(adapter.timeout, 120) # Default timeout - - def test_init_missing_name_raises_error(self): - with self.assertRaisesRegex( - AgentConfigurationError, "Missing required configuration key 'name'" - ): - ADKAgent(id="err_agent_1", config={"endpoint": "ep", "user_id": "uid"}) - - def test_init_missing_endpoint_raises_error(self): - with self.assertRaisesRegex( - AgentConfigurationError, "Missing required configuration key 'endpoint'" - ): - ADKAgent(id="err_agent_2", config={"name": "app_name", "user_id": "uid"}) - - def test_init_missing_user_id_raises_error(self): - with self.assertRaisesRegex( - AgentConfigurationError, "Missing required configuration key 'user_id'" - ): - ADKAgent(id="err_agent_3", config={"name": "app_name", "endpoint": "ep"}) - - def test_init_endpoint_gets_stripped(self): - adapter_id = "adk_strip_test" - config = { - "name": "strip_app", - "endpoint": "http://fake-adk-endpoint.com/api/", # trailing slash - "user_id": "strip_user", - } - adapter = ADKAgent(id=adapter_id, config=config) - self.assertEqual(adapter.endpoint, "http://fake-adk-endpoint.com/api") - - -class TestADKAgentCreateSession(unittest.TestCase): - def setUp(self): - self.adapter_id = "adk_session_test_agent" - self.config = { - "name": "test_app", - "endpoint": "http://fake-adk.com", - "user_id": "test_user", - } - self.adapter = ADKAgent(id=self.adapter_id, config=self.config) - self.session_id = "test_session_123" - +def _make_handler(**overrides): + """Construct an _ADKCustomLLM with sensible defaults for tests.""" + handler_cls = _get_adk_custom_llm_class() + defaults = dict( + endpoint="http://fake-adk.com", + app_name="test_app", + user_id="test_user", + default_session_id="sess-default", + fresh_session_per_request=False, + timeout=30, + log=logging.getLogger("test"), + ) + defaults.update(overrides) + return handler_cls(**defaults) + + +class TestADKHelpers(unittest.TestCase): + def test_last_user_text_returns_last_user_string(self): + messages = [ + {"role": "system", "content": "be terse"}, + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "ack"}, + {"role": "user", "content": "second"}, + ] + self.assertEqual(_last_user_text(messages), "second") + + def test_last_user_text_handles_content_parts(self): + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": "from-parts"}], + } + ] + self.assertEqual(_last_user_text(messages), "from-parts") + + def test_last_user_text_returns_none_when_no_user_message(self): + self.assertIsNone(_last_user_text([{"role": "system", "content": "x"}])) + + def test_extract_final_text_returns_latest_text(self): + events = [ + {"content": {"parts": [{"text": "first"}]}}, + {"content": {"parts": [{"text": "final"}]}}, + ] + self.assertEqual(_extract_final_text(events), "final") + + def test_extract_final_text_handles_escalation(self): + events = [ + {"content": {"parts": [{"text": "x"}]}}, + {"actions": {"escalate": True}, "error_message": "boom"}, + ] + self.assertEqual(_extract_final_text(events), "Agent escalated: boom") + + def test_extract_final_text_returns_none_when_no_text(self): + self.assertIsNone(_extract_final_text([{"content": {}}])) + + +class TestADKCustomLLMTransport(unittest.TestCase): @patch("requests.post") - def test_create_session_internal_success(self, mock_post): - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.raise_for_status = MagicMock() # Does not raise for 200 - mock_post.return_value = mock_response - - result = self.adapter._create_session_internal(session_id=self.session_id) - self.assertTrue(result) - expected_url = f"{self.config['endpoint']}/apps/{self.config['name']}/users/{self.config['user_id']}/sessions/{self.session_id}" - mock_post.assert_called_once_with( - expected_url, headers=unittest.mock.ANY, json={}, timeout=30 + def test_create_session_success(self, mock_post): + mock_post.return_value = MagicMock( + status_code=200, raise_for_status=MagicMock() ) + handler = _make_handler() + handler._create_session(session_id="abc") + kwargs = mock_post.call_args.kwargs + self.assertEqual(kwargs["timeout"], 30) + self.assertEqual(kwargs["json"], {}) @patch("requests.post") - def test_create_session_internal_success_with_initial_state(self, mock_post): - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.raise_for_status = MagicMock() - mock_post.return_value = mock_response - initial_state = {"key": "value"} - - result = self.adapter._create_session_internal( - session_id=self.session_id, initial_state=initial_state - ) - self.assertTrue(result) - expected_url = f"{self.config['endpoint']}/apps/{self.config['name']}/users/{self.config['user_id']}/sessions/{self.session_id}" - mock_post.assert_called_once_with( - expected_url, headers=unittest.mock.ANY, json=initial_state, timeout=30 + def test_create_session_with_initial_state(self, mock_post): + mock_post.return_value = MagicMock( + status_code=200, raise_for_status=MagicMock() ) + handler = _make_handler() + handler._create_session(session_id="abc", initial_state={"k": "v"}) + self.assertEqual(mock_post.call_args.kwargs["json"], {"k": "v"}) @patch("requests.post") - def test_create_session_internal_already_exists_409(self, mock_post): - mock_response = MagicMock() - mock_response.status_code = 409 - mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError( - response=mock_response + def test_create_session_409_is_idempotent(self, mock_post): + mock_resp = MagicMock(status_code=409) + mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError( + response=mock_resp ) - mock_post.return_value = mock_response - - result = self.adapter._create_session_internal(session_id=self.session_id) - self.assertTrue(result) + mock_post.return_value = mock_resp + handler = _make_handler() + handler._create_session(session_id="abc") # no raise @patch("requests.post") - def test_create_session_internal_already_exists_400_specific_message( - self, mock_post - ): - mock_response = MagicMock() - mock_response.status_code = 400 - mock_response.text = "Session already exists for this user and app." - mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError( - response=mock_response + def test_create_session_400_with_already_exists_text_is_idempotent(self, mock_post): + mock_resp = MagicMock( + status_code=400, + text="Session already exists for this user and app.", ) - mock_post.return_value = mock_response - - result = self.adapter._create_session_internal(session_id=self.session_id) - self.assertTrue(result) + mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError( + response=mock_resp + ) + mock_post.return_value = mock_resp + handler = _make_handler() + handler._create_session(session_id="abc") @patch("requests.post") - def test_create_session_internal_http_error_other(self, mock_post): - mock_response = MagicMock() - mock_response.status_code = 500 # Other server error - mock_response.text = "Internal Server Error" - mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError( - response=mock_response + def test_create_session_other_http_error_raises(self, mock_post): + mock_resp = MagicMock(status_code=500, text="boom") + mock_resp.raise_for_status.side_effect = requests.exceptions.HTTPError( + response=mock_resp ) - mock_post.return_value = mock_response - - with self.assertRaisesRegex( - AgentInteractionError, "HTTP Error 500 creating session test_session_123" - ): - self.adapter._create_session_internal(session_id=self.session_id) + mock_post.return_value = mock_resp + handler = _make_handler() + with self.assertRaises(AgentInteractionError): + handler._create_session(session_id="abc") @patch("requests.post") - def test_create_session_internal_request_exception_timeout(self, mock_post): - mock_post.side_effect = requests.exceptions.Timeout("Request timed out") - with self.assertRaisesRegex( - AgentInteractionError, - "Request failed creating session test_session_123: Request timed out", - ): - self.adapter._create_session_internal(session_id=self.session_id) + def test_create_session_connection_error_raises(self, mock_post): + mock_post.side_effect = requests.exceptions.ConnectionError("nope") + handler = _make_handler() + with self.assertRaises(AgentInteractionError): + handler._create_session(session_id="abc") @patch("requests.post") - def test_create_session_internal_request_exception_connection(self, mock_post): - mock_post.side_effect = requests.exceptions.ConnectionError( - "Connection refused" + def test_run_returns_final_text_from_events(self, mock_post): + events = [ + {"content": {"parts": [{"text": "ignored"}]}}, + {"content": {"parts": [{"text": "the answer"}]}}, + ] + mock_resp = MagicMock(status_code=200, headers={"X": "1"}) + mock_resp.text = "[]" + mock_resp.json.return_value = events + mock_resp.raise_for_status = MagicMock() + mock_post.return_value = mock_resp + handler = _make_handler() + result = handler._run(prompt_text="hi", session_id="s1") + self.assertEqual(result["final_text"], "the answer") + self.assertEqual(result["events"], events) + self.assertEqual(result["status_code"], 200) + + +class TestADKAgentInit(unittest.TestCase): + def test_init_success(self): + adapter = ADKAgent( + id=str(uuid.uuid4()), + config={ + "name": "my_app", + "endpoint": "http://fake-adk.com/", + "user_id": "alice", + "timeout": 60, + }, + ) + self.assertEqual(adapter.endpoint, "http://fake-adk.com") + self.assertEqual(adapter.name, "my_app") + self.assertEqual(adapter.user_id, "alice") + self.assertEqual(adapter.timeout, 60) + # The adapter routes through LiteLLM under a per-instance provider. + self.assertTrue( + adapter.litellm_model.startswith("hackagent_adk_") + and adapter.litellm_model.endswith("/my_app") ) - with self.assertRaisesRegex( - AgentInteractionError, - "Request failed creating session test_session_123: Connection refused", - ): - self.adapter._create_session_internal(session_id=self.session_id) + def test_init_default_timeout(self): + adapter = ADKAgent( + id="t1", + config={"name": "a", "endpoint": "http://x", "user_id": "u"}, + ) + self.assertEqual(adapter.timeout, 120) + + def test_init_missing_name(self): + with self.assertRaises(AgentConfigurationError): + ADKAgent(id="e1", config={"endpoint": "http://x", "user_id": "u"}) + + def test_init_missing_endpoint(self): + with self.assertRaises(AgentConfigurationError): + ADKAgent(id="e2", config={"name": "a", "user_id": "u"}) + + def test_init_missing_user_id(self): + with self.assertRaises(AgentConfigurationError): + ADKAgent(id="e3", config={"name": "a", "endpoint": "http://x"}) + + def test_init_registers_custom_provider(self): + import litellm + + adapter = ADKAgent( + id="reg1", + config={ + "name": "app", + "endpoint": "http://fake-adk.com", + "user_id": "u", + }, + ) + providers = [entry["provider"] for entry in litellm.custom_provider_map] + self.assertIn(f"hackagent_adk_{adapter.id}", providers) -class TestADKAgentHandleRequestValidation(unittest.TestCase): + +class TestADKAgentHandleRequest(unittest.TestCase): def setUp(self): - self.adapter_id = "adk_handle_req_test_agent" - self.config = { - "name": "handle_app", - "endpoint": "http://fake-handle.com", - "user_id": "handle_user", - } - self.adapter = ADKAgent(id=self.adapter_id, config=self.config) - - def test_handle_request_missing_prompt(self): - request_data = {"session_id": "sess_abc"} - response = self.adapter.handle_request(request_data) + self.adapter = ADKAgent( + id="h1", + config={ + "name": "test_app", + "endpoint": "http://fake-adk.com", + "user_id": "u", + "fresh_session_per_request": False, + }, + ) + + def test_missing_prompt_returns_400(self): + response = self.adapter.handle_request({}) self.assertEqual(response["status_code"], 400) - self.assertIn( - "Request data must include a 'prompt' field.", response["error_message"] + + @patch("requests.post") + def test_handle_request_success_routes_through_adk(self, mock_post): + # First call creates the session; second call is /run. + session_resp = MagicMock(status_code=200, raise_for_status=MagicMock()) + run_events = [{"content": {"parts": [{"text": "agent reply"}]}}] + run_resp = MagicMock(status_code=200, headers={"X": "1"}, text="[]") + run_resp.json.return_value = run_events + run_resp.raise_for_status = MagicMock() + mock_post.side_effect = [session_resp, run_resp] + + response = self.adapter.handle_request({"prompt": "hello"}) + + self.assertEqual(response["status_code"], 200) + self.assertEqual(response["generated_text"], "agent reply") + self.assertEqual(response["adapter_type"], "ADKAgent") + agent_data = response["agent_specific_data"] + self.assertEqual(agent_data.get("adk_events_list"), run_events) + self.assertEqual(agent_data.get("adk_session_id"), self.adapter.session_id) + + @patch("requests.post") + def test_handle_request_uses_explicit_session_id(self, mock_post): + session_resp = MagicMock(status_code=200, raise_for_status=MagicMock()) + run_resp = MagicMock(status_code=200, headers={}, text="[]") + run_resp.json.return_value = [{"content": {"parts": [{"text": "ok"}]}}] + run_resp.raise_for_status = MagicMock() + mock_post.side_effect = [session_resp, run_resp] + + response = self.adapter.handle_request( + {"prompt": "hi", "session_id": "explicit-123"} + ) + self.assertEqual(response["status_code"], 200) + self.assertEqual( + response["agent_specific_data"]["adk_session_id"], "explicit-123" ) - self.assertEqual(response["raw_request"], request_data) + # Session-create POST should target the explicit id. + session_call_url = mock_post.call_args_list[0][0][0] + self.assertIn("/sessions/explicit-123", session_call_url) - def test_handle_request_missing_session_id(self): - # Session ID is optional - adapter uses default if not provided - # This will fail with 500 when trying to create/verify the session - request_data = {"prompt": "Hello agent"} - response = self.adapter.handle_request(request_data) - self.assertEqual(response["status_code"], 500) - # Check that error message mentions session creation failure - self.assertIn( - "Failed to create/verify ADK session", - response["error_message"], + @patch("requests.post") + def test_handle_request_run_http_error_returns_500(self, mock_post): + session_resp = MagicMock(status_code=200, raise_for_status=MagicMock()) + run_resp = MagicMock(status_code=500, text="boom", headers={}) + run_resp.raise_for_status.side_effect = requests.exceptions.HTTPError( + response=run_resp ) - self.assertEqual(response["raw_request"], request_data) + mock_post.side_effect = [session_resp, run_resp] + response = self.adapter.handle_request({"prompt": "hi"}) + self.assertEqual(response["status_code"], 500) + self.assertIn("HTTP Error: 500", response["error_message"]) if __name__ == "__main__": diff --git a/tests/unit/adapters/test_ollama.py b/tests/unit/adapters/test_ollama.py index 116fde71..86273980 100644 --- a/tests/unit/adapters/test_ollama.py +++ b/tests/unit/adapters/test_ollama.py @@ -4,11 +4,11 @@ """ Unit tests for OllamaAgent. -These tests verify the functionality of the Ollama adapter including: -- Initialization with various configurations -- Request handling (generate and chat endpoints) -- Error handling -- Model information retrieval +Issue #379 moved the Ollama adapter onto LiteLLM (via the +``ollama_chat`` provider), so these tests patch ``litellm.completion`` +rather than ``requests.post`` for the chat path. Utility methods such as +``list_models`` and ``model_info`` still talk to the Ollama HTTP API +directly and so still mock ``requests``. """ import logging @@ -23,534 +23,198 @@ OllamaConfigurationError, ) -# Disable logging for tests to keep output clean logging.disable(logging.CRITICAL) -class TestOllamaAgentInit(unittest.TestCase): - """Test initialization of OllamaAgent.""" - - def test_init_success_with_minimal_config(self): - """Test successful initialization with minimum required config.""" - adapter_id = "ollama_test_agent_001" - config = { - "name": "llama3", - } +def _make_litellm_response(content: str = "ok") -> MagicMock: + response = MagicMock() + choice = MagicMock() + message = MagicMock() + message.content = content + message.tool_calls = None + message.reasoning_content = None + message.reasoning = None + message.provider_specific_fields = None + choice.message = message + choice.finish_reason = "stop" + response.choices = [choice] + response.usage = MagicMock(model_dump=MagicMock(return_value={"total_tokens": 5})) + response.model = "ollama_chat/llama3" + return response - adapter = OllamaAgent(id=adapter_id, config=config) - self.assertEqual(adapter.id, adapter_id) +class TestOllamaAgentInit(unittest.TestCase): + def test_init_success_minimal_config(self): + adapter = OllamaAgent(id="ol1", config={"name": "llama3"}) + self.assertEqual(adapter.id, "ol1") self.assertEqual(adapter.model_name, "llama3") self.assertEqual(adapter.api_base_url, "http://localhost:11434") + self.assertEqual(adapter.litellm_model, "ollama_chat/llama3") self.assertEqual(adapter.default_max_tokens, 100) - self.assertEqual(adapter.default_temperature, 0.8) - self.assertEqual(adapter.default_top_p, 0.95) def test_init_with_custom_endpoint(self): - """Test initialization with custom endpoint.""" - adapter_id = "ollama_test_agent_002" - config = { - "name": "mistral", - "endpoint": "http://192.168.1.100:11434", - } - - adapter = OllamaAgent(id=adapter_id, config=config) - - self.assertEqual(adapter.api_base_url, "http://192.168.1.100:11434") - - def test_init_with_endpoint_trailing_slash(self): - """Test that trailing slash is removed from endpoint.""" - adapter_id = "ollama_test_agent_003" - config = { - "name": "llama3", - "endpoint": "http://localhost:11434/", - } - - adapter = OllamaAgent(id=adapter_id, config=config) - - self.assertEqual(adapter.api_base_url, "http://localhost:11434") - - @patch.dict(os.environ, {"OLLAMA_BASE_URL": "http://env-ollama:11434"}) - def test_init_with_env_var_endpoint(self): - """Test initialization with endpoint from environment variable.""" - adapter_id = "ollama_test_agent_004" - config = { - "name": "llama3", - } - - adapter = OllamaAgent(id=adapter_id, config=config) - - self.assertEqual(adapter.api_base_url, "http://env-ollama:11434") - - def test_init_with_full_config(self): - """Test initialization with full configuration.""" - adapter_id = "ollama_test_agent_005" - config = { - "name": "codellama", - "endpoint": "http://localhost:11434", - "max_tokens": 200, - "temperature": 0.7, - "top_p": 0.9, - "top_k": 40, - "num_ctx": 4096, - "stream": False, - "timeout": 60, - } - - adapter = OllamaAgent(id=adapter_id, config=config) - - self.assertEqual(adapter.model_name, "codellama") - self.assertEqual(adapter.default_max_tokens, 200) - self.assertEqual(adapter.default_temperature, 0.7) - self.assertEqual(adapter.default_top_p, 0.9) - self.assertEqual(adapter.default_top_k, 40) - self.assertEqual(adapter.default_num_ctx, 4096) - self.assertEqual(adapter.default_stream, False) - self.assertEqual(adapter.timeout, 60) - - def test_init_with_default_thinking(self): - """Test initialization stores default thinking behavior.""" adapter = OllamaAgent( - id="ollama_test_agent_thinking", - config={"name": "qwen3", "thinking": False}, + id="ol2", + config={"name": "mistral", "endpoint": "http://host:11434"}, ) + self.assertEqual(adapter.api_base_url, "http://host:11434") - self.assertFalse(adapter.default_thinking) - - def test_init_missing_name_raises_error(self): - """Test that missing 'name' config raises error.""" - with self.assertRaisesRegex( - OllamaConfigurationError, "Missing required configuration key 'name'" - ): - OllamaAgent(id="err_agent_1", config={}) - - -class TestOllamaAgentBuildOptions(unittest.TestCase): - """Test _build_options method of OllamaAgent.""" - - def setUp(self): - """Set up test fixtures.""" - self.adapter = OllamaAgent( - id="test_options_adapter", + def test_init_normalizes_trailing_slash_and_api_suffix(self): + adapter = OllamaAgent( + id="ol3", config={ "name": "llama3", - "max_tokens": 100, - "temperature": 0.8, - "top_p": 0.95, + "endpoint": "http://host:11434/api/chat/", }, ) + self.assertEqual(adapter.api_base_url, "http://host:11434") - def test_build_options_with_defaults(self): - """Test building options with default values.""" - options = self.adapter._build_options() - - self.assertEqual(options["num_predict"], 100) - self.assertEqual(options["temperature"], 0.8) - self.assertEqual(options["top_p"], 0.95) - - def test_build_options_with_overrides(self): - """Test building options with override values.""" - options = self.adapter._build_options( - max_tokens=200, temperature=0.5, top_p=0.7 - ) - - self.assertEqual(options["num_predict"], 200) - self.assertEqual(options["temperature"], 0.5) - self.assertEqual(options["top_p"], 0.7) + @patch.dict(os.environ, {"OLLAMA_BASE_URL": "http://env-ollama:11434"}) + def test_init_picks_up_env_var(self): + adapter = OllamaAgent(id="ol4", config={"name": "llama3"}) + self.assertEqual(adapter.api_base_url, "http://env-ollama:11434") - def test_build_options_with_additional_params(self): - """Test building options with additional Ollama parameters.""" - options = self.adapter._build_options(seed=42, repeat_penalty=1.1, stop=["END"]) + def test_init_missing_name_raises(self): + with self.assertRaises(OllamaConfigurationError): + OllamaAgent(id="err", config={}) - self.assertEqual(options["seed"], 42) - self.assertEqual(options["repeat_penalty"], 1.1) - self.assertEqual(options["stop"], ["END"]) + def test_init_preserves_existing_provider_prefix(self): + """If the user supplies ``ollama/`` it shouldn't be re-prefixed.""" + adapter = OllamaAgent(id="ol5", config={"name": "ollama/llama3"}) + self.assertEqual(adapter.litellm_model, "ollama/llama3") class TestOllamaAgentHandleRequest(unittest.TestCase): - """Test handle_request method of OllamaAgent.""" - def setUp(self): - """Set up test fixtures.""" - self.adapter_id = "ollama_handle_req_test" - self.config = { - "name": "llama3", - "endpoint": "http://localhost:11434", - "max_tokens": 50, - "temperature": 0.5, - } - self.adapter = OllamaAgent(id=self.adapter_id, config=self.config) - - def test_handle_request_missing_prompt_and_messages(self): - """Test that missing both prompt and messages returns error.""" - request_data = {"temperature": 0.5} - response = self.adapter.handle_request(request_data) + self.adapter = OllamaAgent( + id="oh1", + config={"name": "llama3", "max_tokens": 50, "temperature": 0.5}, + ) + def test_missing_prompt_and_messages_returns_400(self): + response = self.adapter.handle_request({}) self.assertEqual(response["status_code"], 400) self.assertIn( "Request data must include either 'messages' or 'prompt'", response["error_message"], ) - self.assertEqual(response["raw_request"], request_data) - - @patch("hackagent.router.adapters.ollama.requests.post") - def test_handle_request_with_prompt_success(self, mock_post): - """Test successful request with prompt text using generate endpoint.""" - # Mock the Ollama API response - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "model": "llama3", - "response": "This is a test response from Ollama.", - "done": True, - "eval_count": 10, - "eval_duration": 1000000, - "prompt_eval_count": 5, - "total_duration": 2000000, - } - mock_response.raise_for_status = MagicMock() - mock_post.return_value = mock_response - request_data = {"prompt": "Hello, Ollama!"} - response = self.adapter.handle_request(request_data) + @patch("litellm.completion") + def test_handle_request_with_prompt_success(self, mock_completion): + mock_completion.return_value = _make_litellm_response("Hello!") - self.assertEqual(response["status_code"], 200) - self.assertIsNone(response["error_message"]) - self.assertEqual( - response["processed_response"], "This is a test response from Ollama." - ) - self.assertEqual(response["raw_request"], request_data) - self.assertEqual(response["agent_specific_data"]["model_name"], "llama3") - - # Verify the API was called correctly - mock_post.assert_called_once() - call_args = mock_post.call_args - self.assertIn("/api/generate", call_args[0][0]) - - @patch("hackagent.router.adapters.ollama.requests.post") - def test_handle_request_with_messages_success(self, mock_post): - """Test successful request with messages using chat endpoint.""" - # Mock the Ollama API response - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "model": "llama3", - "message": {"role": "assistant", "content": "Chat response from Ollama."}, - "done": True, - "eval_count": 15, - "total_duration": 3000000, - } - mock_response.raise_for_status = MagicMock() - mock_post.return_value = mock_response - - request_data = { - "messages": [ - {"role": "user", "content": "Hello!"}, - ] - } - response = self.adapter.handle_request(request_data) + response = self.adapter.handle_request({"prompt": "Hi"}) self.assertEqual(response["status_code"], 200) - self.assertIsNone(response["error_message"]) - self.assertEqual(response["processed_response"], "Chat response from Ollama.") - - # Verify the chat endpoint was called - mock_post.assert_called_once() - call_args = mock_post.call_args - self.assertIn("/api/chat", call_args[0][0]) - - @patch("hackagent.router.adapters.ollama.requests.post") - def test_handle_request_strips_text_before_think_close(self, mock_post): - """Test that text before and including '' is removed.""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "model": "llama3", - "response": "analysis pathVisible final output", - "done": True, - } - mock_response.raise_for_status = MagicMock() - mock_post.return_value = mock_response - - response = self.adapter.handle_request({"prompt": "Hello"}) - + self.assertEqual(response["generated_text"], "Hello!") + kwargs = mock_completion.call_args.kwargs + self.assertEqual(kwargs["model"], "ollama_chat/llama3") + self.assertEqual(kwargs["api_base"], "http://localhost:11434") + self.assertEqual(kwargs["messages"], [{"role": "user", "content": "Hi"}]) + + @patch("litellm.completion") + def test_handle_request_with_messages_success(self, mock_completion): + mock_completion.return_value = _make_litellm_response("ack") + messages = [ + {"role": "system", "content": "be terse"}, + {"role": "user", "content": "go"}, + ] + response = self.adapter.handle_request({"messages": messages}) self.assertEqual(response["status_code"], 200) - self.assertEqual(response["processed_response"], "Visible final output") - self.assertEqual(response["generated_text"], "Visible final output") - - @patch("hackagent.router.adapters.ollama.requests.post") - def test_handle_request_connection_error(self, mock_post): - """Test handling of connection error.""" - mock_post.side_effect = requests.exceptions.ConnectionError( - "Connection refused" - ) - - request_data = {"prompt": "Hello"} - response = self.adapter.handle_request(request_data) - - self.assertEqual(response["status_code"], 503) - self.assertIn("connection error", response["error_message"].lower()) - - @patch("hackagent.router.adapters.ollama.requests.post") - def test_handle_timeout_error(self, mock_post): - """Test handling of timeout error.""" - mock_post.side_effect = requests.exceptions.Timeout("Request timed out") - - request_data = {"prompt": "Hello"} - response = self.adapter.handle_request(request_data) - - self.assertEqual(response["status_code"], 503) - self.assertIn("timed out", response["error_message"].lower()) - - @patch("hackagent.router.adapters.ollama.requests.post") - def test_handle_request_http_error(self, mock_post): - """Test handling of HTTP error.""" - mock_response = MagicMock() - mock_response.status_code = 404 - mock_response.text = "Model not found" - http_error = requests.exceptions.HTTPError("404 Not Found") - http_error.response = mock_response - mock_post.side_effect = http_error - - request_data = {"prompt": "Hello"} - response = self.adapter.handle_request(request_data) - - self.assertEqual(response["status_code"], 404) - self.assertIn("HTTP error", response["error_message"]) - - @patch("hackagent.router.adapters.ollama.requests.post") - def test_handle_request_with_system_prompt(self, mock_post): - """Test request with system prompt.""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "model": "llama3", - "response": "Response with system context.", - "done": True, - } - mock_response.raise_for_status = MagicMock() - mock_post.return_value = mock_response - - request_data = { - "prompt": "What's the weather?", - "system": "You are a helpful weather assistant.", - } - response = self.adapter.handle_request(request_data) + self.assertEqual(mock_completion.call_args.kwargs["messages"], messages) - self.assertEqual(response["status_code"], 200) - - # Verify system prompt was included in the request - call_args = mock_post.call_args - request_body = call_args[1]["json"] - self.assertEqual(request_body["system"], "You are a helpful weather assistant.") - - @patch("hackagent.router.adapters.ollama.requests.post") - def test_handle_request_with_custom_parameters(self, mock_post): - """Test request with custom generation parameters.""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "model": "llama3", - "response": "Custom params response.", - "done": True, - } - mock_response.raise_for_status = MagicMock() - mock_post.return_value = mock_response - - request_data = { - "prompt": "Hello", - "max_tokens": 200, - "temperature": 0.3, - "top_k": 20, - "seed": 42, - } - response = self.adapter.handle_request(request_data) - - self.assertEqual(response["status_code"], 200) - - # Verify options were included in the request - call_args = mock_post.call_args - request_body = call_args[1]["json"] - self.assertEqual(request_body["options"]["num_predict"], 200) - self.assertEqual(request_body["options"]["temperature"], 0.3) - self.assertEqual(request_body["options"]["top_k"], 20) - self.assertEqual(request_body["options"]["seed"], 42) - - @patch("hackagent.router.adapters.ollama.requests.post") - def test_handle_request_forwards_thinking_override(self, mock_post): - """Per-request thinking is forwarded as Ollama `think`.""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "model": "llama3", - "response": "Thinking disabled.", - "done": True, - } - mock_response.raise_for_status = MagicMock() - mock_post.return_value = mock_response - - response = self.adapter.handle_request({"prompt": "Hello", "thinking": False}) - - self.assertEqual(response["status_code"], 200) - request_body = mock_post.call_args[1]["json"] - self.assertIn("think", request_body) - self.assertFalse(request_body["think"]) - - @patch("hackagent.router.adapters.ollama.requests.post") - def test_handle_request_uses_default_thinking_from_config(self, mock_post): - """Adapter-level thinking default is applied when request omits it.""" + @patch("litellm.completion") + def test_extra_generation_options_pass_through(self, mock_completion): + mock_completion.return_value = _make_litellm_response("hi") adapter = OllamaAgent( - id="ollama_default_think", - config={"name": "llama3", "thinking": False}, - ) - - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "model": "llama3", - "response": "Default thinking applied.", - "done": True, - } - mock_response.raise_for_status = MagicMock() - mock_post.return_value = mock_response - - response = adapter.handle_request({"prompt": "Hello"}) - - self.assertEqual(response["status_code"], 200) - request_body = mock_post.call_args[1]["json"] - self.assertIn("think", request_body) - self.assertFalse(request_body["think"]) - - -class TestOllamaAgentUtilityMethods(unittest.TestCase): - """Test utility methods of OllamaAgent.""" - - def setUp(self): - """Set up test fixtures.""" - self.adapter = OllamaAgent( - id="test_utility_adapter", + id="oh2", config={ "name": "llama3", - "endpoint": "http://localhost:11434", + "top_k": 40, + "num_ctx": 8192, + "stream": True, }, ) + adapter.handle_request({"prompt": "Hi"}) + kwargs = mock_completion.call_args.kwargs + self.assertEqual(kwargs.get("top_k"), 40) + self.assertEqual(kwargs.get("num_ctx"), 8192) + self.assertEqual(kwargs.get("stream"), True) + + @patch("litellm.completion") + def test_thinking_true_translates_to_think(self, mock_completion): + mock_completion.return_value = _make_litellm_response("yo") + adapter = OllamaAgent(id="oh3", config={"name": "llama3", "thinking": True}) + adapter.handle_request({"prompt": "Hi"}) + kwargs = mock_completion.call_args.kwargs + self.assertIs(kwargs.get("think"), True) + self.assertNotIn("thinking", kwargs) + + @patch("litellm.completion") + def test_thinking_false_translates_to_think_false(self, mock_completion): + mock_completion.return_value = _make_litellm_response("yo") + adapter = OllamaAgent(id="oh4", config={"name": "llama3", "thinking": False}) + adapter.handle_request({"prompt": "Hi"}) + self.assertIs(mock_completion.call_args.kwargs.get("think"), False) + + @patch("litellm.completion") + def test_thinking_request_overrides_config_default(self, mock_completion): + mock_completion.return_value = _make_litellm_response("yo") + adapter = OllamaAgent(id="oh5", config={"name": "llama3", "thinking": False}) + adapter.handle_request({"prompt": "Hi", "thinking": True}) + self.assertIs(mock_completion.call_args.kwargs.get("think"), True) + + @patch("litellm.completion") + def test_handle_request_api_error(self, mock_completion): + mock_completion.side_effect = RuntimeError("connection refused") + response = self.adapter.handle_request({"prompt": "Hi"}) + self.assertEqual(response["status_code"], 500) + self.assertIn("connection refused", response["error_message"]) + + +class TestOllamaAgentUtilities(unittest.TestCase): + def setUp(self): + self.adapter = OllamaAgent(id="util", config={"name": "llama3"}) - @patch("hackagent.router.adapters.ollama.requests.get") + @patch("requests.get") def test_list_models_success(self, mock_get): - """Test successful model listing.""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "models": [ - {"name": "llama3:latest", "size": 4000000000}, - {"name": "mistral:latest", "size": 3500000000}, - ] + mock_resp = MagicMock() + mock_resp.json.return_value = { + "models": [{"name": "llama3"}, {"name": "mistral:latest"}] } - mock_response.raise_for_status = MagicMock() - mock_get.return_value = mock_response - + mock_resp.raise_for_status = MagicMock() + mock_get.return_value = mock_resp models = self.adapter.list_models() - self.assertEqual(len(models), 2) - self.assertEqual(models[0]["name"], "llama3:latest") - - @patch("hackagent.router.adapters.ollama.requests.get") - def test_list_models_error(self, mock_get): - """Test model listing with error.""" - mock_get.side_effect = requests.exceptions.ConnectionError("Connection refused") - models = self.adapter.list_models() - - self.assertEqual(models, []) + @patch("requests.get") + def test_list_models_error_returns_empty_list(self, mock_get): + mock_get.side_effect = requests.exceptions.ConnectionError("nope") + self.assertEqual(self.adapter.list_models(), []) - @patch("hackagent.router.adapters.ollama.requests.post") + @patch("requests.post") def test_model_info_success(self, mock_post): - """Test successful model info retrieval.""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "modelfile": "FROM llama3...", - "parameters": "temperature 0.8", - "template": "...", - } - mock_response.raise_for_status = MagicMock() - mock_post.return_value = mock_response - - info = self.adapter.model_info() - - self.assertIn("modelfile", info) - - @patch("hackagent.router.adapters.ollama.requests.post") - def test_model_info_error(self, mock_post): - """Test model info with error.""" - mock_post.side_effect = requests.exceptions.ConnectionError( - "Connection refused" - ) - - info = self.adapter.model_info() - - self.assertEqual(info, {}) - - @patch("hackagent.router.adapters.ollama.requests.get") - def test_is_available_true(self, mock_get): - """Test is_available returns True when model exists.""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "models": [ - {"name": "llama3:latest"}, - {"name": "mistral:latest"}, - ] - } - mock_response.raise_for_status = MagicMock() - mock_get.return_value = mock_response - + mock_resp = MagicMock() + mock_resp.json.return_value = {"license": "mit"} + mock_resp.raise_for_status = MagicMock() + mock_post.return_value = mock_resp + self.assertEqual(self.adapter.model_info(), {"license": "mit"}) + + @patch("requests.post") + def test_model_info_error_returns_empty_dict(self, mock_post): + mock_post.side_effect = requests.exceptions.ConnectionError("nope") + self.assertEqual(self.adapter.model_info(), {}) + + @patch.object(OllamaAgent, "list_models") + def test_is_available_true_when_model_present(self, mock_list): + mock_list.return_value = [{"name": "llama3:latest"}] self.assertTrue(self.adapter.is_available()) - @patch("hackagent.router.adapters.ollama.requests.get") - def test_is_available_false(self, mock_get): - """Test is_available returns False when model doesn't exist.""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "models": [ - {"name": "other-model:latest"}, - ] - } - mock_response.raise_for_status = MagicMock() - mock_get.return_value = mock_response - + @patch.object(OllamaAgent, "list_models") + def test_is_available_false_when_model_missing(self, mock_list): + mock_list.return_value = [{"name": "mistral"}] self.assertFalse(self.adapter.is_available()) - @patch("hackagent.router.adapters.ollama.requests.get") - def test_is_available_connection_error(self, mock_get): - """Test is_available returns False on connection error.""" - mock_get.side_effect = requests.exceptions.ConnectionError("Connection refused") - - self.assertFalse(self.adapter.is_available()) - - -class TestOllamaAgentIntegration(unittest.TestCase): - """Integration-style tests for OllamaAgent.""" - - def test_adapter_identifier(self): - """Test that adapter returns correct identifier.""" - adapter = OllamaAgent( - id="integration_test_adapter", - config={"name": "llama3"}, - ) - - self.assertEqual(adapter.get_identifier(), "integration_test_adapter") - - def test_adapter_with_model_tag(self): - """Test adapter with model name including tag.""" - adapter = OllamaAgent( - id="tagged_model_adapter", - config={"name": "llama3:8b-instruct-q4_0"}, - ) - - self.assertEqual(adapter.model_name, "llama3:8b-instruct-q4_0") - if __name__ == "__main__": unittest.main() diff --git a/tests/unit/adapters/test_openai.py b/tests/unit/adapters/test_openai.py index c83dad14..a6190949 100644 --- a/tests/unit/adapters/test_openai.py +++ b/tests/unit/adapters/test_openai.py @@ -1,8 +1,16 @@ # Copyright 2026 - AI4I. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +""" +Unit tests for the OpenAI agent adapter. + +Issue #379 moved every chat-completion adapter onto LiteLLM, so these +tests exercise the OpenAI adapter by patching ``litellm.completion`` +rather than the OpenAI SDK directly. +""" import logging +import os import unittest from unittest.mock import MagicMock, patch @@ -11,745 +19,217 @@ OpenAIConfigurationError, ) -# Disable logging for tests to keep output clean logging.disable(logging.CRITICAL) +def _make_litellm_response(content: str = "ok", *, tool_calls=None) -> MagicMock: + """Build a minimal mock of a litellm ModelResponse.""" + response = MagicMock() + choice = MagicMock() + message = MagicMock() + message.content = content + message.tool_calls = tool_calls + message.reasoning_content = None + message.reasoning = None + message.provider_specific_fields = None + choice.message = message + choice.finish_reason = "stop" + response.choices = [choice] + response.usage = MagicMock(model_dump=MagicMock(return_value={"total_tokens": 10})) + response.model = "gpt-4" + return response + + class TestOpenAIAgentInit(unittest.TestCase): - """Test initialization of OpenAIAgent.""" - - @patch("hackagent.router.adapters.openai.OPENAI_AVAILABLE", True) - @patch("hackagent.router.adapters.openai.OpenAI") - def test_init_success_with_required_config(self, mock_openai_class): - """Test successful initialization with minimum required config.""" - adapter_id = "openai_test_agent_001" - config = { - "name": "gpt-4", - } - mock_client = MagicMock() - mock_openai_class.return_value = mock_client - - adapter = OpenAIAgent(id=adapter_id, config=config) - - self.assertEqual(adapter.id, adapter_id) + def test_init_success_with_required_config(self): + adapter = OpenAIAgent(id="o1", config={"name": "gpt-4"}) + self.assertEqual(adapter.id, "o1") self.assertEqual(adapter.model_name, "gpt-4") + # OpenAIAgent forces the openai/ provider prefix when none is set. + self.assertEqual(adapter.litellm_model, "openai/gpt-4") self.assertIsNone(adapter.api_base_url) self.assertEqual(adapter.default_temperature, 1.0) - mock_openai_class.assert_called_once() - - @patch("hackagent.router.adapters.openai.OPENAI_AVAILABLE", True) - @patch("hackagent.router.adapters.openai.OpenAI") - @patch.dict("os.environ", {"CUSTOM_API_KEY": "test-key-123"}) - def test_init_with_api_key_from_env(self, mock_openai_class): - """Test initialization with API key from environment variable.""" - adapter_id = "openai_test_agent_002" - config = { - "name": "gpt-3.5-turbo", - "api_key": "CUSTOM_API_KEY", - } - mock_client = MagicMock() - mock_openai_class.return_value = mock_client - - adapter = OpenAIAgent(id=adapter_id, config=config) - - self.assertEqual(adapter.actual_api_key, "test-key-123") - mock_openai_class.assert_called_once_with(api_key="test-key-123", timeout=120) - - @patch("hackagent.router.adapters.openai.OPENAI_AVAILABLE", True) - @patch("hackagent.router.adapters.openai.OpenAI") - def test_init_with_custom_endpoint(self, mock_openai_class): - """Test initialization with custom API endpoint.""" - adapter_id = "openai_test_agent_003" - config = { - "name": "gpt-4", - "endpoint": "https://custom.openai.proxy.com/v1", - } - mock_client = MagicMock() - mock_openai_class.return_value = mock_client - - adapter = OpenAIAgent(id=adapter_id, config=config) - - self.assertEqual(adapter.api_base_url, "https://custom.openai.proxy.com/v1") - # Verify OpenAI was called with the correct base_url (api_key may vary) - mock_openai_class.assert_called_once() - call_kwargs = mock_openai_class.call_args.kwargs - self.assertEqual( - call_kwargs.get("base_url"), "https://custom.openai.proxy.com/v1" + + def test_init_with_custom_endpoint(self): + adapter = OpenAIAgent( + id="o2", + config={ + "name": "gpt-4", + "endpoint": "https://custom.proxy/v1", + }, ) + self.assertEqual(adapter.api_base_url, "https://custom.proxy/v1") + # When there's no API key, a placeholder is used so the underlying + # OpenAI client doesn't choke. + self.assertEqual(adapter.actual_api_key, "not-required") - @patch("hackagent.router.adapters.openai.OPENAI_AVAILABLE", True) - @patch("hackagent.router.adapters.openai.OpenAI") - def test_init_with_generation_parameters(self, mock_openai_class): - """Test initialization with custom generation parameters.""" - adapter_id = "openai_test_agent_004" - config = { - "name": "gpt-4", - "max_tokens": 500, - "temperature": 0.7, - "tools": [{"type": "function", "function": {"name": "test_func"}}], - "tool_choice": "auto", - } - mock_client = MagicMock() - mock_openai_class.return_value = mock_client - - adapter = OpenAIAgent(id=adapter_id, config=config) + def test_init_with_custom_endpoint_defaults_model_name(self): + adapter = OpenAIAgent(id="o3", config={"endpoint": "https://example.com/v1"}) + self.assertEqual(adapter.model_name, "default") + @patch.dict(os.environ, {"CUSTOM_API_KEY": "sk-test"}) + def test_init_with_api_key_from_env(self): + adapter = OpenAIAgent( + id="o4", + config={"name": "gpt-4", "api_key": "CUSTOM_API_KEY"}, + ) + self.assertEqual(adapter.actual_api_key, "sk-test") + + def test_init_with_generation_parameters(self): + adapter = OpenAIAgent( + id="o5", + config={ + "name": "gpt-4", + "max_tokens": 500, + "temperature": 0.7, + "tools": [{"type": "function", "function": {"name": "f"}}], + "tool_choice": "auto", + }, + ) self.assertEqual(adapter.default_max_tokens, 500) self.assertEqual(adapter.default_temperature, 0.7) self.assertIsNotNone(adapter.default_tools) self.assertEqual(adapter.default_tool_choice, "auto") - def test_init_missing_name_raises_error(self): - """Test that missing 'name' config raises error.""" - with self.assertRaisesRegex( - OpenAIConfigurationError, "Missing required configuration key 'name'" - ): - OpenAIAgent(id="err_agent_1", config={}) + def test_init_missing_name_no_endpoint_raises(self): + with self.assertRaises(OpenAIConfigurationError): + OpenAIAgent(id="err", config={}) - @patch("hackagent.router.adapters.openai.OPENAI_AVAILABLE", False) - def test_init_without_openai_installed_raises_error(self): - """Test that initialization fails gracefully when OpenAI SDK not installed.""" - with self.assertRaisesRegex( - OpenAIConfigurationError, "OpenAI SDK is not installed" - ): - OpenAIAgent(id="err_agent_2", config={"name": "gpt-4"}) + def test_init_preserves_existing_provider_prefix(self): + """A user-supplied ``openai/`` shouldn't get double-prefixed.""" + adapter = OpenAIAgent(id="o6", config={"name": "openai/gpt-4"}) + self.assertEqual(adapter.litellm_model, "openai/gpt-4") class TestOpenAIAgentHandleRequest(unittest.TestCase): - """Test handle_request method of OpenAIAgent.""" - def setUp(self): - """Set up test fixtures.""" - self.adapter_id = "openai_handle_req_test" - self.config = { - "name": "gpt-4", - "max_tokens": 100, - "temperature": 0.8, - } - - # Patch at module level - self.openai_patch = patch( - "hackagent.router.adapters.openai.OPENAI_AVAILABLE", True + self.adapter = OpenAIAgent( + id="oh1", + config={"name": "gpt-4", "max_tokens": 100, "temperature": 0.8}, ) - self.openai_class_patch = patch("hackagent.router.adapters.openai.OpenAI") - - self.openai_patch.start() - self.mock_openai_class = self.openai_class_patch.start() - - self.mock_client = MagicMock() - self.mock_openai_class.return_value = self.mock_client - - self.adapter = OpenAIAgent(id=self.adapter_id, config=self.config) - - def tearDown(self): - """Clean up patches.""" - self.openai_patch.stop() - self.openai_class_patch.stop() - - def test_handle_request_missing_prompt_and_messages(self): - """Test that missing both prompt and messages returns error.""" - request_data = {"temperature": 0.5} - response = self.adapter.handle_request(request_data) + def test_missing_prompt_and_messages_returns_400(self): + response = self.adapter.handle_request({"temperature": 0.5}) self.assertEqual(response["status_code"], 400) self.assertIn( "Request data must include either 'messages' or 'prompt'", response["error_message"], ) - self.assertEqual(response["raw_request"], request_data) - - def test_handle_request_with_prompt_success(self): - """Test successful request with prompt text.""" - # Mock the OpenAI API response - mock_message = MagicMock() - mock_message.content = "This is a test response" - mock_message.tool_calls = None - - mock_choice = MagicMock() - mock_choice.message = mock_message - mock_choice.finish_reason = "stop" - - mock_usage = MagicMock() - mock_usage.model_dump.return_value = { - "prompt_tokens": 10, - "completion_tokens": 20, - "total_tokens": 30, - } - - mock_response = MagicMock() - mock_response.choices = [mock_choice] - mock_response.usage = mock_usage - mock_response.model = "gpt-4" - self.mock_client.chat.completions.create.return_value = mock_response + @patch("litellm.completion") + def test_handle_request_with_prompt_success(self, mock_completion): + mock_completion.return_value = _make_litellm_response("Hello back") + response = self.adapter.handle_request({"prompt": "Hi"}) - request_data = {"prompt": "Hello, how are you?"} - response = self.adapter.handle_request(request_data) - - # Verify response structure self.assertEqual(response["status_code"], 200) self.assertIsNone(response["error_message"]) - self.assertEqual(response["generated_text"], "This is a test response") - self.assertEqual(response["agent_id"], self.adapter_id) + self.assertEqual(response["generated_text"], "Hello back") self.assertEqual(response["adapter_type"], "OpenAIAgent") + kwargs = mock_completion.call_args.kwargs + self.assertEqual(kwargs["model"], "openai/gpt-4") + self.assertEqual(kwargs["messages"], [{"role": "user", "content": "Hi"}]) - # Verify agent specific data - self.assertEqual(response["agent_specific_data"]["model_name"], "gpt-4") - self.assertEqual(response["agent_specific_data"]["finish_reason"], "stop") - self.assertIsNotNone(response["agent_specific_data"]["usage"]) - - # Verify the API was called correctly - self.mock_client.chat.completions.create.assert_called_once() - call_kwargs = self.mock_client.chat.completions.create.call_args[1] - self.assertEqual(call_kwargs["model"], "gpt-4") - self.assertEqual( - call_kwargs["messages"], - [{"role": "user", "content": "Hello, how are you?"}], - ) - - def test_handle_request_with_messages_success(self): - """Test successful request with pre-formatted messages.""" - # Mock the OpenAI API response - mock_message = MagicMock() - mock_message.content = "Response to conversation" - mock_message.tool_calls = None - - mock_choice = MagicMock() - mock_choice.message = mock_message - mock_choice.finish_reason = "stop" - - mock_usage = MagicMock() - mock_usage.model_dump.return_value = { - "prompt_tokens": 15, - "completion_tokens": 25, - "total_tokens": 40, - } - - mock_response = MagicMock() - mock_response.choices = [mock_choice] - mock_response.usage = mock_usage - mock_response.model = "gpt-4" - - self.mock_client.chat.completions.create.return_value = mock_response - - request_data = { - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello!"}, - ] - } - response = self.adapter.handle_request(request_data) - - self.assertEqual(response["status_code"], 200) - self.assertIsNone(response["error_message"]) - self.assertEqual(response["generated_text"], "Response to conversation") - - # Verify messages were passed correctly - call_kwargs = self.mock_client.chat.completions.create.call_args[1] - self.assertEqual(len(call_kwargs["messages"]), 2) - self.assertEqual(call_kwargs["messages"][0]["role"], "system") - - def test_handle_request_with_tool_calls(self): - """Test request that returns tool calls.""" - # Mock a tool call response - mock_tool_call = MagicMock() - mock_tool_call.id = "call_123" - mock_tool_call.type = "function" - mock_tool_call.function.name = "get_weather" - mock_tool_call.function.arguments = '{"location": "San Francisco"}' - - mock_message = MagicMock() - mock_message.content = None - mock_message.tool_calls = [mock_tool_call] - - mock_choice = MagicMock() - mock_choice.message = mock_message - mock_choice.finish_reason = "tool_calls" - - mock_usage = MagicMock() - mock_usage.model_dump.return_value = { - "prompt_tokens": 50, - "completion_tokens": 30, - "total_tokens": 80, - } - - mock_response = MagicMock() - mock_response.choices = [mock_choice] - mock_response.usage = mock_usage - mock_response.model = "gpt-4" - - self.mock_client.chat.completions.create.return_value = mock_response - - tools = [ - { - "type": "function", - "function": { - "name": "get_weather", - "description": "Get weather for a location", - "parameters": { - "type": "object", - "properties": {"location": {"type": "string"}}, - }, - }, - } - ] - - request_data = { - "prompt": "What's the weather in San Francisco?", - "tools": tools, - "tool_choice": "auto", - } - response = self.adapter.handle_request(request_data) - - self.assertEqual(response["status_code"], 200) - self.assertIsNone(response["error_message"]) - self.assertEqual(response["agent_specific_data"]["finish_reason"], "tool_calls") - - # Verify tool calls in response - tool_calls = response["agent_specific_data"]["tool_calls"] - self.assertIsNotNone(tool_calls) - self.assertEqual(len(tool_calls), 1) - self.assertEqual(tool_calls[0]["id"], "call_123") - self.assertEqual(tool_calls[0]["function"]["name"], "get_weather") - - def test_handle_request_api_timeout_error(self): - """Test handling of API timeout errors.""" - import openai - - self.mock_client.chat.completions.create.side_effect = openai.APITimeoutError( - "Request timed out" - ) - - request_data = {"prompt": "Hello"} - response = self.adapter.handle_request(request_data) - - self.assertEqual(response["status_code"], 500) - self.assertIn("timeout", response["error_message"]) - self.assertIn("Request timed out", response["error_message"]) - - def test_handle_request_rate_limit_error(self): - """Test handling of rate limit errors.""" - import openai - - # Create mock response and body for RateLimitError - mock_response = MagicMock() - mock_response.status_code = 429 - mock_body = {"error": {"message": "Rate limit exceeded"}} - - error = openai.RateLimitError( - "Rate limit exceeded", response=mock_response, body=mock_body - ) - self.mock_client.chat.completions.create.side_effect = error - - request_data = {"prompt": "Hello"} - response = self.adapter.handle_request(request_data) - - self.assertEqual(response["status_code"], 500) - self.assertIn("rate_limit", response["error_message"]) - - def test_handle_request_connection_error(self): - """Test handling of connection errors.""" - import openai - - # APIConnectionError requires a request parameter - mock_request = MagicMock() - error = openai.APIConnectionError( - message="Connection failed", request=mock_request - ) - self.mock_client.chat.completions.create.side_effect = error - - request_data = {"prompt": "Hello"} - response = self.adapter.handle_request(request_data) - - self.assertEqual(response["status_code"], 500) - self.assertIn("connection", response["error_message"]) - - @patch("hackagent.router.adapters.openai.time.sleep", return_value=None) - def test_handle_request_connection_error_retries_then_succeeds(self, _mock_sleep): - """Test transient connection errors are retried and can recover.""" - import openai - - mock_request = MagicMock() - error = openai.APIConnectionError( - message="Connection failed", request=mock_request - ) - - mock_message = MagicMock() - mock_message.content = "Recovered response" - mock_message.tool_calls = None - - mock_choice = MagicMock() - mock_choice.message = mock_message - mock_choice.finish_reason = "stop" - - mock_usage = MagicMock() - mock_usage.model_dump.return_value = {"total_tokens": 10} - - mock_response = MagicMock() - mock_response.choices = [mock_choice] - mock_response.usage = mock_usage - mock_response.model = "gpt-4" - - self.mock_client.chat.completions.create.side_effect = [ - error, - error, - mock_response, - ] - - response = self.adapter.handle_request({"prompt": "Hello"}) - - self.assertEqual(response["status_code"], 200) - self.assertEqual(response["generated_text"], "Recovered response") - self.assertEqual(self.mock_client.chat.completions.create.call_count, 3) - - @patch("hackagent.router.adapters.openai.time.sleep", return_value=None) - def test_handle_request_connection_error_stops_after_five_retries( - self, _mock_sleep - ): - """Test connection retry budget is capped at 5 retries.""" - import openai - - mock_request = MagicMock() - error = openai.APIConnectionError( - message="Connection failed", request=mock_request - ) - - self.mock_client.chat.completions.create.side_effect = [error] * 6 - - response = self.adapter.handle_request({"prompt": "Hello"}) - - self.assertEqual(response["status_code"], 500) - self.assertIn("connection", response["error_message"]) - # First attempt + 5 retries = 6 total calls. - self.assertEqual(self.mock_client.chat.completions.create.call_count, 6) - - def test_handle_request_with_parameter_overrides(self): - """Test that request parameters override defaults.""" - mock_message = MagicMock() - mock_message.content = "Response" - mock_message.tool_calls = None - - mock_choice = MagicMock() - mock_choice.message = mock_message - mock_choice.finish_reason = "stop" - - mock_usage = MagicMock() - mock_usage.model_dump.return_value = {} - - mock_response = MagicMock() - mock_response.choices = [mock_choice] - mock_response.usage = mock_usage - mock_response.model = "gpt-4" - - self.mock_client.chat.completions.create.return_value = mock_response - - request_data = { - "prompt": "Test", - "max_tokens": 200, # Override default of 100 - "temperature": 0.5, # Override default of 0.8 - } - response = self.adapter.handle_request(request_data) - - self.assertEqual(response["status_code"], 200) - - # Verify overridden parameters were used - call_kwargs = self.mock_client.chat.completions.create.call_args[1] - self.assertEqual(call_kwargs["max_tokens"], 200) - self.assertEqual(call_kwargs["temperature"], 0.5) - - -class TestOpenAIAgentIntegration(unittest.TestCase): - """Integration-style tests for OpenAIAgent.""" - - @patch("hackagent.router.adapters.openai.OPENAI_AVAILABLE", True) - @patch("hackagent.router.adapters.openai.OpenAI") - def test_full_conversation_flow(self, mock_openai_class): - """Test a full conversation flow with multiple messages.""" - mock_client = MagicMock() - mock_openai_class.return_value = mock_client - - adapter = OpenAIAgent( - id="conversation_test", config={"name": "gpt-4", "temperature": 0.7} - ) - - # Mock response - mock_message = MagicMock() - mock_message.content = "I'm doing great, thank you!" - mock_message.tool_calls = None - - mock_choice = MagicMock() - mock_choice.message = mock_message - mock_choice.finish_reason = "stop" - - mock_usage = MagicMock() - mock_usage.model_dump.return_value = {"total_tokens": 50} - - mock_response = MagicMock() - mock_response.choices = [mock_choice] - mock_response.usage = mock_usage - mock_response.model = "gpt-4" - - mock_client.chat.completions.create.return_value = mock_response - - # Simulate a conversation + @patch("litellm.completion") + def test_handle_request_with_messages_success(self, mock_completion): + mock_completion.return_value = _make_litellm_response("Hi!") messages = [ - {"role": "system", "content": "You are a friendly assistant."}, - {"role": "user", "content": "Hello!"}, - {"role": "assistant", "content": "Hi! How can I help you?"}, - {"role": "user", "content": "How are you?"}, + {"role": "system", "content": "be helpful"}, + {"role": "user", "content": "ping"}, ] - - request_data = {"messages": messages} - response = adapter.handle_request(request_data) + response = self.adapter.handle_request({"messages": messages}) self.assertEqual(response["status_code"], 200) - self.assertEqual(response["generated_text"], "I'm doing great, thank you!") - self.assertEqual(response["agent_specific_data"]["model_name"], "gpt-4") - - -class TestOpenAIAgentReasoningModels(unittest.TestCase): - """Test reasoning model support (e.g., o1-preview, o1-mini).""" - - def setUp(self): - """Set up test fixtures.""" - self.adapter_id = "openai_reasoning_test" - self.config = { - "name": "o1-preview", - "temperature": 1.0, - } - - # Patch at module level - self.openai_patch = patch( - "hackagent.router.adapters.openai.OPENAI_AVAILABLE", True + self.assertEqual(response["generated_text"], "Hi!") + kwargs = mock_completion.call_args.kwargs + self.assertEqual(kwargs["messages"], messages) + + @patch("litellm.completion") + def test_handle_request_with_tool_calls(self, mock_completion): + tool = MagicMock() + tool.id = "call_1" + tool.type = "function" + tool.function.name = "get_weather" + tool.function.arguments = '{"loc": "SF"}' + mock_completion.return_value = _make_litellm_response( + "I'll call a tool", tool_calls=[tool] ) - self.openai_class_patch = patch("hackagent.router.adapters.openai.OpenAI") - - self.openai_patch.start() - self.mock_openai_class = self.openai_class_patch.start() - - self.mock_client = MagicMock() - self.mock_openai_class.return_value = self.mock_client - - self.adapter = OpenAIAgent(id=self.adapter_id, config=self.config) - def tearDown(self): - """Clean up patches.""" - self.openai_patch.stop() - self.openai_class_patch.stop() - - def test_handle_request_with_reasoning_field(self): - """Test that reasoning field is extracted when content is empty.""" - # Mock a reasoning model response with reasoning field - mock_message = MagicMock() - mock_message.content = None # Reasoning models may have no content - mock_message.reasoning = "Let me think through this step by step..." - mock_message.tool_calls = None - - mock_choice = MagicMock() - mock_choice.message = mock_message - mock_choice.finish_reason = "stop" - - mock_usage = MagicMock() - mock_usage.model_dump.return_value = { - "prompt_tokens": 20, - "completion_tokens": 50, - "total_tokens": 70, - } - - mock_response = MagicMock() - mock_response.choices = [mock_choice] - mock_response.usage = mock_usage - mock_response.model = "o1-preview" - - self.mock_client.chat.completions.create.return_value = mock_response - - request_data = {"prompt": "What is 2+2?"} - response = self.adapter.handle_request(request_data) - - # Verify reasoning field was extracted - self.assertEqual(response["status_code"], 200) - self.assertEqual( - response["generated_text"], "Let me think through this step by step..." - ) - self.assertEqual( - response["processed_response"], "Let me think through this step by step..." + response = self.adapter.handle_request( + { + "prompt": "weather?", + "tools": [{"type": "function", "function": {"name": "x"}}], + "tool_choice": "auto", + } ) - self.assertEqual(response["agent_specific_data"]["model_name"], "o1-preview") - - def test_handle_request_with_empty_content_and_reasoning(self): - """Test extraction when content is empty string but reasoning exists.""" - mock_message = MagicMock() - mock_message.content = "" # Empty content - mock_message.reasoning = "First, I need to analyze the problem..." - mock_message.tool_calls = None - - mock_choice = MagicMock() - mock_choice.message = mock_message - mock_choice.finish_reason = "stop" - - mock_usage = MagicMock() - mock_usage.model_dump.return_value = {"total_tokens": 100} - - mock_response = MagicMock() - mock_response.choices = [mock_choice] - mock_response.usage = mock_usage - mock_response.model = "o1-mini" - - self.mock_client.chat.completions.create.return_value = mock_response - - request_data = {"prompt": "Solve this problem"} - response = self.adapter.handle_request(request_data) self.assertEqual(response["status_code"], 200) - self.assertEqual( - response["generated_text"], "First, I need to analyze the problem..." + tcs = response["agent_specific_data"]["tool_calls"] + self.assertEqual(len(tcs), 1) + self.assertEqual(tcs[0]["function"]["name"], "get_weather") + kwargs = mock_completion.call_args.kwargs + self.assertIn("tools", kwargs) + self.assertEqual(kwargs["tool_choice"], "auto") + + @patch("litellm.completion") + def test_parameter_overrides_apply(self, mock_completion): + mock_completion.return_value = _make_litellm_response("ok") + self.adapter.handle_request( + {"prompt": "go", "max_tokens": 200, "temperature": 0.5} ) + kwargs = mock_completion.call_args.kwargs + self.assertEqual(kwargs["max_tokens"], 200) + self.assertEqual(kwargs["temperature"], 0.5) + + @patch("litellm.completion") + def test_handle_request_api_error(self, mock_completion): + mock_completion.side_effect = RuntimeError("boom") + response = self.adapter.handle_request({"prompt": "Hi"}) + self.assertEqual(response["status_code"], 500) + self.assertIn("boom", response["error_message"]) - def test_handle_request_without_reasoning_attribute(self): - """Test handling when message has no reasoning attribute (non-reasoning model).""" - mock_message = MagicMock() - mock_message.content = "Regular response" - # Don't set reasoning attribute at all - mock_message.tool_calls = None - # Ensure hasattr returns False - del mock_message.reasoning - - mock_choice = MagicMock() - mock_choice.message = mock_message - mock_choice.finish_reason = "stop" - - mock_usage = MagicMock() - mock_usage.model_dump.return_value = {"total_tokens": 50} - - mock_response = MagicMock() - mock_response.choices = [mock_choice] - mock_response.usage = mock_usage - mock_response.model = "gpt-4" - - self.mock_client.chat.completions.create.return_value = mock_response - - request_data = {"prompt": "Hello"} - response = self.adapter.handle_request(request_data) - - # Should use content, not reasoning - self.assertEqual(response["status_code"], 200) - self.assertEqual(response["generated_text"], "Regular response") - - def test_handle_request_content_takes_precedence_over_reasoning(self): - """Test that non-empty content takes precedence over reasoning field.""" - mock_message = MagicMock() - mock_message.content = "This is the actual response" - mock_message.reasoning = "This is the reasoning" - mock_message.tool_calls = None - - mock_choice = MagicMock() - mock_choice.message = mock_message - mock_choice.finish_reason = "stop" - - mock_usage = MagicMock() - mock_usage.model_dump.return_value = {"total_tokens": 60} - - mock_response = MagicMock() - mock_response.choices = [mock_choice] - mock_response.usage = mock_usage - mock_response.model = "o1-preview" - - self.mock_client.chat.completions.create.return_value = mock_response - - request_data = {"prompt": "Test"} - response = self.adapter.handle_request(request_data) - - # Content should be used, not reasoning - self.assertEqual(response["status_code"], 200) - self.assertEqual(response["generated_text"], "This is the actual response") - - def test_handle_request_reasoning_with_messages(self): - """Test reasoning model with pre-formatted messages.""" - mock_message = MagicMock() - mock_message.content = None - mock_message.reasoning = "Analyzing the conversation context..." - mock_message.tool_calls = None - - mock_choice = MagicMock() - mock_choice.message = mock_message - mock_choice.finish_reason = "stop" - - mock_usage = MagicMock() - mock_usage.model_dump.return_value = {"total_tokens": 150} - - mock_response = MagicMock() - mock_response.choices = [mock_choice] - mock_response.usage = mock_usage - mock_response.model = "o1-mini" - self.mock_client.chat.completions.create.return_value = mock_response +class TestOpenAIAgentThinking(unittest.TestCase): + """Issue #379 — verify the unified thinking knob translates correctly.""" - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Help me solve this."}, - ] - request_data = {"messages": messages} - response = self.adapter.handle_request(request_data) - - self.assertEqual(response["status_code"], 200) + @patch("litellm.completion") + def test_thinking_true_on_reasoning_model_sets_reasoning_effort( + self, mock_completion + ): + mock_completion.return_value = _make_litellm_response("hi") + adapter = OpenAIAgent(id="r1", config={"name": "o1-mini", "thinking": True}) + adapter.handle_request({"prompt": "hello"}) + kwargs = mock_completion.call_args.kwargs + self.assertEqual(kwargs.get("reasoning_effort"), "medium") + self.assertNotIn("thinking", kwargs) + + @patch("litellm.completion") + def test_thinking_false_on_reasoning_model_omits_effort(self, mock_completion): + mock_completion.return_value = _make_litellm_response("hi") + adapter = OpenAIAgent(id="r2", config={"name": "o3", "thinking": False}) + adapter.handle_request({"prompt": "hello"}) + kwargs = mock_completion.call_args.kwargs + self.assertNotIn("reasoning_effort", kwargs) + self.assertNotIn("thinking", kwargs) + + @patch("litellm.completion") + def test_thinking_string_passes_through_as_effort(self, mock_completion): + mock_completion.return_value = _make_litellm_response("hi") + adapter = OpenAIAgent(id="r3", config={"name": "o1"}) + adapter.handle_request({"prompt": "hello", "thinking": "high"}) self.assertEqual( - response["generated_text"], "Analyzing the conversation context..." + mock_completion.call_args.kwargs.get("reasoning_effort"), "high" ) - self.assertIsNone(response["error_message"]) - - def test_handle_request_reasoning_none_and_content_none(self): - """Test when both reasoning and content are None (edge case).""" - mock_message = MagicMock() - mock_message.content = None - mock_message.reasoning = None - mock_message.tool_calls = None - - mock_choice = MagicMock() - mock_choice.message = mock_message - mock_choice.finish_reason = "stop" - - mock_usage = MagicMock() - mock_usage.model_dump.return_value = {"total_tokens": 10} - - mock_response = MagicMock() - mock_response.choices = [mock_choice] - mock_response.usage = mock_usage - mock_response.model = "o1-preview" - - self.mock_client.chat.completions.create.return_value = mock_response - - request_data = {"prompt": "Test"} - response = self.adapter.handle_request(request_data) - # Should return empty string - self.assertEqual(response["status_code"], 200) - self.assertEqual(response["generated_text"], "") - - def test_handle_request_strips_text_before_think_close(self): - """Test that text before and including '' is removed.""" - mock_message = MagicMock() - mock_message.content = "draft stepsFinal visible answer" - mock_message.tool_calls = None - - mock_choice = MagicMock() - mock_choice.message = mock_message - mock_choice.finish_reason = "stop" - - mock_usage = MagicMock() - mock_usage.model_dump.return_value = {"total_tokens": 12} - - mock_response = MagicMock() - mock_response.choices = [mock_choice] - mock_response.usage = mock_usage - mock_response.model = "gpt-4" - - self.mock_client.chat.completions.create.return_value = mock_response - - response = self.adapter.handle_request({"prompt": "Hello"}) - - self.assertEqual(response["status_code"], 200) - self.assertEqual(response["generated_text"], "Final visible answer") - self.assertEqual(response["processed_response"], "Final visible answer") + @patch("litellm.completion") + def test_thinking_on_non_reasoning_model_passes_through_generically( + self, mock_completion + ): + mock_completion.return_value = _make_litellm_response("hi") + adapter = OpenAIAgent(id="r4", config={"name": "gpt-4"}) + adapter.handle_request({"prompt": "hello", "thinking": True}) + kwargs = mock_completion.call_args.kwargs + # Non-reasoning OpenAI models get the generic LiteLLM thinking dict. + self.assertEqual(kwargs.get("thinking"), {"type": "enabled"}) if __name__ == "__main__": From 1b3dedfebc98d9c99cf2353d1b2630b9ace28f98 Mon Sep 17 00:00:00 2001 From: Nicola Franco Date: Sat, 23 May 2026 16:33:12 +0200 Subject: [PATCH 02/23] refactor(router): extract envelope helpers (#379 Phase A) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase A of the LiteLLM router refactor plan: pull the response-shaping logic out of the adapter classes into pure functions in ``hackagent/router/envelope.py``. Adapter classes now delegate to those helpers; the public dict shape returned by ``handle_request`` is unchanged (snapshot tests in test_envelope.py). Also lands ``hackagent/router/provider_config.py`` — the AgentType → provider-prefix + thinking-translator + extra-passthrough-keys table that Phase C will use to bypass the adapter classes entirely. The table isn't wired in yet; this commit just ships the lookup module with its own unit tests. Co-Authored-By: Claude Opus 4.7 (1M context) --- hackagent/router/adapters/base.py | 141 ++++------ hackagent/router/adapters/litellm.py | 202 ++++----------- hackagent/router/envelope.py | 299 ++++++++++++++++++++++ hackagent/router/provider_config.py | 153 +++++++++++ tests/unit/router/test_envelope.py | 249 ++++++++++++++++++ tests/unit/router/test_provider_config.py | 177 +++++++++++++ 6 files changed, 979 insertions(+), 242 deletions(-) create mode 100644 hackagent/router/envelope.py create mode 100644 hackagent/router/provider_config.py create mode 100644 tests/unit/router/test_envelope.py create mode 100644 tests/unit/router/test_provider_config.py diff --git a/hackagent/router/adapters/base.py b/hackagent/router/adapters/base.py index 59505c3a..5f92a3e2 100644 --- a/hackagent/router/adapters/base.py +++ b/hackagent/router/adapters/base.py @@ -15,6 +15,8 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple +from hackagent.router import envelope as _envelope + # --- Common Exception Classes --- class AdapterConfigurationError(Exception): @@ -266,39 +268,23 @@ def _build_error_response( raw_response_headers: Optional[Dict[str, str]] = None, agent_specific_data: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: - """ - Constructs a standardized error response dictionary. - - Args: - error_message: The primary error message string. - status_code: The HTTP status code associated with the error. - raw_request: The original request data that led to the error. - raw_response_body: Raw response body if available. - raw_response_headers: Response headers if available. - agent_specific_data: Additional adapter-specific data. - - Returns: - A dictionary representing a standardized error response. - """ - if agent_specific_data is None: - agent_specific_data = {} - - # Include model_name in agent_specific_data if available - if self.model_name and "model_name" not in agent_specific_data: - agent_specific_data["model_name"] = self.model_name - - return { - "raw_request": raw_request, - "processed_response": None, - "generated_text": None, - "status_code": status_code if status_code is not None else 500, - "raw_response_headers": raw_response_headers, - "raw_response_body": raw_response_body, - "agent_specific_data": agent_specific_data, - "error_message": error_message, - "agent_id": self.id, - "adapter_type": self.ADAPTER_TYPE, - } + """Construct HackAgent's standardised error-response dict. + + Delegates to :func:`hackagent.router.envelope.build_error_envelope` + so the dict shape lives in one place (see + ``LITELLM_ROUTER_REFACTOR_PLAN.md`` Phase A). + """ + return _envelope.build_error_envelope( + agent_id=self.id, + adapter_type=self.ADAPTER_TYPE, + error_message=error_message, + status_code=status_code, + raw_request=raw_request, + raw_response_body=raw_response_body, + raw_response_headers=raw_response_headers, + agent_specific_data=agent_specific_data, + model_name=self.model_name, + ) def _build_success_response( self, @@ -309,50 +295,25 @@ def _build_success_response( agent_specific_data: Optional[Dict[str, Any]] = None, status_code: int = 200, ) -> Dict[str, Any]: - """ - Constructs a standardized success response dictionary. - - Args: - processed_response: The processed/generated text response. - raw_request: The original request data. - raw_response_body: Raw response body if available. - raw_response_headers: Response headers if available. - agent_specific_data: Additional adapter-specific data. - status_code: HTTP status code (default: 200). - - Returns: - A dictionary representing a standardized success response. - """ - if isinstance(processed_response, str): - processed_response = self._strip_think_prefix(processed_response) - - if agent_specific_data is None: - agent_specific_data = {} - - # Include model_name in agent_specific_data if available - if self.model_name and "model_name" not in agent_specific_data: - agent_specific_data["model_name"] = self.model_name - - return { - "raw_request": raw_request, - "processed_response": processed_response, - "generated_text": processed_response, - "status_code": status_code, - "raw_response_headers": raw_response_headers, - "raw_response_body": raw_response_body, - "agent_specific_data": agent_specific_data, - "error_message": None, - "agent_id": self.id, - "adapter_type": self.ADAPTER_TYPE, - } + """Construct HackAgent's standardised success-response dict. + + Delegates to :func:`hackagent.router.envelope.build_success_envelope`. + """ + return _envelope.build_success_envelope( + agent_id=self.id, + adapter_type=self.ADAPTER_TYPE, + processed_response=processed_response, + raw_request=raw_request, + raw_response_body=raw_response_body, + raw_response_headers=raw_response_headers, + agent_specific_data=agent_specific_data, + model_name=self.model_name, + status_code=status_code, + ) def _strip_think_prefix(self, text: str) -> str: - """Strip hidden reasoning prefix up to and including '' if present.""" - marker = "" - marker_index = text.find(marker) - if marker_index == -1: - return text - return text[marker_index + len(marker) :] + """Strip hidden reasoning prefix up to and including ''.""" + return _envelope.strip_think_prefix(text) @abstractmethod def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: @@ -522,29 +483,17 @@ def _build_agent_specific_data( completion_result: Dict[str, Any], parameters: Dict[str, Any], ) -> Dict[str, Any]: - """ - Build the agent_specific_data dictionary for the response. - - Override this method to add adapter-specific metadata. + """Build the standard ``agent_specific_data`` block. - Args: - completion_result: The result dictionary from _execute_completion. - parameters: The parameters used for the completion call. - - Returns: - Dictionary of adapter-specific data to include in response. + Delegates to + :func:`hackagent.router.envelope.build_agent_specific_data`. + Subclasses override to add adapter-specific metadata. """ - data = { - "model_name": self.model_name, - "invoked_parameters": parameters, - } - - # Include any additional data from the completion result - for key in ["usage", "finish_reason", "raw_response"]: - if key in completion_result: - data[key] = completion_result[key] - - return data + return _envelope.build_agent_specific_data( + model_name=self.model_name, + invoked_parameters=parameters, + completion_result=completion_result, + ) def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: """ diff --git a/hackagent/router/adapters/litellm.py b/hackagent/router/adapters/litellm.py index 947eb518..4fb0b280 100644 --- a/hackagent/router/adapters/litellm.py +++ b/hackagent/router/adapters/litellm.py @@ -5,6 +5,8 @@ from hackagent.logger import get_logger from typing import Any, Dict, List, Optional +from hackagent.router import envelope as _envelope + from .base import ChatCompletionsAgent, AdapterConfigurationError # Lazy load litellm - only import when actually needed to avoid ~2s startup delay @@ -87,27 +89,9 @@ class LiteLLMConfigurationError(AdapterConfigurationError): logger = get_logger(__name__) # Module-level logger -# Provider prefixes that LiteLLM recognises natively. When a model string -# already starts with one of these, we leave it alone instead of prepending -# our own provider prefix. -_KNOWN_LITELLM_PROVIDER_PREFIXES = ( - "openai/", - "anthropic/", - "azure/", - "bedrock/", - "vertex_ai/", - "huggingface/", - "replicate/", - "together_ai/", - "anyscale/", - "ollama/", - "ollama_chat/", - "groq/", - "mistral/", - "cohere/", - "gemini/", - "deepseek/", -) +# Sourced from envelope.py — kept here as a module-level alias so any +# external code that imported it from this module still works. +_KNOWN_LITELLM_PROVIDER_PREFIXES = _envelope.KNOWN_LITELLM_PROVIDER_PREFIXES class LiteLLMAgent(ChatCompletionsAgent): @@ -202,14 +186,13 @@ def __init__(self, id: str, config: Dict[str, Any]): def _resolve_litellm_model(self, raw_model: str) -> str: """Return the model string to pass to ``litellm.completion``. - Honors the subclass ``PROVIDER_PREFIX`` while leaving names that - already carry an explicit LiteLLM provider prefix untouched. + Delegates to :func:`hackagent.router.envelope.resolve_litellm_model`. + Subclasses (notably ADKAgent) override this entirely to inject a + per-instance provider prefix. """ - if self.PROVIDER_PREFIX is None: - return raw_model - if raw_model.startswith(_KNOWN_LITELLM_PROVIDER_PREFIXES): - return raw_model - return f"{self.PROVIDER_PREFIX}/{raw_model}" + return _envelope.resolve_litellm_model( + raw_model, provider_prefix=self.PROVIDER_PREFIX + ) def _default_api_key_env_var(self) -> Optional[str]: """Return the env var used as a fallback when no API key is configured.""" @@ -261,123 +244,65 @@ def _prepare_litellm_params( top_p: float, **kwargs, ) -> Dict[str, Any]: - """Build the kwargs dict for ``litellm.completion``.""" - litellm_params: Dict[str, Any] = { - "model": self.litellm_model, - "messages": messages, - "max_tokens": max_tokens, - "temperature": temperature, - "top_p": top_p, - } + """Build the kwargs dict for ``litellm.completion``. - if self.api_base_url: - litellm_params["api_base"] = self.api_base_url - if self.actual_api_key: - litellm_params["api_key"] = self.actual_api_key - - # When the caller provides a custom endpoint without a recognised - # LiteLLM provider prefix, treat it as OpenAI-compatible. This - # preserves the previous behaviour for plain LiteLLM users and gives - # us a sensible default for LangChain-style endpoints. - if self.api_base_url and not self.litellm_model.startswith( - _KNOWN_LITELLM_PROVIDER_PREFIXES - ): - litellm_params["custom_llm_provider"] = "openai" - litellm_params["extra_headers"] = {"User-Agent": "HackAgent/0.1.0"} - elif self.api_base_url: - # Keep the User-Agent for outbound requests even when a provider - # prefix is supplied — useful for self-hosted proxies. - litellm_params["extra_headers"] = {"User-Agent": "HackAgent/0.1.0"} - - # Thinking handling — config default merged with per-request override. + Delegates the bulk construction to + :func:`hackagent.router.envelope.build_litellm_kwargs`. The + thinking translation still goes through ``_apply_thinking`` so + subclasses can specialise it. + """ + thinking_payload: Dict[str, Any] = {} thinking = kwargs.pop("thinking", self.default_thinking) - self._apply_thinking(litellm_params, thinking) + self._apply_thinking(thinking_payload, thinking) - # Tool calls. tools = kwargs.pop("tools", self.default_tools) tool_choice = kwargs.pop("tool_choice", self.default_tool_choice) - if tools: - litellm_params["tools"] = tools - if tool_choice is not None: - litellm_params["tool_choice"] = tool_choice - - # Provider-specific extra body (e.g. OpenRouter ``reasoning``). extra_body = kwargs.pop("extra_body", self.default_extra_body) - if extra_body is not None: - litellm_params["extra_body"] = ( - dict(extra_body) if isinstance(extra_body, dict) else extra_body - ) - litellm_params.update(kwargs) - return litellm_params + return _envelope.build_litellm_kwargs( + model=self.litellm_model, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + api_base=self.api_base_url, + api_key=self.actual_api_key, + tools=tools, + tool_choice=tool_choice, + extra_body=extra_body, + thinking_payload=thinking_payload, + extra_kwargs=kwargs, + ) def _extract_raw_response_content(self, response: Any, context: str = "") -> str: - """Extract content from a litellm response object.""" - if not (response and response.choices and response.choices[0].message): + """Extract content from a litellm response object. + + Delegates to + :func:`hackagent.router.envelope.extract_text_from_response`. The + ``context`` argument is preserved for backwards compatibility but + is only used for logging when the response is malformed. + """ + text = _envelope.extract_text_from_response( + response, model_name=self.litellm_model + ) + if text == "[GENERATION_ERROR: UNEXPECTED_RESPONSE]": self.logger.warning( f"LiteLLM received unexpected response structure for model " f"'{self.litellm_model}'{context}. Response: {response}" ) - return "[GENERATION_ERROR: UNEXPECTED_RESPONSE]" - - message = response.choices[0].message - content = message.content if message.content else "" - - # Reasoning models surface their output in a dedicated field; fall - # back to it when the regular content is empty. - reasoning_content = None - if hasattr(message, "reasoning_content") and message.reasoning_content: - reasoning_content = message.reasoning_content - elif hasattr(message, "reasoning") and message.reasoning: - reasoning_content = message.reasoning - elif ( - hasattr(message, "provider_specific_fields") - and message.provider_specific_fields - ): - reasoning_content = message.provider_specific_fields.get( - "reasoning_content" - ) or message.provider_specific_fields.get("reasoning") - - if content: - return content - if reasoning_content: - self.logger.debug( - f"LiteLLM using reasoning content for model " - f"'{self.litellm_model}' (content field was empty)" + elif text == "[GENERATION_ERROR: EMPTY_RESPONSE]": + self.logger.warning( + f"LiteLLM received empty content and no reasoning field for " + f"model '{self.litellm_model}'{context}." ) - return reasoning_content - - self.logger.warning( - f"LiteLLM received empty content and no reasoning field for model " - f"'{self.litellm_model}'{context}. Message: {message}" - ) - return "[GENERATION_ERROR: EMPTY_RESPONSE]" + return text def _extract_tool_calls(self, response: Any) -> Optional[List[Dict[str, Any]]]: - """Extract OpenAI-style tool_calls from a LiteLLM response, if any.""" - try: - message = response.choices[0].message - except (AttributeError, IndexError, TypeError): - return None - tool_calls = getattr(message, "tool_calls", None) - if not tool_calls: - return None - result = [] - for tc in tool_calls: - try: - result.append( - { - "id": getattr(tc, "id", None), - "type": getattr(tc, "type", "function"), - "function": { - "name": tc.function.name, - "arguments": tc.function.arguments, - }, - } - ) - except AttributeError: - continue - return result or None + """Return OpenAI-style ``tool_calls`` from a ``ModelResponse``, or ``None``. + + Delegates to :func:`hackagent.router.envelope.extract_tool_calls`. + """ + return _envelope.extract_tool_calls(response) def _get_excluded_request_keys(self) -> set: """Return keys handled explicitly so they aren't re-passed as kwargs.""" @@ -483,24 +408,9 @@ def _execute_completion( "error_message": str(e), } - # ---- response shaping ------------------------------------------------ - - def _build_agent_specific_data( - self, - completion_result: Dict[str, Any], - parameters: Dict[str, Any], - ) -> Dict[str, Any]: - """Include common LiteLLM metadata (finish_reason, usage, tools).""" - data = super()._build_agent_specific_data(completion_result, parameters) - for key in ("finish_reason", "usage", "provider_model"): - value = completion_result.get(key) - if value is not None and key not in data: - data[key] = value - if completion_result.get("tool_calls"): - data["tool_calls"] = completion_result["tool_calls"] - return data - # ---- legacy convenience helpers ------------------------------------- + # (Response-shaping is handled by the base ``ChatCompletionsAgent`` via + # ``envelope.build_agent_specific_data`` since Phase A.) def _execute_litellm_completion_with_messages( self, diff --git a/hackagent/router/envelope.py b/hackagent/router/envelope.py new file mode 100644 index 00000000..f59e3bf1 --- /dev/null +++ b/hackagent/router/envelope.py @@ -0,0 +1,299 @@ +# Copyright 2026 - AI4I. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Envelope helpers — pure functions that translate between LiteLLM's +``ModelResponse`` and HackAgent's standardized response dict. + +This module exists as the Phase A landing zone of the +``LITELLM_ROUTER_REFACTOR_PLAN.md`` plan: extract the response-shaping +logic out of the adapter classes so it can be reused by +``AgentRouter`` once the call path is hoisted in Phase C. + +The functions here are intentionally: +- pure: no I/O, no logging side effects, no LiteLLM imports at module + level. Any LiteLLM import lives behind a lazy helper. +- agnostic of agent identity: the caller supplies ``agent_id`` and + ``adapter_type`` as keyword arguments. +- byte-compatible with the previous adapter envelope, so downstream + consumers (``StepTracker``, attacks, evaluators, dashboard) keep + seeing exactly the same dict shape. +""" + +from typing import Any, Dict, List, Optional + + +# Provider prefixes that LiteLLM recognises natively. When a model string +# already starts with one of these we leave it alone instead of prepending +# our own provider prefix. +KNOWN_LITELLM_PROVIDER_PREFIXES = ( + "openai/", + "anthropic/", + "azure/", + "bedrock/", + "vertex_ai/", + "huggingface/", + "replicate/", + "together_ai/", + "anyscale/", + "ollama/", + "ollama_chat/", + "groq/", + "mistral/", + "cohere/", + "gemini/", + "deepseek/", +) + + +# ---- text helpers -------------------------------------------------------- + + +def strip_think_prefix(text: str) -> str: + """Strip hidden reasoning prefix up to and including ```` if present.""" + if not isinstance(text, str): + return text + marker = "" + marker_index = text.find(marker) + if marker_index == -1: + return text + return text[marker_index + len(marker) :] + + +def extract_text_from_response(response: Any, *, model_name: str = "") -> str: + """Pull the assistant text out of a LiteLLM ``ModelResponse``. + + Falls back to ``reasoning_content`` / ``reasoning`` when ``content`` + is empty so reasoning-only models still produce output. Returns a + sentinel ``[GENERATION_ERROR: ...]`` string when the response is + structurally unusable, mirroring the previous adapter behaviour. + """ + if not ( + response and getattr(response, "choices", None) and response.choices[0].message + ): + return "[GENERATION_ERROR: UNEXPECTED_RESPONSE]" + + message = response.choices[0].message + content = getattr(message, "content", "") or "" + + reasoning_content = None + if getattr(message, "reasoning_content", None): + reasoning_content = message.reasoning_content + elif getattr(message, "reasoning", None): + reasoning_content = message.reasoning + else: + provider_specific = getattr(message, "provider_specific_fields", None) + if provider_specific: + reasoning_content = provider_specific.get( + "reasoning_content" + ) or provider_specific.get("reasoning") + + if content: + return content + if reasoning_content: + return reasoning_content + return "[GENERATION_ERROR: EMPTY_RESPONSE]" + + +def extract_tool_calls(response: Any) -> Optional[List[Dict[str, Any]]]: + """Return OpenAI-style ``tool_calls`` from a ``ModelResponse``, or ``None``.""" + try: + message = response.choices[0].message + except (AttributeError, IndexError, TypeError): + return None + tool_calls = getattr(message, "tool_calls", None) + if not tool_calls: + return None + out: List[Dict[str, Any]] = [] + for tc in tool_calls: + try: + out.append( + { + "id": getattr(tc, "id", None), + "type": getattr(tc, "type", "function"), + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + ) + except AttributeError: + continue + return out or None + + +# ---- LiteLLM kwargs assembly -------------------------------------------- + + +def resolve_litellm_model( + raw_model: str, *, provider_prefix: Optional[str] = None +) -> str: + """Return the model string to pass to ``litellm.completion``. + + Honors a caller-supplied ``provider_prefix`` while leaving names that + already carry an explicit LiteLLM provider prefix untouched. + """ + if not provider_prefix: + return raw_model + if raw_model.startswith(KNOWN_LITELLM_PROVIDER_PREFIXES): + return raw_model + return f"{provider_prefix}/{raw_model}" + + +def build_litellm_kwargs( + *, + model: str, + messages: List[Dict[str, str]], + max_tokens: int, + temperature: float, + top_p: float, + api_base: Optional[str] = None, + api_key: Optional[str] = None, + tools: Optional[Any] = None, + tool_choice: Optional[Any] = None, + extra_body: Optional[Any] = None, + thinking_payload: Optional[Dict[str, Any]] = None, + extra_kwargs: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """Build the kwargs dict for ``litellm.completion``. + + ``thinking_payload`` is the *already-translated* per-provider dict + (e.g. ``{"reasoning_effort": "medium"}`` or ``{"think": True}``); + the caller is responsible for converting the unified ``thinking`` + knob into the provider-specific shape before passing it in here. + Anything in ``extra_kwargs`` is splat-merged last and wins on + collision, matching the previous adapter behaviour. + """ + params: Dict[str, Any] = { + "model": model, + "messages": messages, + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + } + + if api_base: + params["api_base"] = api_base + if api_key: + params["api_key"] = api_key + + # When the caller provides a custom endpoint without a recognised + # LiteLLM provider prefix, treat it as OpenAI-compatible — same + # default the previous LiteLLMAgent used. + if api_base and not model.startswith(KNOWN_LITELLM_PROVIDER_PREFIXES): + params["custom_llm_provider"] = "openai" + params["extra_headers"] = {"User-Agent": "HackAgent/0.1.0"} + elif api_base: + params["extra_headers"] = {"User-Agent": "HackAgent/0.1.0"} + + if thinking_payload: + params.update(thinking_payload) + + if tools: + params["tools"] = tools + if tool_choice is not None: + params["tool_choice"] = tool_choice + + if extra_body is not None: + params["extra_body"] = ( + dict(extra_body) if isinstance(extra_body, dict) else extra_body + ) + + if extra_kwargs: + params.update(extra_kwargs) + + return params + + +# ---- response envelopes ------------------------------------------------- + + +def build_success_envelope( + *, + agent_id: str, + adapter_type: str, + processed_response: Optional[str], + raw_request: Optional[Dict[str, Any]] = None, + raw_response_body: Optional[Any] = None, + raw_response_headers: Optional[Dict[str, str]] = None, + agent_specific_data: Optional[Dict[str, Any]] = None, + model_name: Optional[str] = None, + status_code: int = 200, +) -> Dict[str, Any]: + """Construct HackAgent's standardised success-response dict.""" + if isinstance(processed_response, str): + processed_response = strip_think_prefix(processed_response) + + if agent_specific_data is None: + agent_specific_data = {} + if model_name and "model_name" not in agent_specific_data: + agent_specific_data["model_name"] = model_name + + return { + "raw_request": raw_request, + "processed_response": processed_response, + "generated_text": processed_response, + "status_code": status_code, + "raw_response_headers": raw_response_headers, + "raw_response_body": raw_response_body, + "agent_specific_data": agent_specific_data, + "error_message": None, + "agent_id": agent_id, + "adapter_type": adapter_type, + } + + +def build_error_envelope( + *, + agent_id: str, + adapter_type: str, + error_message: str, + status_code: Optional[int] = None, + raw_request: Optional[Dict[str, Any]] = None, + raw_response_body: Optional[Any] = None, + raw_response_headers: Optional[Dict[str, str]] = None, + agent_specific_data: Optional[Dict[str, Any]] = None, + model_name: Optional[str] = None, +) -> Dict[str, Any]: + """Construct HackAgent's standardised error-response dict.""" + if agent_specific_data is None: + agent_specific_data = {} + if model_name and "model_name" not in agent_specific_data: + agent_specific_data["model_name"] = model_name + + return { + "raw_request": raw_request, + "processed_response": None, + "generated_text": None, + "status_code": status_code if status_code is not None else 500, + "raw_response_headers": raw_response_headers, + "raw_response_body": raw_response_body, + "agent_specific_data": agent_specific_data, + "error_message": error_message, + "agent_id": agent_id, + "adapter_type": adapter_type, + } + + +def build_agent_specific_data( + *, + model_name: Optional[str], + invoked_parameters: Dict[str, Any], + completion_result: Optional[Dict[str, Any]] = None, + extra: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """Build the standard ``agent_specific_data`` block shared by adapters.""" + data: Dict[str, Any] = { + "model_name": model_name, + "invoked_parameters": invoked_parameters, + } + if completion_result: + for key in ("usage", "finish_reason", "provider_model", "raw_response"): + value = completion_result.get(key) + if value is not None: + data[key] = value + if completion_result.get("tool_calls"): + data["tool_calls"] = completion_result["tool_calls"] + if extra: + data.update(extra) + return data diff --git a/hackagent/router/provider_config.py b/hackagent/router/provider_config.py new file mode 100644 index 00000000..2654683e --- /dev/null +++ b/hackagent/router/provider_config.py @@ -0,0 +1,153 @@ +# Copyright 2026 - AI4I. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +``AgentType`` → ``ProviderConfig`` table. + +The lookup table is the single source of truth for how each agent type +maps to a LiteLLM call: provider prefix, the ``thinking`` knob +translator, the allow-list of extra request keys that should pass +through, and an optional :class:`litellm.CustomLLM` factory for agent +types LiteLLM cannot speak natively (ADK, future MCP/A2A). +""" + +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Tuple + +from hackagent.router.types import AgentTypeEnum + + +# ---- thinking translators ----------------------------------------------- +# Each translator takes the raw ``thinking`` value and the model name and +# returns the (possibly empty) dict of provider-specific request fields +# that should be merged into the LiteLLM kwargs. + + +def default_thinking_translator( + thinking: Any, *, model_name: str = "" +) -> Dict[str, Any]: + """Provider-agnostic translation that matches LiteLLM's own conventions.""" + if thinking is None: + return {} + if isinstance(thinking, dict): + return {"thinking": dict(thinking)} + if isinstance(thinking, str): + return {"reasoning_effort": thinking} + if isinstance(thinking, bool): + return {"thinking": {"type": "enabled" if thinking else "disabled"}} + if isinstance(thinking, int): + return {"thinking": {"type": "enabled", "budget_tokens": int(thinking)}} + return {"thinking": thinking} + + +_OPENAI_REASONING_MODEL_PREFIXES = ("o1", "o3", "o4", "gpt-5", "gpt-6") + + +def openai_thinking_translator( + thinking: Any, *, model_name: str = "" +) -> Dict[str, Any]: + """Map ``thinking`` to ``reasoning_effort`` for OpenAI reasoning models.""" + if thinking is None: + return {} + bare = model_name.split("/")[-1] + is_reasoning = bare.startswith(_OPENAI_REASONING_MODEL_PREFIXES) + if is_reasoning: + if thinking is True: + return {"reasoning_effort": "medium"} + if thinking is False: + return {} + if isinstance(thinking, str): + return {"reasoning_effort": thinking} + if isinstance(thinking, dict): + effort = thinking.get("reasoning_effort") or thinking.get("effort") + if effort: + return {"reasoning_effort": effort} + return {"thinking": dict(thinking)} + return default_thinking_translator(thinking, model_name=model_name) + + +def ollama_thinking_translator( + thinking: Any, *, model_name: str = "" +) -> Dict[str, Any]: + """Map ``thinking`` to Ollama's native ``think`` field.""" + if thinking is None: + return {} + if isinstance(thinking, bool): + return {"think": thinking} + if isinstance(thinking, str): + return {"think": thinking} + if isinstance(thinking, int): + return {"think": thinking > 0} + if isinstance(thinking, dict): + kind = (thinking.get("type") or "").lower() + return {"think": False if kind == "disabled" else True} + return {"think": bool(thinking)} + + +# ---- provider config ----------------------------------------------------- + + +@dataclass(frozen=True) +class ProviderConfig: + """Per-``AgentType`` knobs the router uses to drive ``litellm.completion``.""" + + # LiteLLM provider prefix to prepend to ``model`` (``"openai"``, + # ``"ollama_chat"``…). ``None`` means leave the user-supplied model + # string unchanged (the LITELLM passthrough type). + provider_prefix: Optional[str] + + # Translates the unified ``thinking`` value into provider-specific + # request fields. Receives the raw value plus the model name. + thinking_translator: Callable[..., Dict[str, Any]] + + # ``adapter_type`` label that appears in the response envelope. + adapter_label: str + + # Additional request-data keys allowed to pass through into the + # LiteLLM call (e.g. ``top_k`` for Ollama, ``tools`` for OpenAI). + extra_passthrough_keys: Tuple[str, ...] = () + + # Optional zero-arg factory returning a (provider_name, handler) + # tuple to register with LiteLLM's ``custom_provider_map`` — only + # used by agent types whose protocol LiteLLM doesn't speak + # natively (ADK today; MCP/A2A in the future). + custom_llm_factory: Optional[Callable[..., Any]] = None + + +# ---- the table ---------------------------------------------------------- +# ADK isn't in the lookup table because its custom-LLM handler is +# constructed per-instance (it captures endpoint/user_id/session policy +# from the adapter config). It stays driven by ``ADKAgent`` for now and +# moves into ``router/providers/`` in Phase E. + +PROVIDER_CONFIGS: Dict[AgentTypeEnum, ProviderConfig] = { + AgentTypeEnum.LITELLM: ProviderConfig( + provider_prefix=None, + thinking_translator=default_thinking_translator, + adapter_label="LiteLLMAgent", + ), + AgentTypeEnum.OPENAI_SDK: ProviderConfig( + provider_prefix="openai", + thinking_translator=openai_thinking_translator, + adapter_label="OpenAIAgent", + extra_passthrough_keys=("tools", "tool_choice", "extra_body"), + ), + AgentTypeEnum.OLLAMA: ProviderConfig( + provider_prefix="ollama_chat", + thinking_translator=ollama_thinking_translator, + adapter_label="OllamaAgent", + extra_passthrough_keys=("top_k", "num_ctx", "stream"), + ), + AgentTypeEnum.LANGCHAIN: ProviderConfig( + # LangServe endpoints are OpenAI-compatible by convention; the + # generic LiteLLM passthrough already handles them. + provider_prefix=None, + thinking_translator=default_thinking_translator, + adapter_label="LiteLLMAgent", + ), +} + + +def get_provider_config(agent_type: AgentTypeEnum) -> Optional[ProviderConfig]: + """Return the ``ProviderConfig`` for ``agent_type``, or ``None``.""" + return PROVIDER_CONFIGS.get(agent_type) diff --git a/tests/unit/router/test_envelope.py b/tests/unit/router/test_envelope.py new file mode 100644 index 00000000..03c90bb6 --- /dev/null +++ b/tests/unit/router/test_envelope.py @@ -0,0 +1,249 @@ +# Copyright 2026 - AI4I. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for ``hackagent/router/envelope.py``.""" + +import logging +import unittest +from unittest.mock import MagicMock + +from hackagent.router import envelope + +logging.disable(logging.CRITICAL) + + +def _model_response( + content: str = "", + *, + reasoning_content: str = None, + reasoning: str = None, + tool_calls=None, +): + """Build a minimal mock of a litellm ``ModelResponse``.""" + response = MagicMock() + choice = MagicMock() + message = MagicMock() + message.content = content + message.reasoning_content = reasoning_content + message.reasoning = reasoning + message.tool_calls = tool_calls + message.provider_specific_fields = None + choice.message = message + response.choices = [choice] + return response + + +class TestStripThinkPrefix(unittest.TestCase): + def test_removes_prefix_up_to_and_including_marker(self): + self.assertEqual( + envelope.strip_think_prefix("scratchreal answer"), + "real answer", + ) + + def test_returns_unchanged_when_marker_absent(self): + self.assertEqual(envelope.strip_think_prefix("plain text"), "plain text") + + def test_handles_non_string_gracefully(self): + self.assertIs(envelope.strip_think_prefix(None), None) + + +class TestExtractTextFromResponse(unittest.TestCase): + def test_returns_content_when_present(self): + response = _model_response("hello world") + self.assertEqual(envelope.extract_text_from_response(response), "hello world") + + def test_falls_back_to_reasoning_content(self): + response = _model_response("", reasoning_content="reasoning trace") + self.assertEqual( + envelope.extract_text_from_response(response), "reasoning trace" + ) + + def test_falls_back_to_reasoning_attribute(self): + response = _model_response("", reasoning="reasoning text") + self.assertEqual( + envelope.extract_text_from_response(response), "reasoning text" + ) + + def test_returns_empty_response_marker_when_nothing_usable(self): + response = _model_response("") + self.assertEqual( + envelope.extract_text_from_response(response), + "[GENERATION_ERROR: EMPTY_RESPONSE]", + ) + + def test_returns_unexpected_response_marker_when_response_malformed(self): + bad = MagicMock() + bad.choices = [] + self.assertEqual( + envelope.extract_text_from_response(bad), + "[GENERATION_ERROR: UNEXPECTED_RESPONSE]", + ) + + +class TestExtractToolCalls(unittest.TestCase): + def test_returns_none_when_no_tool_calls(self): + self.assertIsNone(envelope.extract_tool_calls(_model_response("hi"))) + + def test_normalises_tool_call_shape(self): + tc = MagicMock() + tc.id = "call_1" + tc.type = "function" + tc.function.name = "do_thing" + tc.function.arguments = '{"x": 1}' + response = _model_response("", tool_calls=[tc]) + result = envelope.extract_tool_calls(response) + self.assertEqual( + result, + [ + { + "id": "call_1", + "type": "function", + "function": {"name": "do_thing", "arguments": '{"x": 1}'}, + } + ], + ) + + def test_returns_none_for_unstructured_response(self): + self.assertIsNone(envelope.extract_tool_calls(MagicMock(choices=[]))) + + +class TestResolveLitellmModel(unittest.TestCase): + def test_no_prefix_returns_raw(self): + self.assertEqual(envelope.resolve_litellm_model("gpt-4"), "gpt-4") + + def test_adds_prefix_when_provided(self): + self.assertEqual( + envelope.resolve_litellm_model("gpt-4", provider_prefix="openai"), + "openai/gpt-4", + ) + + def test_preserves_existing_known_prefix(self): + self.assertEqual( + envelope.resolve_litellm_model("ollama/llama3", provider_prefix="openai"), + "ollama/llama3", + ) + + +class TestBuildLitellmKwargs(unittest.TestCase): + def _common(self): + return dict( + model="openai/gpt-4", + messages=[{"role": "user", "content": "hi"}], + max_tokens=100, + temperature=0.7, + top_p=0.9, + ) + + def test_minimal_kwargs(self): + kwargs = envelope.build_litellm_kwargs(**self._common()) + self.assertEqual(kwargs["model"], "openai/gpt-4") + self.assertEqual(kwargs["temperature"], 0.7) + self.assertNotIn("api_base", kwargs) + + def test_attaches_api_base_and_key(self): + kwargs = envelope.build_litellm_kwargs( + api_base="http://host/v1", api_key="sk-x", **self._common() + ) + self.assertEqual(kwargs["api_base"], "http://host/v1") + self.assertEqual(kwargs["api_key"], "sk-x") + + def test_custom_endpoint_without_provider_prefix_falls_back_to_openai(self): + common = self._common() + common["model"] = "local-model" + kwargs = envelope.build_litellm_kwargs(api_base="http://host:8000/v1", **common) + self.assertEqual(kwargs.get("custom_llm_provider"), "openai") + self.assertIn("extra_headers", kwargs) + + def test_thinking_payload_merged_in(self): + kwargs = envelope.build_litellm_kwargs( + thinking_payload={"reasoning_effort": "high"}, **self._common() + ) + self.assertEqual(kwargs["reasoning_effort"], "high") + + def test_tools_and_choice_only_set_when_tools_present(self): + # tool_choice provided but no tools — both omitted. + kwargs = envelope.build_litellm_kwargs(tool_choice="auto", **self._common()) + self.assertNotIn("tools", kwargs) + self.assertNotIn("tool_choice", kwargs) + + kwargs = envelope.build_litellm_kwargs( + tools=[{"type": "function"}], + tool_choice="auto", + **self._common(), + ) + self.assertEqual(kwargs["tool_choice"], "auto") + + def test_extra_kwargs_override_defaults(self): + kwargs = envelope.build_litellm_kwargs( + extra_kwargs={"temperature": 0.1, "custom": "x"}, **self._common() + ) + self.assertEqual(kwargs["temperature"], 0.1) + self.assertEqual(kwargs["custom"], "x") + + +class TestEnvelopeBuilders(unittest.TestCase): + def test_success_strips_think_prefix(self): + env = envelope.build_success_envelope( + agent_id="a1", + adapter_type="X", + processed_response="scratchfinal", + ) + self.assertEqual(env["processed_response"], "final") + self.assertEqual(env["generated_text"], "final") + self.assertEqual(env["status_code"], 200) + self.assertIsNone(env["error_message"]) + + def test_success_attaches_model_name(self): + env = envelope.build_success_envelope( + agent_id="a1", + adapter_type="X", + processed_response="ok", + model_name="gpt-4", + ) + self.assertEqual(env["agent_specific_data"]["model_name"], "gpt-4") + + def test_error_default_status_500(self): + env = envelope.build_error_envelope( + agent_id="a1", adapter_type="X", error_message="boom" + ) + self.assertEqual(env["status_code"], 500) + self.assertEqual(env["error_message"], "boom") + self.assertIsNone(env["processed_response"]) + + def test_error_uses_supplied_status(self): + env = envelope.build_error_envelope( + agent_id="a1", + adapter_type="X", + error_message="bad", + status_code=400, + ) + self.assertEqual(env["status_code"], 400) + + +class TestBuildAgentSpecificData(unittest.TestCase): + def test_merges_completion_metadata(self): + data = envelope.build_agent_specific_data( + model_name="gpt-4", + invoked_parameters={"temperature": 0.7}, + completion_result={ + "usage": {"total_tokens": 12}, + "finish_reason": "stop", + "tool_calls": [{"id": "c1"}], + }, + ) + self.assertEqual(data["model_name"], "gpt-4") + self.assertEqual(data["usage"], {"total_tokens": 12}) + self.assertEqual(data["finish_reason"], "stop") + self.assertEqual(data["tool_calls"], [{"id": "c1"}]) + + def test_extra_dict_overrides(self): + data = envelope.build_agent_specific_data( + model_name="m", + invoked_parameters={}, + extra={"hackagent_call_id": "abc"}, + ) + self.assertEqual(data["hackagent_call_id"], "abc") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/router/test_provider_config.py b/tests/unit/router/test_provider_config.py new file mode 100644 index 00000000..26832eed --- /dev/null +++ b/tests/unit/router/test_provider_config.py @@ -0,0 +1,177 @@ +# Copyright 2026 - AI4I. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for ``hackagent/router/provider_config.py``.""" + +import logging +import unittest + +from hackagent.router.provider_config import ( + PROVIDER_CONFIGS, + default_thinking_translator, + get_provider_config, + ollama_thinking_translator, + openai_thinking_translator, +) +from hackagent.router.types import AgentTypeEnum + +logging.disable(logging.CRITICAL) + + +class TestDefaultThinkingTranslator(unittest.TestCase): + def test_none_returns_empty(self): + self.assertEqual(default_thinking_translator(None), {}) + + def test_dict_passes_through(self): + self.assertEqual( + default_thinking_translator({"budget_tokens": 1024}), + {"thinking": {"budget_tokens": 1024}}, + ) + + def test_string_becomes_reasoning_effort(self): + self.assertEqual( + default_thinking_translator("high"), {"reasoning_effort": "high"} + ) + + def test_true_becomes_enabled_dict(self): + self.assertEqual( + default_thinking_translator(True), + {"thinking": {"type": "enabled"}}, + ) + + def test_false_becomes_disabled_dict(self): + self.assertEqual( + default_thinking_translator(False), + {"thinking": {"type": "disabled"}}, + ) + + def test_int_becomes_budget(self): + self.assertEqual( + default_thinking_translator(2048), + {"thinking": {"type": "enabled", "budget_tokens": 2048}}, + ) + + +class TestOpenAIThinkingTranslator(unittest.TestCase): + def test_reasoning_model_true_maps_to_medium(self): + self.assertEqual( + openai_thinking_translator(True, model_name="openai/o1-mini"), + {"reasoning_effort": "medium"}, + ) + + def test_reasoning_model_false_omits(self): + self.assertEqual(openai_thinking_translator(False, model_name="openai/o3"), {}) + + def test_reasoning_model_string_passes_through(self): + self.assertEqual( + openai_thinking_translator("low", model_name="o1"), + {"reasoning_effort": "low"}, + ) + + def test_reasoning_model_dict_effort_extracted(self): + self.assertEqual( + openai_thinking_translator({"reasoning_effort": "high"}, model_name="o3"), + {"reasoning_effort": "high"}, + ) + + def test_non_reasoning_falls_back_to_default(self): + self.assertEqual( + openai_thinking_translator(True, model_name="openai/gpt-4"), + {"thinking": {"type": "enabled"}}, + ) + + def test_none_returns_empty(self): + self.assertEqual(openai_thinking_translator(None, model_name="o1"), {}) + + +class TestOllamaThinkingTranslator(unittest.TestCase): + def test_bool_passes_through_to_think(self): + self.assertEqual( + ollama_thinking_translator(True, model_name="llama3"), + {"think": True}, + ) + self.assertEqual( + ollama_thinking_translator(False, model_name="llama3"), + {"think": False}, + ) + + def test_str_passes_through_to_think(self): + self.assertEqual( + ollama_thinking_translator("low", model_name="llama3"), + {"think": "low"}, + ) + + def test_int_coerces_to_bool(self): + self.assertEqual( + ollama_thinking_translator(1, model_name="llama3"), + {"think": True}, + ) + self.assertEqual( + ollama_thinking_translator(0, model_name="llama3"), + {"think": False}, + ) + + def test_dict_disabled_type_maps_to_false(self): + self.assertEqual( + ollama_thinking_translator({"type": "disabled"}, model_name="llama3"), + {"think": False}, + ) + + def test_dict_enabled_type_maps_to_true(self): + self.assertEqual( + ollama_thinking_translator({"type": "enabled"}, model_name="llama3"), + {"think": True}, + ) + + def test_none_returns_empty(self): + self.assertEqual(ollama_thinking_translator(None, model_name="llama3"), {}) + + +class TestProviderConfigsTable(unittest.TestCase): + def test_openai_config_present_and_correct(self): + cfg = get_provider_config(AgentTypeEnum.OPENAI_SDK) + self.assertIsNotNone(cfg) + self.assertEqual(cfg.provider_prefix, "openai") + self.assertEqual(cfg.adapter_label, "OpenAIAgent") + self.assertIn("tools", cfg.extra_passthrough_keys) + + def test_ollama_config_present_and_correct(self): + cfg = get_provider_config(AgentTypeEnum.OLLAMA) + self.assertIsNotNone(cfg) + self.assertEqual(cfg.provider_prefix, "ollama_chat") + self.assertEqual(cfg.adapter_label, "OllamaAgent") + self.assertIn("top_k", cfg.extra_passthrough_keys) + self.assertIn("num_ctx", cfg.extra_passthrough_keys) + + def test_litellm_passthrough_has_no_prefix(self): + cfg = get_provider_config(AgentTypeEnum.LITELLM) + self.assertIsNotNone(cfg) + self.assertIsNone(cfg.provider_prefix) + + def test_langchain_uses_default_passthrough(self): + cfg = get_provider_config(AgentTypeEnum.LANGCHAIN) + self.assertIsNotNone(cfg) + self.assertIsNone(cfg.provider_prefix) + + def test_google_adk_not_in_lookup_table(self): + # ADK still uses per-instance custom-LLM registration; it's not + # in the static table yet. See LITELLM_ROUTER_REFACTOR_PLAN.md + # Phase E for the move into router/providers/. + self.assertIsNone(get_provider_config(AgentTypeEnum.GOOGLE_ADK)) + + def test_unknown_agent_type_returns_none(self): + self.assertIsNone(get_provider_config(AgentTypeEnum.UNKNOWN)) + + def test_provider_configs_dict_is_complete(self): + """All chat-completion agent types appear in the table.""" + expected = { + AgentTypeEnum.LITELLM, + AgentTypeEnum.OPENAI_SDK, + AgentTypeEnum.OLLAMA, + AgentTypeEnum.LANGCHAIN, + } + self.assertEqual(expected, set(PROVIDER_CONFIGS.keys())) + + +if __name__ == "__main__": + unittest.main() From 67cd38fa8c0109533b99b990a80d1a85f97f4953 Mon Sep 17 00:00:00 2001 From: Nicola Franco Date: Sat, 23 May 2026 16:38:38 +0200 Subject: [PATCH 03/23] refactor(router): drive adapters from ProviderConfig (#379 Phase B) Phase B of the LiteLLM router refactor. LiteLLMAgent.__init__ now accepts an optional ProviderConfig. OpenAIAgent and OllamaAgent look their config up from hackagent/router/provider_config.py instead of overriding ``_apply_thinking`` and ``PROVIDER_PREFIX``. The class-level ``PROVIDER_PREFIX`` path stays for backwards compatibility (and is still how ADKAgent injects its per-instance provider name). Phase C will use this same ProviderConfig lookup inside ``AgentRouter`` to bypass the adapter classes entirely for chat-completion agent types. Co-Authored-By: Claude Opus 4.7 (1M context) --- hackagent/router/adapters/litellm.py | 74 +++++++++++++++---------- hackagent/router/adapters/ollama.py | 35 ++++-------- hackagent/router/adapters/openai.py | 81 ++++++++-------------------- 3 files changed, 76 insertions(+), 114 deletions(-) diff --git a/hackagent/router/adapters/litellm.py b/hackagent/router/adapters/litellm.py index 4fb0b280..0606076d 100644 --- a/hackagent/router/adapters/litellm.py +++ b/hackagent/router/adapters/litellm.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional from hackagent.router import envelope as _envelope +from hackagent.router.provider_config import ProviderConfig from .base import ChatCompletionsAgent, AdapterConfigurationError @@ -124,9 +125,16 @@ class LiteLLMAgent(ChatCompletionsAgent): # When set, the model string passed to LiteLLM is prefixed with # ``"{PROVIDER_PREFIX}/"`` unless it already starts with a known # LiteLLM provider prefix. ``None`` means "let LiteLLM auto-detect". + # Subclasses can either set this class attribute or pass a + # ``ProviderConfig`` to ``__init__`` (the new Phase B path). PROVIDER_PREFIX: Optional[str] = None - def __init__(self, id: str, config: Dict[str, Any]): + def __init__( + self, + id: str, + config: Dict[str, Any], + provider_config: Optional[ProviderConfig] = None, + ): """ Initialise the adapter from configuration. @@ -142,8 +150,15 @@ def __init__(self, id: str, config: Dict[str, Any]): definitions, passed through to LiteLLM. - ``thinking`` (optional): see class docstring. - ``extra_body`` (optional): provider-specific request body. + provider_config: Optional :class:`ProviderConfig` looked up + from ``hackagent.router.provider_config``. When supplied, + it takes precedence over the class-level + ``PROVIDER_PREFIX`` and ``_apply_thinking`` override. This + is the path Phase C will use to drive the adapter from + ``router.py`` without subclassing. """ super().__init__(id, config) + self._provider_config: Optional[ProviderConfig] = provider_config # Require model name self.model_name = self._require_config_key("name", LiteLLMConfigurationError) @@ -186,13 +201,18 @@ def __init__(self, id: str, config: Dict[str, Any]): def _resolve_litellm_model(self, raw_model: str) -> str: """Return the model string to pass to ``litellm.completion``. - Delegates to :func:`hackagent.router.envelope.resolve_litellm_model`. - Subclasses (notably ADKAgent) override this entirely to inject a - per-instance provider prefix. + Honors ``self._provider_config.provider_prefix`` when a + :class:`ProviderConfig` was supplied, otherwise falls back to the + class-level ``PROVIDER_PREFIX``. Subclasses (notably + :class:`ADKAgent`) override this entirely to inject a per-instance + provider prefix. """ - return _envelope.resolve_litellm_model( - raw_model, provider_prefix=self.PROVIDER_PREFIX + prefix = ( + self._provider_config.provider_prefix + if self._provider_config is not None + else self.PROVIDER_PREFIX ) + return _envelope.resolve_litellm_model(raw_model, provider_prefix=prefix) def _default_api_key_env_var(self) -> Optional[str]: """Return the env var used as a fallback when no API key is configured.""" @@ -207,32 +227,28 @@ def _default_api_key_env_var(self) -> Optional[str]: def _apply_thinking(self, litellm_params: Dict[str, Any], thinking: Any) -> None: """Translate the unified ``thinking`` value into LiteLLM params. - The default implementation mirrors LiteLLM's own conventions: - - dict: forwarded verbatim as ``thinking=...`` - - str: forwarded as ``reasoning_effort=...`` - - int: forwarded as ``thinking={"type": "enabled", - "budget_tokens": int}`` - - True/False: forwarded as - ``thinking={"type": "enabled" | "disabled"}`` - Subclasses override this method when their provider needs different - field names (e.g. Ollama's ``think``). + When a :class:`ProviderConfig` was supplied, its + ``thinking_translator`` is consulted; otherwise the default + generic translator is used. Subclasses may still override this + method for backwards compatibility, but the recommended path is + to supply a ``ProviderConfig`` instead. """ if thinking is None: return - if isinstance(thinking, dict): - litellm_params["thinking"] = dict(thinking) - elif isinstance(thinking, str): - litellm_params["reasoning_effort"] = thinking - elif isinstance(thinking, bool): - litellm_params["thinking"] = {"type": "enabled" if thinking else "disabled"} - elif isinstance(thinking, int): - litellm_params["thinking"] = { - "type": "enabled", - "budget_tokens": int(thinking), - } - else: - # Best-effort passthrough for unknown shapes. - litellm_params["thinking"] = thinking + if self._provider_config is not None: + payload = self._provider_config.thinking_translator( + thinking, model_name=self.litellm_model + ) + if payload: + litellm_params.update(payload) + return + + # Fallback path for adapters built without a ProviderConfig. + from hackagent.router.provider_config import default_thinking_translator + + payload = default_thinking_translator(thinking, model_name=self.litellm_model) + if payload: + litellm_params.update(payload) # ---- request preparation -------------------------------------------- diff --git a/hackagent/router/adapters/ollama.py b/hackagent/router/adapters/ollama.py index 77b3fe49..5b4931fc 100644 --- a/hackagent/router/adapters/ollama.py +++ b/hackagent/router/adapters/ollama.py @@ -15,6 +15,9 @@ from hackagent.logger import get_logger from typing import Any, Dict, List, Optional +from hackagent.router.provider_config import get_provider_config +from hackagent.router.types import AgentTypeEnum + from .base import AdapterConfigurationError from .litellm import LiteLLMAgent @@ -45,7 +48,6 @@ class OllamaAgent(LiteLLMAgent): """ ADAPTER_TYPE = "OllamaAgent" - PROVIDER_PREFIX = "ollama_chat" DEFAULT_ENDPOINT = "http://localhost:11434" def __init__(self, id: str, config: Dict[str, Any]): @@ -57,7 +59,11 @@ def __init__(self, id: str, config: Dict[str, Any]): config = {**config, "endpoint": effective_endpoint} try: - super().__init__(id, config) + super().__init__( + id, + config, + provider_config=get_provider_config(AgentTypeEnum.OLLAMA), + ) except AdapterConfigurationError as e: raise OllamaConfigurationError(str(e)) from e @@ -97,29 +103,8 @@ def _get_excluded_request_keys(self) -> set: base = super()._get_excluded_request_keys() return base | {"top_k", "num_ctx", "stream", "system"} - def _apply_thinking(self, litellm_params: Dict[str, Any], thinking: Any) -> None: - """Translate ``thinking`` into Ollama's native ``think`` field. - - Ollama accepts a boolean (``true``/``false``) or a reasoning level - such as ``"low"``/``"medium"``/``"high"`` depending on the model. - Dicts and ints are coerced into the most reasonable boolean. - """ - if thinking is None: - return - if isinstance(thinking, bool): - litellm_params["think"] = thinking - elif isinstance(thinking, str): - litellm_params["think"] = thinking - elif isinstance(thinking, int): - litellm_params["think"] = thinking > 0 - elif isinstance(thinking, dict): - kind = (thinking.get("type") or "").lower() - if kind == "disabled": - litellm_params["think"] = False - else: - litellm_params["think"] = True - else: - litellm_params["think"] = bool(thinking) + # Thinking translation is driven by the ``OLLAMA`` ``ProviderConfig`` + # (see ``hackagent/router/provider_config.py``); no override needed. # ---- diagnostics passthroughs (kept for callers/tests) -------------- diff --git a/hackagent/router/adapters/openai.py b/hackagent/router/adapters/openai.py index 597ec050..f87f6599 100644 --- a/hackagent/router/adapters/openai.py +++ b/hackagent/router/adapters/openai.py @@ -6,14 +6,17 @@ The OpenAI agent type used to talk to the OpenAI SDK directly. As of issue #379 every chat-completion adapter routes through LiteLLM, so this -class is now a thin specialisation of :class:`LiteLLMAgent` that pins the -provider prefix to ``openai`` and translates the unified ``thinking`` knob -into OpenAI's ``reasoning_effort`` field for the o-series models. +class is now a thin specialisation of :class:`LiteLLMAgent` that pulls +its provider prefix + thinking translator out of the +:class:`ProviderConfig` table. """ from hackagent.logger import get_logger from typing import Any, Dict +from hackagent.router.provider_config import get_provider_config +from hackagent.router.types import AgentTypeEnum + from .base import AdapterConfigurationError from .litellm import LiteLLMAgent @@ -29,30 +32,23 @@ class OpenAIConfigurationError(AdapterConfigurationError): logger = get_logger(__name__) -# OpenAI reasoning models that natively understand ``reasoning_effort``. -# ``thinking=True`` defaults to "medium" for these; for other models we fall -# back to LiteLLM's generic ``thinking`` payload. -_OPENAI_REASONING_MODEL_PREFIXES = ("o1", "o3", "o4", "gpt-5", "gpt-6") - - class OpenAIAgent(LiteLLMAgent): """ Adapter for OpenAI-compatible chat endpoints. - Configured via the ``OPENAI_SDK`` agent type. Internally uses LiteLLM, - so any OpenAI-compatible server (the official API, a local model server - exposing ``/v1/chat/completions``, OpenRouter, etc.) works as the - endpoint. + Configured via the ``OPENAI_SDK`` agent type. Internally uses + LiteLLM, so any OpenAI-compatible server (the official API, a local + model server exposing ``/v1/chat/completions``, OpenRouter, etc.) + works as the endpoint. Reasoning / "thinking": - Set ``thinking`` in the adapter config or per request to enable or - disable the model's reasoning. For the o-series and newer GPT - reasoning models the value is translated to ``reasoning_effort`` - (low/medium/high). + Driven by the :class:`ProviderConfig` for + ``AgentTypeEnum.OPENAI_SDK`` — for the o-series and newer GPT + reasoning models the unified ``thinking`` value is translated + to ``reasoning_effort``. """ ADAPTER_TYPE = "OpenAIAgent" - PROVIDER_PREFIX = "openai" DEFAULT_TEMPERATURE = 1.0 def __init__(self, id: str, config: Dict[str, Any]): @@ -63,56 +59,21 @@ def __init__(self, id: str, config: Dict[str, Any]): config = {**config, "name": config.get("name", "default")} try: - super().__init__(id, config) + super().__init__( + id, + config, + provider_config=get_provider_config(AgentTypeEnum.OPENAI_SDK), + ) except AdapterConfigurationError as e: # Re-raise as the OpenAI-flavoured subclass so legacy callers # that catch OpenAIConfigurationError keep working. raise OpenAIConfigurationError(str(e)) from e - # For custom endpoints without an API key, use a placeholder so the - # OpenAI client (under LiteLLM's hood) doesn't error out. + # For custom endpoints without an API key, use a placeholder so + # the OpenAI client (under LiteLLM's hood) doesn't error out. if not self.actual_api_key and self.api_base_url: self.actual_api_key = "not-required" self.logger.info( f"No API key configured for custom endpoint " f"'{self.api_base_url}', using placeholder" ) - - # ---- thinking translation ------------------------------------------- - - def _is_reasoning_model(self) -> bool: - bare = self.model_name.split("/")[-1] - return bare.startswith(_OPENAI_REASONING_MODEL_PREFIXES) - - def _apply_thinking(self, litellm_params: Dict[str, Any], thinking: Any) -> None: - """Map ``thinking`` to ``reasoning_effort`` for OpenAI reasoning models. - - Non-reasoning models fall back to LiteLLM's default ``thinking`` - passthrough, so callers can still attach arbitrary provider payload - if they need to. - """ - if thinking is None: - return - - if self._is_reasoning_model(): - if thinking is True: - litellm_params["reasoning_effort"] = "medium" - elif thinking is False: - # Explicit disable: omit the parameter entirely so the - # provider falls back to whatever its server-side default is. - # (OpenAI doesn't currently accept reasoning_effort="off".) - return - elif isinstance(thinking, str): - litellm_params["reasoning_effort"] = thinking - elif isinstance(thinking, dict): - effort = thinking.get("reasoning_effort") or thinking.get("effort") - if effort: - litellm_params["reasoning_effort"] = effort - else: - litellm_params["thinking"] = dict(thinking) - else: - super()._apply_thinking(litellm_params, thinking) - return - - # Non-reasoning model: defer to the generic translation. - super()._apply_thinking(litellm_params, thinking) From c14b2e0b262117e87ae0ef5b608f305052d41187 Mon Sep 17 00:00:00 2001 From: Nicola Franco Date: Sat, 23 May 2026 16:44:46 +0200 Subject: [PATCH 04/23] refactor(router): hoist call path into AgentRouter (#379 Phase C) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase C of the LiteLLM router refactor. ``AgentRouter.route_request`` now dispatches chat-completion AgentTypes (LITELLM, OPENAI_SDK, OLLAMA, LANGCHAIN) directly through ``litellm.completion`` via the new ``_dispatch_via_litellm`` method, looking up the provider prefix, thinking translator, and passthrough keys from ``hackagent/router/provider_config.py``. The adapter classes are still instantiated and still expose ``handle_request`` for backwards compatibility — that's the path GOOGLE_ADK and any future protocol-specific AgentType continue to use. For chat types, ``handle_request`` is no longer on the hot path; the adapter instance is consulted only for its already-resolved model name, endpoint, API key, and generation defaults. The HackAgent envelope dict shape is byte-identical to the previous adapter-driven path (verified by leaving all existing adapter tests green plus six new tests covering the new dispatch path). Co-Authored-By: Claude Opus 4.7 (1M context) --- hackagent/router/router.py | 253 ++++++++++++++++++++++++++++- tests/unit/router/test_dispatch.py | 218 +++++++++++++++++++++++++ 2 files changed, 468 insertions(+), 3 deletions(-) create mode 100644 tests/unit/router/test_dispatch.py diff --git a/hackagent/router/router.py b/hackagent/router/router.py index 3c3f36eb..19a8254f 100644 --- a/hackagent/router/router.py +++ b/hackagent/router/router.py @@ -2,17 +2,19 @@ # SPDX-License-Identifier: Apache-2.0 import logging -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, List, Optional, Tuple, Type from hackagent.server.storage.base import AgentRecord, StorageBackend +from hackagent.router import envelope as _envelope from hackagent.router.adapters.base import Agent +from hackagent.router.provider_config import ProviderConfig, get_provider_config from hackagent.router.types import AgentTypeEnum # Adapter imports - these are imported at module level for backwards compatibility # with test patching (tests patch hackagent.router.router.LiteLLMAgent etc.) # The actual heavy dependency (litellm) is lazy-loaded within LiteLLMAgent from hackagent.router.adapters import ADKAgent -from hackagent.router.adapters.litellm import LiteLLMAgent +from hackagent.router.adapters.litellm import LiteLLMAgent, _get_litellm from hackagent.router.adapters.openai import OpenAIAgent from hackagent.router.adapters.ollama import OllamaAgent @@ -77,6 +79,11 @@ def __init__( """ self.backend = backend self._agent_registry: dict = {} + # Tracks the AgentTypeEnum each registration was created under, so + # ``route_request`` can pick the right dispatch path (chat + # AgentTypes go through ``_dispatch_via_litellm`` directly; + # everything else still calls ``adapter.handle_request``). + self._agent_types: Dict[str, AgentTypeEnum] = {} context = self.backend.get_context() self.organization_id = context.org_id @@ -231,6 +238,7 @@ def _configure_and_instantiate_adapter( f"ROUTER_DEBUG: Called adapter_class. Resulting instance: {adapter_instance}, type: {type(adapter_instance)}" ) self._agent_registry[registration_key] = adapter_instance + self._agent_types[registration_key] = agent_type logger.info( f"Agent '{name}' (Backend ID: {registration_key}, Type: {agent_type.value}) " f"successfully initialized and registered with adapter {adapter_class.__name__}. " @@ -351,8 +359,26 @@ def route_request( registration_key=registration_key, ) + agent_type = self._agent_types.get(registration_key) + provider_config = ( + get_provider_config(agent_type) if agent_type is not None else None + ) + try: - response = agent_instance.handle_request(request_data) + if provider_config is not None: + # Chat-completion AgentType: drive LiteLLM directly via the + # router instead of going through the adapter's + # ``handle_request``. Phase C of #379. + response = self._dispatch_via_litellm( + registration_key=registration_key, + agent_instance=agent_instance, + provider_config=provider_config, + request_data=request_data, + ) + else: + # ADK and other gap-filler AgentTypes still use the + # adapter path. + response = agent_instance.handle_request(request_data) logger.debug( f"Successfully routed request for agent key: {registration_key}" ) @@ -374,3 +400,224 @@ def route_request( raw_request=request_data, registration_key=registration_key, ) + + # ------------------------------------------------------------------ # + # Phase C: LiteLLM dispatch path + # ------------------------------------------------------------------ # + + @staticmethod + def _extract_messages( + request_data: Dict[str, Any], + ) -> Tuple[Optional[List[Dict[str, str]]], Optional[str]]: + """Return ``(messages, error_msg)`` for a chat-completion request.""" + messages = request_data.get("messages") + prompt = request_data.get("prompt") + if messages: + return messages, None + if prompt: + return [{"role": "user", "content": prompt}], None + return ( + None, + "Request data must include either 'messages' or 'prompt' field.", + ) + + def _dispatch_via_litellm( + self, + *, + registration_key: str, + agent_instance: Agent, + provider_config: ProviderConfig, + request_data: Dict[str, Any], + ) -> Dict[str, Any]: + """Route a chat-completion request through ``litellm.completion``. + + Reads the model string, endpoint, API key, and generation + defaults off the already-configured adapter instance, looks up + the provider-specific thinking translator from + ``provider_config``, then calls LiteLLM directly. The response + is shaped via :mod:`hackagent.router.envelope` so downstream + consumers see exactly the same dict as the adapter-driven path. + """ + adapter_label = provider_config.adapter_label or agent_instance.ADAPTER_TYPE + model_name = getattr(agent_instance, "litellm_model", None) or getattr( + agent_instance, "model_name", None + ) + if model_name is None: + return _envelope.build_error_envelope( + agent_id=registration_key, + adapter_type=adapter_label, + error_message=( + f"Adapter for '{registration_key}' has no model name; " + "cannot dispatch via LiteLLM." + ), + status_code=500, + raw_request=request_data, + ) + + messages, validation_error = self._extract_messages(request_data) + if validation_error: + return _envelope.build_error_envelope( + agent_id=registration_key, + adapter_type=adapter_label, + error_message=validation_error, + status_code=400, + raw_request=request_data, + ) + + # Generation defaults come from the adapter instance. + max_tokens = request_data.get( + "max_tokens", getattr(agent_instance, "default_max_tokens", 100) + ) + temperature = request_data.get( + "temperature", getattr(agent_instance, "default_temperature", 0.8) + ) + top_p = request_data.get( + "top_p", getattr(agent_instance, "default_top_p", 0.95) + ) + + # Translate the unified thinking knob via the provider config. + thinking = request_data.get( + "thinking", getattr(agent_instance, "default_thinking", None) + ) + thinking_payload = provider_config.thinking_translator( + thinking, model_name=model_name + ) + + # Provider-specific pass-throughs (tools, extras, …) plus any + # adapter-specific extra knobs (top_k, num_ctx for Ollama, etc.). + tools = request_data.get( + "tools", getattr(agent_instance, "default_tools", None) + ) + tool_choice = request_data.get( + "tool_choice", getattr(agent_instance, "default_tool_choice", None) + ) + extra_body = request_data.get( + "extra_body", getattr(agent_instance, "default_extra_body", None) + ) + + excluded_keys = { + "prompt", + "messages", + "max_tokens", + "temperature", + "top_p", + "tools", + "tool_choice", + "thinking", + "extra_body", + } + extra_kwargs: Dict[str, Any] = { + k: v for k, v in request_data.items() if k not in excluded_keys + } + # Add adapter-instance defaults for the extra passthrough keys. + for key in provider_config.extra_passthrough_keys: + if key in request_data or key in extra_kwargs: + continue + default = getattr(agent_instance, f"default_{key}", None) + if default is not None: + extra_kwargs[key] = default + + kwargs = _envelope.build_litellm_kwargs( + model=model_name, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + api_base=getattr(agent_instance, "api_base_url", None), + api_key=getattr(agent_instance, "actual_api_key", None), + tools=tools, + tool_choice=tool_choice, + extra_body=extra_body, + thinking_payload=thinking_payload, + extra_kwargs=extra_kwargs, + ) + + litellm, available = _get_litellm() + if not available: + return _envelope.build_error_envelope( + agent_id=registration_key, + adapter_type=adapter_label, + error_message="litellm is not installed", + status_code=500, + raw_request=request_data, + model_name=model_name, + ) + + try: + response = litellm.completion(**kwargs) + except Exception as exc: + logger.exception( + f"LiteLLM dispatch failed for agent {registration_key} " + f"(model={model_name}): {exc}" + ) + return _envelope.build_error_envelope( + agent_id=registration_key, + adapter_type=adapter_label, + error_message=f"{adapter_label} error ({type(exc).__name__}): {exc}", + status_code=500, + raw_request=request_data, + model_name=model_name, + ) + + text = _envelope.extract_text_from_response(response, model_name=model_name) + if isinstance(text, str) and text.startswith("[GENERATION_ERROR:"): + return _envelope.build_error_envelope( + agent_id=registration_key, + adapter_type=adapter_label, + error_message=f"{adapter_label} generation error: {text}", + status_code=500, + raw_request=request_data, + model_name=model_name, + ) + + # Build completion_result + agent_specific_data the same way + # ChatCompletionsAgent did, so the envelope dict matches byte + # for byte. + invoked_parameters: Dict[str, Any] = { + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + } + invoked_parameters.update(extra_kwargs) + if tools is not None: + invoked_parameters["tools"] = tools + if tool_choice is not None: + invoked_parameters["tool_choice"] = tool_choice + + completion_result: Dict[str, Any] = { + "success": True, + "content": text, + "raw_response": response, + } + tool_calls = _envelope.extract_tool_calls(response) + if tool_calls is not None: + completion_result["tool_calls"] = tool_calls + try: + completion_result["finish_reason"] = response.choices[0].finish_reason + except (AttributeError, IndexError, TypeError): + pass + try: + if response.usage is not None: + completion_result["usage"] = response.usage.model_dump() + except AttributeError: + pass + try: + completion_result["provider_model"] = response.model + except AttributeError: + pass + + agent_specific_data = _envelope.build_agent_specific_data( + model_name=model_name, + invoked_parameters=invoked_parameters, + completion_result=completion_result, + ) + + return _envelope.build_success_envelope( + agent_id=registration_key, + adapter_type=adapter_label, + processed_response=text, + raw_request=request_data, + raw_response_body=response, + agent_specific_data=agent_specific_data, + model_name=model_name, + ) diff --git a/tests/unit/router/test_dispatch.py b/tests/unit/router/test_dispatch.py new file mode 100644 index 00000000..00ef4623 --- /dev/null +++ b/tests/unit/router/test_dispatch.py @@ -0,0 +1,218 @@ +# Copyright 2026 - AI4I. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Unit tests for ``AgentRouter._dispatch_via_litellm`` — Phase C of #379. + +The dispatch path lives on the router itself and is exercised end-to-end +by going through ``AgentRouter.route_request``. These tests mock the +backend so the router can be initialised with a real adapter instance, +then patch ``litellm.completion`` to control the response. +""" + +import logging +import unittest +import uuid +from unittest.mock import MagicMock, patch + +from hackagent.router.router import AgentRouter +from hackagent.router.types import AgentTypeEnum +from hackagent.server.storage.base import OrganizationContext + +logging.disable(logging.CRITICAL) + + +def _make_litellm_response(content: str = "ok") -> MagicMock: + response = MagicMock() + choice = MagicMock() + message = MagicMock() + message.content = content + message.tool_calls = None + message.reasoning_content = None + message.reasoning = None + message.provider_specific_fields = None + choice.message = message + choice.finish_reason = "stop" + response.choices = [choice] + response.usage = MagicMock(model_dump=MagicMock(return_value={"total_tokens": 7})) + response.model = "openai/gpt-4" + return response + + +def _make_context(org_id=None, user_id="test_user"): + ctx = MagicMock(spec=OrganizationContext) + ctx.org_id = org_id or uuid.uuid4() + ctx.user_id = user_id + return ctx + + +def _make_agent_rec(*, agent_id, name, agent_type_str, endpoint, metadata=None): + rec = MagicMock() + rec.id = agent_id + rec.name = name + rec.agent_type = agent_type_str + rec.endpoint = endpoint + rec.metadata = metadata or {} + rec.organization = uuid.uuid4() + rec.owner = "local" + return rec + + +def _make_backend(*, agent_id, name, agent_type_str, endpoint, metadata=None): + backend = MagicMock() + backend.get_context.return_value = _make_context() + backend.get_api_key.return_value = None + backend.create_or_update_agent.return_value = _make_agent_rec( + agent_id=agent_id, + name=name, + agent_type_str=agent_type_str, + endpoint=endpoint, + metadata=metadata, + ) + return backend + + +class TestDispatchViaLiteLLM(unittest.TestCase): + """Verify the chat path goes through litellm.completion directly.""" + + def _make_router_for_openai(self): + agent_id = uuid.uuid4() + backend = _make_backend( + agent_id=agent_id, + name="gpt-4-router-test", + agent_type_str=AgentTypeEnum.OPENAI_SDK.value, + endpoint="", + metadata={"name": "gpt-4"}, + ) + router = AgentRouter( + backend=backend, + name="gpt-4-router-test", + agent_type=AgentTypeEnum.OPENAI_SDK, + endpoint="", + metadata={"name": "gpt-4"}, + adapter_operational_config={"name": "gpt-4"}, + ) + return router, str(agent_id) + + @patch("litellm.completion") + def test_chat_request_goes_through_litellm_completion(self, mock_completion): + """OPENAI_SDK request lands at litellm.completion via the router.""" + mock_completion.return_value = _make_litellm_response("hi there") + router, reg_key = self._make_router_for_openai() + + # Patch the adapter's handle_request so we can verify it's NOT called. + adapter = router.get_agent_instance(reg_key) + adapter.handle_request = MagicMock(name="should_not_be_called") + + response = router.route_request(reg_key, {"prompt": "hi"}) + + self.assertEqual(response["status_code"], 200) + self.assertEqual(response["generated_text"], "hi there") + self.assertEqual(response["adapter_type"], "OpenAIAgent") + mock_completion.assert_called_once() + adapter.handle_request.assert_not_called() + kwargs = mock_completion.call_args.kwargs + self.assertEqual(kwargs["model"], "openai/gpt-4") + self.assertEqual(kwargs["messages"], [{"role": "user", "content": "hi"}]) + + @patch("litellm.completion") + def test_missing_prompt_returns_400_envelope_from_router(self, mock_completion): + router, reg_key = self._make_router_for_openai() + response = router.route_request(reg_key, {"temperature": 0.5}) + self.assertEqual(response["status_code"], 400) + self.assertIn( + "Request data must include either 'messages' or 'prompt'", + response["error_message"], + ) + mock_completion.assert_not_called() + + @patch("litellm.completion") + def test_litellm_exception_becomes_500_envelope(self, mock_completion): + mock_completion.side_effect = RuntimeError("boom") + router, reg_key = self._make_router_for_openai() + response = router.route_request(reg_key, {"prompt": "hi"}) + self.assertEqual(response["status_code"], 500) + self.assertIn("boom", response["error_message"]) + self.assertEqual(response["adapter_type"], "OpenAIAgent") + + @patch("litellm.completion") + def test_thinking_translation_applied_through_router(self, mock_completion): + """Per-request ``thinking`` flag is translated by the ProviderConfig.""" + mock_completion.return_value = _make_litellm_response("ok") + agent_id = uuid.uuid4() + backend = _make_backend( + agent_id=agent_id, + name="o1-mini", + agent_type_str=AgentTypeEnum.OPENAI_SDK.value, + endpoint="", + metadata={"name": "o1-mini"}, + ) + router = AgentRouter( + backend=backend, + name="o1-mini", + agent_type=AgentTypeEnum.OPENAI_SDK, + endpoint="", + metadata={"name": "o1-mini"}, + adapter_operational_config={"name": "o1-mini"}, + ) + reg_key = str(agent_id) + router.route_request(reg_key, {"prompt": "hi", "thinking": True}) + kwargs = mock_completion.call_args.kwargs + self.assertEqual(kwargs.get("reasoning_effort"), "medium") + + def test_unknown_registration_key_returns_404_envelope(self): + router, _ = self._make_router_for_openai() + response = router.route_request("nonexistent-key", {"prompt": "hi"}) + # The legacy router-level AgentNotFound envelope uses + # ``raw_response_status`` rather than ``status_code``. Phase F may + # unify these; we just check the actual current behaviour here. + self.assertEqual(response["raw_response_status"], 404) + self.assertIn("Agent not found", response["error_message"]) + + +class TestDispatchADKBypassesLiteLLM(unittest.TestCase): + """Verify ADK requests still flow through the adapter's handle_request.""" + + def test_adk_uses_adapter_handle_request_not_litellm(self): + agent_id = uuid.uuid4() + backend = _make_backend( + agent_id=agent_id, + name="my_app", + agent_type_str=AgentTypeEnum.GOOGLE_ADK.value, + endpoint="http://fake-adk.com", + metadata={"name": "my_app"}, + ) + router = AgentRouter( + backend=backend, + name="my_app", + agent_type=AgentTypeEnum.GOOGLE_ADK, + endpoint="http://fake-adk.com", + metadata={"name": "my_app"}, + adapter_operational_config={ + "name": "my_app", + "endpoint": "http://fake-adk.com", + "user_id": "alice", + }, + ) + reg_key = str(agent_id) + adapter = router.get_agent_instance(reg_key) + adapter.handle_request = MagicMock( + return_value={ + "status_code": 200, + "generated_text": "adk reply", + "adapter_type": "ADKAgent", + "agent_id": reg_key, + "error_message": None, + } + ) + + with patch("litellm.completion") as mock_completion: + response = router.route_request(reg_key, {"prompt": "hi"}) + + self.assertEqual(response["generated_text"], "adk reply") + adapter.handle_request.assert_called_once() + mock_completion.assert_not_called() + + +if __name__ == "__main__": + unittest.main() From 68746c755e69e6a899bafb1e2c6600f4a76517da Mon Sep 17 00:00:00 2001 From: Nicola Franco Date: Sat, 23 May 2026 16:49:13 +0200 Subject: [PATCH 05/23] feat(router): capture I/O via LiteLLM CustomLogger (#379 Phase D) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase D of the LiteLLM router refactor. Adds ``hackagent/router/tracking_logger.py`` — a ``CustomLogger`` subclass that hooks ``log_pre_api_call``, ``log_success_event``, and ``log_failure_event``, emitting structured records via ``hackagent.logger`` that downstream sinks (TUI, dashboard) can pick up. ``AgentRouter.__init__`` registers the logger on ``litellm.callbacks`` exactly once per process. ``_dispatch_via_litellm`` attaches ``metadata={"hackagent_agent_id": ..., "hackagent_adapter_type": ...}`` to every call so the logger can filter HackAgent-owned traffic and correlate input ↔ output ↔ cost via LiteLLM's call_id. Caller-supplied ``metadata`` (e.g. trace_id for Langfuse/OTEL) is merged in and wins on collision. Co-Authored-By: Claude Opus 4.7 (1M context) --- hackagent/router/router.py | 20 ++ hackagent/router/tracking_logger.py | 243 ++++++++++++++++++++++ tests/unit/router/test_dispatch.py | 29 +++ tests/unit/router/test_tracking_logger.py | 129 ++++++++++++ 4 files changed, 421 insertions(+) create mode 100644 hackagent/router/tracking_logger.py create mode 100644 tests/unit/router/test_tracking_logger.py diff --git a/hackagent/router/router.py b/hackagent/router/router.py index 19a8254f..1f83523c 100644 --- a/hackagent/router/router.py +++ b/hackagent/router/router.py @@ -6,6 +6,7 @@ from hackagent.server.storage.base import AgentRecord, StorageBackend from hackagent.router import envelope as _envelope +from hackagent.router import tracking_logger as _tracking_logger from hackagent.router.adapters.base import Agent from hackagent.router.provider_config import ProviderConfig, get_provider_config from hackagent.router.types import AgentTypeEnum @@ -85,6 +86,10 @@ def __init__( # everything else still calls ``adapter.handle_request``). self._agent_types: Dict[str, AgentTypeEnum] = {} + # Phase D: register the LiteLLM CustomLogger that captures input + # and output for every HackAgent-owned call. Idempotent. + _tracking_logger.ensure_registered() + context = self.backend.get_context() self.organization_id = context.org_id self.user_id_str = context.user_id @@ -505,6 +510,7 @@ def _dispatch_via_litellm( "tool_choice", "thinking", "extra_body", + "metadata", } extra_kwargs: Dict[str, Any] = { k: v for k, v in request_data.items() if k not in excluded_keys @@ -517,6 +523,20 @@ def _dispatch_via_litellm( if default is not None: extra_kwargs[key] = default + # Phase D: attach correlation metadata so the registered + # HackAgentTrackingLogger can join input ↔ output ↔ cost. Any + # ``metadata`` already in ``request_data`` is preserved and + # augmented (caller-supplied keys win on collision so user + # tracing identifiers aren't overwritten). + caller_metadata = request_data.get("metadata") + merged_metadata: Dict[str, Any] = { + _tracking_logger.HACKAGENT_AGENT_ID_KEY: registration_key, + _tracking_logger.HACKAGENT_ADAPTER_TYPE_KEY: adapter_label, + } + if isinstance(caller_metadata, dict): + merged_metadata.update(caller_metadata) + extra_kwargs["metadata"] = merged_metadata + kwargs = _envelope.build_litellm_kwargs( model=model_name, messages=messages, diff --git a/hackagent/router/tracking_logger.py b/hackagent/router/tracking_logger.py new file mode 100644 index 00000000..efaaffd6 --- /dev/null +++ b/hackagent/router/tracking_logger.py @@ -0,0 +1,243 @@ +# Copyright 2026 - AI4I. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +LiteLLM callback that captures every ``litellm.completion`` call. + +LiteLLM exposes a ``CustomLogger`` base class with hook methods that +fire pre-call, on success, and on failure. We register a single +:class:`HackAgentTrackingLogger` instance on ``litellm.callbacks`` and +attach ``metadata`` to every call so the logger can correlate the I/O +back to the originating HackAgent registration. + +The logger only emits structured records to ``hackagent.logger``; it +does not write to the backend storage directly. Downstream sinks (TUI +event bus, dashboard, file logs) can pick the records up from there. +""" + +from __future__ import annotations + +from typing import Any, Dict, Optional + +from hackagent.logger import get_logger + + +# Singleton — one logger per process so we don't double-register on +# ``litellm.callbacks``. ``ensure_registered`` is idempotent and is +# called from :meth:`AgentRouter.__init__`. The instance type is +# dynamically built by ``_build_handler_class`` once litellm is +# importable, so we annotate it as ``Optional[Any]`` here. +_REGISTERED: bool = False +_LOGGER_INSTANCE: Optional[Any] = None +_TRACKING_LOGGER = get_logger("hackagent.router.tracking_logger") + +# Sentinel metadata keys the logger uses to identify HackAgent-owned +# calls. Other tools wiring their own ``litellm.callbacks`` won't see +# their calls double-logged because we filter on this key. +HACKAGENT_AGENT_ID_KEY = "hackagent_agent_id" +HACKAGENT_ADAPTER_TYPE_KEY = "hackagent_adapter_type" + + +def _try_import_custom_logger() -> Optional[type]: + """Return ``litellm.integrations.custom_logger.CustomLogger`` or ``None``.""" + try: + from litellm.integrations.custom_logger import CustomLogger + + return CustomLogger + except ImportError: + return None + + +def _extract_hackagent_metadata(kwargs: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Pull the HackAgent metadata block out of a LiteLLM callback ``kwargs``. + + LiteLLM nests user-supplied ``metadata`` under ``litellm_params``. We + only return a dict when the metadata carries our sentinel key so + that callbacks fired by other libraries' calls don't get logged. + """ + litellm_params = kwargs.get("litellm_params") or {} + metadata = ( + litellm_params.get("metadata") if isinstance(litellm_params, dict) else None + ) + if not isinstance(metadata, dict): + return None + if HACKAGENT_AGENT_ID_KEY not in metadata: + return None + return metadata + + +def _extract_response_text(response_obj: Any) -> Optional[str]: + """Best-effort string extraction from a LiteLLM ``ModelResponse``.""" + try: + message = response_obj.choices[0].message + except (AttributeError, IndexError, TypeError): + return None + content = getattr(message, "content", None) + if isinstance(content, str) and content: + return content + reasoning = getattr(message, "reasoning_content", None) or getattr( + message, "reasoning", None + ) + if isinstance(reasoning, str) and reasoning: + return reasoning + return None + + +def _last_user_message(kwargs: Dict[str, Any]) -> Optional[str]: + messages = kwargs.get("messages") or [] + if not isinstance(messages, list): + return None + for msg in reversed(messages): + if not isinstance(msg, dict): + continue + if msg.get("role") != "user": + continue + content = msg.get("content") + if isinstance(content, str): + return content + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + text = part.get("text") + if isinstance(text, str): + return text + return None + + +def _build_handler_class(): + """Build the ``HackAgentTrackingLogger`` class once litellm is importable.""" + CustomLogger = _try_import_custom_logger() + if CustomLogger is None: + return None + + class HackAgentTrackingLogger(CustomLogger): # type: ignore[misc, valid-type] + """Capture every HackAgent-owned ``litellm.completion`` call.""" + + def log_pre_api_call(self, model, messages, kwargs): + metadata = _extract_hackagent_metadata(kwargs) + if metadata is None: + return + preview = "" + for msg in reversed(messages or []): + if isinstance(msg, dict) and msg.get("role") == "user": + content = msg.get("content") or "" + preview = (content if isinstance(content, str) else "")[:120] + break + _TRACKING_LOGGER.info( + "litellm.pre", + extra={ + "hackagent_agent_id": metadata.get(HACKAGENT_AGENT_ID_KEY), + "hackagent_adapter_type": metadata.get(HACKAGENT_ADAPTER_TYPE_KEY), + "litellm_model": model, + "prompt_preview": preview, + }, + ) + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + metadata = _extract_hackagent_metadata(kwargs) + if metadata is None: + return + text = _extract_response_text(response_obj) or "" + response_preview = text[:200] if text else "" + duration_ms = None + try: + duration_ms = (end_time - start_time).total_seconds() * 1000 + except (AttributeError, TypeError): + pass + _TRACKING_LOGGER.info( + "litellm.success", + extra={ + "hackagent_agent_id": metadata.get(HACKAGENT_AGENT_ID_KEY), + "hackagent_adapter_type": metadata.get(HACKAGENT_ADAPTER_TYPE_KEY), + "litellm_model": kwargs.get("model"), + "litellm_call_id": kwargs.get("litellm_call_id"), + "response_preview": response_preview, + "response_cost": kwargs.get("response_cost"), + "duration_ms": duration_ms, + "prompt_preview": _last_user_message(kwargs), + }, + ) + + async def async_log_success_event( + self, kwargs, response_obj, start_time, end_time + ): + self.log_success_event(kwargs, response_obj, start_time, end_time) + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + metadata = _extract_hackagent_metadata(kwargs) + if metadata is None: + return + duration_ms = None + try: + duration_ms = (end_time - start_time).total_seconds() * 1000 + except (AttributeError, TypeError): + pass + _TRACKING_LOGGER.warning( + "litellm.failure", + extra={ + "hackagent_agent_id": metadata.get(HACKAGENT_AGENT_ID_KEY), + "hackagent_adapter_type": metadata.get(HACKAGENT_ADAPTER_TYPE_KEY), + "litellm_model": kwargs.get("model"), + "litellm_call_id": kwargs.get("litellm_call_id"), + "exception_repr": repr(kwargs.get("exception", response_obj)), + "duration_ms": duration_ms, + "prompt_preview": _last_user_message(kwargs), + }, + ) + + async def async_log_failure_event( + self, kwargs, response_obj, start_time, end_time + ): + self.log_failure_event(kwargs, response_obj, start_time, end_time) + + return HackAgentTrackingLogger + + +def ensure_registered() -> bool: + """Register the tracking logger on ``litellm.callbacks`` exactly once. + + Idempotent — safe to call from every ``AgentRouter.__init__``. + Returns ``True`` when registration is in effect (either because we + just registered or because we already had). + """ + global _REGISTERED, _LOGGER_INSTANCE + if _REGISTERED: + return True + + handler_cls = _build_handler_class() + if handler_cls is None: + _TRACKING_LOGGER.debug( + "litellm.integrations.custom_logger.CustomLogger unavailable; " + "skipping HackAgentTrackingLogger registration." + ) + return False + + try: + import litellm + except ImportError: + return False + + instance = handler_cls() + callbacks = list(getattr(litellm, "callbacks", None) or []) + # Guard against re-adding ourselves if the user already imported us + # in another module. + already = any(getattr(cb, "__class__", None) is handler_cls for cb in callbacks) + if not already: + callbacks.append(instance) + litellm.callbacks = callbacks + + _LOGGER_INSTANCE = instance + _REGISTERED = True + return True + + +def get_instance() -> Optional[Any]: + """Return the singleton logger instance (mainly for tests).""" + return _LOGGER_INSTANCE + + +def _reset_for_tests() -> None: + """Reset the singleton state — only used by the unit tests.""" + global _REGISTERED, _LOGGER_INSTANCE + _REGISTERED = False + _LOGGER_INSTANCE = None diff --git a/tests/unit/router/test_dispatch.py b/tests/unit/router/test_dispatch.py index 00ef4623..4bf50772 100644 --- a/tests/unit/router/test_dispatch.py +++ b/tests/unit/router/test_dispatch.py @@ -160,6 +160,35 @@ def test_thinking_translation_applied_through_router(self, mock_completion): kwargs = mock_completion.call_args.kwargs self.assertEqual(kwargs.get("reasoning_effort"), "medium") + @patch("litellm.completion") + def test_dispatch_attaches_hackagent_metadata(self, mock_completion): + """Phase D — every call carries metadata for the tracking logger.""" + mock_completion.return_value = _make_litellm_response("ok") + router, reg_key = self._make_router_for_openai() + router.route_request(reg_key, {"prompt": "hi"}) + metadata = mock_completion.call_args.kwargs.get("metadata") + self.assertIsInstance(metadata, dict) + self.assertEqual(metadata.get("hackagent_agent_id"), reg_key) + self.assertEqual(metadata.get("hackagent_adapter_type"), "OpenAIAgent") + + @patch("litellm.completion") + def test_caller_supplied_metadata_is_merged_and_wins(self, mock_completion): + mock_completion.return_value = _make_litellm_response("ok") + router, reg_key = self._make_router_for_openai() + router.route_request( + reg_key, + { + "prompt": "hi", + "metadata": {"trace_id": "xyz", "hackagent_agent_id": "override"}, + }, + ) + metadata = mock_completion.call_args.kwargs.get("metadata") + self.assertEqual(metadata["trace_id"], "xyz") + # Caller-supplied keys win on collision. + self.assertEqual(metadata["hackagent_agent_id"], "override") + # Adapter-type still set by the router. + self.assertEqual(metadata["hackagent_adapter_type"], "OpenAIAgent") + def test_unknown_registration_key_returns_404_envelope(self): router, _ = self._make_router_for_openai() response = router.route_request("nonexistent-key", {"prompt": "hi"}) diff --git a/tests/unit/router/test_tracking_logger.py b/tests/unit/router/test_tracking_logger.py new file mode 100644 index 00000000..5f8fd86b --- /dev/null +++ b/tests/unit/router/test_tracking_logger.py @@ -0,0 +1,129 @@ +# Copyright 2026 - AI4I. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for ``hackagent/router/tracking_logger.py``.""" + +import datetime as _dt +import logging +import unittest +from unittest.mock import MagicMock, patch + +from hackagent.router import tracking_logger + +logging.disable(logging.CRITICAL) + + +def _hackagent_kwargs(**overrides): + """Build a kwargs dict shaped the way LiteLLM passes one to a callback.""" + base = { + "model": "openai/gpt-4", + "messages": [{"role": "user", "content": "hello"}], + "litellm_call_id": "call-1", + "response_cost": 0.0001, + "litellm_params": { + "metadata": { + "hackagent_agent_id": "agent-123", + "hackagent_adapter_type": "OpenAIAgent", + } + }, + } + base.update(overrides) + return base + + +def _model_response(content: str = "ok"): + response = MagicMock() + message = MagicMock() + message.content = content + message.reasoning_content = None + message.reasoning = None + choice = MagicMock() + choice.message = message + response.choices = [choice] + return response + + +class TestEnsureRegistered(unittest.TestCase): + def setUp(self): + tracking_logger._reset_for_tests() + import litellm + + # Snapshot callbacks so we can restore them. + self._saved_callbacks = list(getattr(litellm, "callbacks", None) or []) + litellm.callbacks = list(self._saved_callbacks) + self._litellm = litellm + + def tearDown(self): + self._litellm.callbacks = self._saved_callbacks + tracking_logger._reset_for_tests() + + def test_idempotent_registration(self): + self.assertTrue(tracking_logger.ensure_registered()) + first_callbacks = list(self._litellm.callbacks) + self.assertTrue(tracking_logger.ensure_registered()) + self.assertEqual(list(self._litellm.callbacks), first_callbacks) + + def test_logger_instance_exposed(self): + tracking_logger.ensure_registered() + self.assertIsNotNone(tracking_logger.get_instance()) + + +class TestCallbackFilteringByMetadata(unittest.TestCase): + """Calls without the HackAgent sentinel metadata are ignored.""" + + def setUp(self): + tracking_logger._reset_for_tests() + tracking_logger.ensure_registered() + self.logger = tracking_logger.get_instance() + + def tearDown(self): + tracking_logger._reset_for_tests() + + @patch.object(tracking_logger._TRACKING_LOGGER, "info") + def test_pre_call_with_no_metadata_is_skipped(self, mock_info): + self.logger.log_pre_api_call( + "openai/gpt-4", + [{"role": "user", "content": "hi"}], + {"litellm_params": {}}, # no metadata + ) + mock_info.assert_not_called() + + @patch.object(tracking_logger._TRACKING_LOGGER, "info") + def test_pre_call_with_hackagent_metadata_is_logged(self, mock_info): + self.logger.log_pre_api_call( + "openai/gpt-4", + [{"role": "user", "content": "hi"}], + _hackagent_kwargs(), + ) + mock_info.assert_called_once() + extra = mock_info.call_args.kwargs["extra"] + self.assertEqual(extra["hackagent_agent_id"], "agent-123") + self.assertEqual(extra["hackagent_adapter_type"], "OpenAIAgent") + + @patch.object(tracking_logger._TRACKING_LOGGER, "info") + def test_success_logs_cost_call_id_and_preview(self, mock_info): + start = _dt.datetime(2026, 1, 1, 0, 0, 0) + end = _dt.datetime(2026, 1, 1, 0, 0, 1) + self.logger.log_success_event( + _hackagent_kwargs(), _model_response("hi"), start, end + ) + mock_info.assert_called_once() + extra = mock_info.call_args.kwargs["extra"] + self.assertEqual(extra["litellm_call_id"], "call-1") + self.assertEqual(extra["response_cost"], 0.0001) + self.assertEqual(extra["response_preview"], "hi") + self.assertAlmostEqual(extra["duration_ms"], 1000.0, places=1) + + @patch.object(tracking_logger._TRACKING_LOGGER, "warning") + def test_failure_logs_exception_repr(self, mock_warning): + start = _dt.datetime(2026, 1, 1, 0, 0, 0) + end = _dt.datetime(2026, 1, 1, 0, 0, 1) + kwargs = _hackagent_kwargs(exception=RuntimeError("boom")) + self.logger.log_failure_event(kwargs, None, start, end) + mock_warning.assert_called_once() + extra = mock_warning.call_args.kwargs["extra"] + self.assertIn("boom", extra["exception_repr"]) + + +if __name__ == "__main__": + unittest.main() From e065713392f52429a2b009afa0795fc00306cf79 Mon Sep 17 00:00:00 2001 From: Nicola Franco Date: Sat, 23 May 2026 16:55:22 +0200 Subject: [PATCH 06/23] refactor(router): move ADK CustomLLM to providers/ (#379 Phase E) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase E (partial) of the LiteLLM router refactor. The Google ADK provider — which is fundamentally a ``litellm.CustomLLM`` wrapping the ADK ``POST /run`` + sessions + events protocol — now lives at ``hackagent/router/providers/adk.py``, its logical home. The old ``hackagent/router/adapters/google_adk.py`` path is preserved as a thin re-export shim so existing imports keep working. The chat adapter classes (``LiteLLMAgent``, ``OpenAIAgent``, ``OllamaAgent``) stay in ``hackagent/router/adapters/`` for now. They no longer run on the hot path — ``AgentRouter._dispatch_via_litellm`` calls ``litellm.completion`` directly — but their public symbols are still imported by external callers, so deleting them is a separate decision documented as Phase E.2 in ``LITELLM_ROUTER_REFACTOR_PLAN.md``. The plan markdown is updated with a Status section enumerating which commits landed which phase and what's deferred. Co-Authored-By: Claude Opus 4.7 (1M context) --- LITELLM_ROUTER_REFACTOR_PLAN.md | 56 +++ hackagent/router/adapters/__init__.py | 20 +- hackagent/router/adapters/google_adk.py | 453 +----------------------- hackagent/router/providers/__init__.py | 0 hackagent/router/providers/adk.py | 444 +++++++++++++++++++++++ tests/unit/adapters/test_google_adk.py | 15 + 6 files changed, 550 insertions(+), 438 deletions(-) create mode 100644 hackagent/router/providers/__init__.py create mode 100644 hackagent/router/providers/adk.py diff --git a/LITELLM_ROUTER_REFACTOR_PLAN.md b/LITELLM_ROUTER_REFACTOR_PLAN.md index c8e9822c..675f4ec5 100644 --- a/LITELLM_ROUTER_REFACTOR_PLAN.md +++ b/LITELLM_ROUTER_REFACTOR_PLAN.md @@ -302,3 +302,59 @@ green. request, error path, and ADK request. - All existing example scripts under `hackagent/examples/` keep working without code edits. + +--- + +## 9. Status (2026-05-23) + +Phases A–E (partial) landed in five commits on +`feat/litellm-unified-adapters-379`: + +- **Phase A** (1b3dedf): `hackagent/router/envelope.py` extracted; the + adapter classes delegate to its pure functions. `provider_config.py` + added alongside but not wired in yet. +- **Phase B** (67cd38f): `LiteLLMAgent.__init__` now accepts an optional + ``ProviderConfig``. `OpenAIAgent` and `OllamaAgent` look their config + up from the table; their `_apply_thinking` overrides are gone. +- **Phase C** (c14b2e0): `AgentRouter._dispatch_via_litellm` calls + `litellm.completion` directly for every chat-completion AgentType + (LITELLM, OPENAI_SDK, OLLAMA, LANGCHAIN). `adapter.handle_request` is + no longer on the hot path for those. ADK still goes through its + adapter (CustomLLM registration is per-instance, by design). +- **Phase D**: `HackAgentTrackingLogger` (CustomLogger subclass) + registered on `litellm.callbacks` from `AgentRouter.__init__`; + `_dispatch_via_litellm` attaches `metadata={...}` so the logger can + correlate input ↔ output ↔ cost via `litellm_call_id`. +- **Phase E (partial)**: ADK moved to + `hackagent/router/providers/adk.py`; the old + `hackagent/router/adapters/google_adk.py` path is a thin re-export + shim for backwards compatibility. The chat adapter classes + (`LiteLLMAgent`, `OpenAIAgent`, `OllamaAgent`) are kept and now act + as config containers — they no longer run on the hot path but are + still instantiated so external callers that `from + hackagent.router.adapters.openai import OpenAIAgent` keep working. + +### Remaining work (deferred) + +- **Phase E.2 — full deletion of the chat adapter classes.** Requires + replacing the per-registration adapter instance with a lightweight + config dataclass (`_ChatRegistration`) and dropping the public + `LiteLLMAgent` / `OpenAIAgent` / `OllamaAgent` symbols. Hold off + until we know no downstream code in `hackagent-api`, + `hackagent-webapp`, or external consumers depends on those imports. +- **Phase F — optional follow-ups.** Adopt `litellm.Router` for + multi-deployment load balancing; surface `response_cost` and + `x-litellm-call-id` in the envelope; streaming support; full + `_build_error_response` shape unification (currently the + `AgentNotFound` envelope still uses ``raw_response_status`` while + the chat-dispatch envelope uses ``status_code``). + +### Tests + +1776 unit tests pass after Phase E. New coverage: + +- `tests/unit/router/test_envelope.py` (26 tests) +- `tests/unit/router/test_provider_config.py` (25 tests) +- `tests/unit/router/test_dispatch.py` (8 tests — chat dispatch + ADK + bypass + metadata flow) +- `tests/unit/router/test_tracking_logger.py` (6 tests) diff --git a/hackagent/router/adapters/__init__.py b/hackagent/router/adapters/__init__.py index c5a744e8..37e73fb7 100644 --- a/hackagent/router/adapters/__init__.py +++ b/hackagent/router/adapters/__init__.py @@ -1,6 +1,24 @@ # Copyright 2026 - AI4I. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +""" +Legacy adapter package. + +Issue #379 routed every chat-completion AgentType through LiteLLM +(Phase A–B), hoisted the call path into ``AgentRouter`` (Phase C), and +wired a ``litellm.CustomLogger`` for I/O capture (Phase D). The +adapter classes ``LiteLLMAgent`` / ``OpenAIAgent`` / ``OllamaAgent`` +are no longer on the hot path — ``AgentRouter._dispatch_via_litellm`` +calls ``litellm.completion`` directly, reading the resolved config off +the instance only for its model name, endpoint, API key, and +generation defaults. + +The classes remain available for backwards compatibility with external +callers that import them. ADK has moved to +``hackagent.router.providers.adk``; this package re-exports +``ADKAgent`` from there so old imports keep working. +""" + # Lazy imports for adapters to improve startup time # These adapters import heavy dependencies (litellm ~2s, google-adk ~0.1s) from .base import ( @@ -15,7 +33,7 @@ def __getattr__(name): """Lazy load adapter classes on first access.""" if name == "ADKAgent": - from .google_adk import ADKAgent + from hackagent.router.providers.adk import ADKAgent return ADKAgent elif name == "LiteLLMAgent": diff --git a/hackagent/router/adapters/google_adk.py b/hackagent/router/adapters/google_adk.py index 66d881cd..6f00a6c8 100644 --- a/hackagent/router/adapters/google_adk.py +++ b/hackagent/router/adapters/google_adk.py @@ -2,443 +2,22 @@ # SPDX-License-Identifier: Apache-2.0 """ -Google ADK (Agent Development Kit) adapter built on top of LiteLLM. - -LiteLLM has no built-in provider for the ADK server protocol (POST /run -with sessions and events), so issue #379 routes ADK through LiteLLM by -registering a per-instance :class:`litellm.CustomLLM` handler under a -unique provider name. The HTTP transport against the deployed ADK server -lives in the lazily-defined ``_ADKCustomLLM`` class, while -:class:`ADKAgent` itself is a thin :class:`LiteLLMAgent` subclass that -registers the handler and asks LiteLLM to route through it. +Backwards-compatibility re-export. + +Issue #379 Phase E moved the Google ADK provider to +``hackagent/router/providers/adk.py`` — its logical home, since +``ADKAgent`` is a :class:`litellm.CustomLLM` wrapper rather than a chat +adapter. This module re-exports the public names so existing +``from hackagent.router.adapters.google_adk import ...`` imports keep +working. """ -import json -import uuid -from hackagent.logger import get_logger -from typing import Any, Dict, List, Optional - -import requests - -from .base import ( - AdapterConfigurationError, - AdapterInteractionError, - AdapterResponseParsingError, +from hackagent.router.providers.adk import ( # noqa: F401 + ADKAgent, + AgentConfigurationError, + AgentInteractionError, + ResponseParsingError, + _extract_final_text, + _get_adk_custom_llm_class, + _last_user_text, ) -from .litellm import LiteLLMAgent, _get_litellm - - -logger = get_logger(__name__) - - -# --- Custom exceptions (kept for backwards compatibility) --- -class AgentConfigurationError(AdapterConfigurationError): - """ADK adapter configuration issues.""" - - pass - - -class AgentInteractionError(AdapterInteractionError): - """Errors interacting with the ADK agent server.""" - - pass - - -class ResponseParsingError(AdapterResponseParsingError): - """Errors parsing the ADK server's event-list response.""" - - pass - - -_ADK_PROVIDER_PREFIX = "hackagent_adk" - - -def _last_user_text(messages: List[Dict[str, Any]]) -> Optional[str]: - """Return the text of the last user message in ``messages``.""" - for msg in reversed(messages or []): - if (msg or {}).get("role") != "user": - continue - content = msg.get("content") - if isinstance(content, str): - return content - # OpenAI-style content lists. - if isinstance(content, list): - for part in content: - if isinstance(part, dict) and part.get("type") == "text": - text = part.get("text") - if isinstance(text, str): - return text - return None - - -def _extract_final_text(events: List[Dict[str, Any]]) -> Optional[str]: - """Walk ``events`` newest-first and return the agent's final reply.""" - for event in reversed(events): - actions = event.get("actions") - if actions and isinstance(actions, dict) and actions.get("escalate"): - error_msg = event.get( - "error_message", - "No specific message provided by agent for escalation.", - ) - return f"Agent escalated: {error_msg}" - - content = event.get("content") - if not isinstance(content, dict): - continue - parts = content.get("parts") - if not isinstance(parts, list) or not parts: - continue - first = parts[0] - if not isinstance(first, dict): - continue - text = first.get("text") - if isinstance(text, str) and text.strip(): - return text - return None - - -_ADK_CUSTOM_LLM_CLASS = None - - -def _get_adk_custom_llm_class(): - """Lazily build the CustomLLM subclass once litellm is importable. - - Defined as a function instead of a module-level class so this module - keeps loading even when litellm is missing — the LiteLLMAgent base - will raise a clear error before anyone tries to actually use ADK. - """ - global _ADK_CUSTOM_LLM_CLASS - if _ADK_CUSTOM_LLM_CLASS is not None: - return _ADK_CUSTOM_LLM_CLASS - - from litellm import CustomLLM - from litellm.types.utils import ModelResponse - - class _ADKCustomLLM(CustomLLM): - """LiteLLM CustomLLM handler that proxies to an ADK server.""" - - def __init__( - self, - *, - endpoint: str, - app_name: str, - user_id: str, - default_session_id: str, - fresh_session_per_request: bool, - timeout: int, - log, - ): - super().__init__() - self.endpoint = endpoint.rstrip("/") - self.app_name = app_name - self.user_id = user_id - self.default_session_id = default_session_id - self.fresh_session_per_request = fresh_session_per_request - self.timeout = timeout - self.logger = log - - # ---- ADK transport (kept close to the previous implementation) --- - - def _create_session( - self, session_id: str, initial_state: Optional[dict] = None - ) -> None: - url = ( - f"{self.endpoint}/apps/{self.app_name}/users/" - f"{self.user_id}/sessions/{session_id}" - ) - headers = { - "Content-Type": "application/json", - "Accept": "application/json", - } - payload = initial_state or {} - try: - response = requests.post(url, headers=headers, json=payload, timeout=30) - response.raise_for_status() - return - except requests.exceptions.HTTPError as http_err: - response_text = "" - status_code = None - if http_err.response is not None: - status_code = http_err.response.status_code - try: - response_text = http_err.response.text or "" - except Exception: - response_text = "" - if status_code == 409: - return - if ( - status_code == 400 - and "session already exists" in response_text.lower() - ): - return - raise AgentInteractionError( - f"HTTP Error {status_code} creating session " - f"{session_id}: {response_text[:200]}" - ) from http_err - except requests.exceptions.RequestException as e: - raise AgentInteractionError( - f"Request failed creating session {session_id}: {e}" - ) from e - - def _run(self, prompt_text: str, session_id: str) -> Dict[str, Any]: - url = f"{self.endpoint}/run" - headers = { - "Content-Type": "application/json", - "Accept": "application/json", - } - payload = { - "app_name": self.app_name, - "user_id": self.user_id, - "session_id": session_id, - "new_message": { - "role": "user", - "parts": [{"text": prompt_text}], - }, - } - - try: - response = requests.post( - url, headers=headers, json=payload, timeout=self.timeout - ) - except requests.exceptions.Timeout as e: - raise AgentInteractionError(f"Request timed out: {e}") from e - except requests.exceptions.RequestException as e: - raise AgentInteractionError(f"Request failed: {e}") from e - - response_body = response.text - try: - response.raise_for_status() - except requests.exceptions.HTTPError as http_err: - raise AgentInteractionError( - f"HTTP Error: {response.status_code}" - ) from http_err - - try: - events = response.json() - except (json.JSONDecodeError, ValueError) as parse_err: - raise ResponseParsingError( - f"JSON parse failed: {parse_err}. Body: {response_body[:200]}" - ) from parse_err - - if not isinstance(events, list): - if isinstance(events, dict) and "detail" in events: - raise ResponseParsingError( - f"ADK returned detail message: {events['detail']}" - ) - raise ResponseParsingError( - "ADK response format unrecognized (not a list)." - ) - - return { - "events": events, - "raw_request": payload, - "raw_response_body": response_body, - "raw_response_headers": dict(response.headers), - "status_code": response.status_code, - "final_text": _extract_final_text(events), - } - - # ---- LiteLLM CustomLLM API --------------------------------------- - - def completion(self, *args, **kwargs): - """Translate a LiteLLM completion call into an ADK /run request.""" - messages = kwargs.get("messages") or [] - optional_params = kwargs.get("optional_params") or {} - model_response: ModelResponse = ( - kwargs.get("model_response") or ModelResponse() - ) - - prompt_text = _last_user_text(messages) - if not prompt_text: - raise AgentInteractionError( - "ADK adapter requires at least one user message with text content." - ) - - session_id = optional_params.get("session_id") - if not session_id: - session_id = ( - str(uuid.uuid4()) - if self.fresh_session_per_request - else self.default_session_id - ) - initial_state = optional_params.get("initial_session_state") - - self.logger.info( - f"🌐 ADK run for app '{self.app_name}' (session {session_id})" - ) - self._create_session(session_id=session_id, initial_state=initial_state) - result = self._run(prompt_text=prompt_text, session_id=session_id) - - final_text = result["final_text"] or "" - model_response.choices[0].message.content = final_text # type: ignore[attr-defined] - try: - model_response.choices[0].finish_reason = "stop" # type: ignore[attr-defined] - except Exception: - pass - model_response.model = ( - kwargs.get("model") or f"{_ADK_PROVIDER_PREFIX}/{self.app_name}" - ) - - # Stash ADK-specific bits where the outer adapter can find them. - try: - model_response.choices[0].message.provider_specific_fields = { # type: ignore[attr-defined] - "adk_events_list": result["events"], - "adk_session_id": session_id, - "adk_raw_response_body": result["raw_response_body"], - "adk_raw_request": result["raw_request"], - "adk_status_code": result["status_code"], - } - except Exception: - pass - return model_response - - async def acompletion(self, *args, **kwargs): - """Async wrapper — run the sync ADK transport in a worker thread.""" - import asyncio - - return await asyncio.get_event_loop().run_in_executor( - None, lambda: self.completion(*args, **kwargs) - ) - - _ADK_CUSTOM_LLM_CLASS = _ADKCustomLLM - return _ADKCustomLLM - - -class ADKAgent(LiteLLMAgent): - """ - Adapter for a deployed Google ADK agent server. - - The request travels through LiteLLM via a per-instance - :class:`CustomLLM` handler registered as - ``hackagent_adk_/``. From the router's perspective this - is just another LiteLLM agent. - - Required config: - - ``name``: ADK app name (used as both the model string and the - ``app_name`` in the request payload). - - ``endpoint``: ADK server base URL. - - ``user_id``: User ID for ADK sessions. - - Optional config: - - ``timeout`` (seconds, default 120). - - ``session_id``: sticky session ID; if unset a UUID is generated. - - ``fresh_session_per_request`` (default True): if True, every - request gets a brand-new session unless the caller supplies one. - """ - - ADAPTER_TYPE = "ADKAgent" - - def __init__(self, id: str, config: Dict[str, Any]): - for key in ("name", "endpoint", "user_id"): - if key not in config: - raise AgentConfigurationError( - f"Missing required configuration key '{key}' for ADKAgent: {id}" - ) - - # Provider name is per-instance so each ADKAgent gets its own handler. - # Set on self before super().__init__ runs so that the base's call to - # _resolve_litellm_model (overridden below) sees the right value. - self._provider_name = f"{_ADK_PROVIDER_PREFIX}_{id}" - - adk_endpoint = str(config["endpoint"]).strip("/") - adk_user_id = config["user_id"] - adk_app_name = config["name"] - adk_timeout = int(config.get("timeout", 120)) - fresh = bool(config.get("fresh_session_per_request", True)) - session_id = config.get("session_id") or str(uuid.uuid4()) - - # The base passes ``endpoint`` along to LiteLLM as ``api_base``; we - # don't want that since our custom provider hits ADK directly. - base_config = {k: v for k, v in config.items() if k != "endpoint"} - super().__init__(id, base_config) - - self.endpoint = adk_endpoint - self.user_id = adk_user_id - self.name = adk_app_name - self.timeout = adk_timeout - self.fresh_session_per_request = fresh - self.session_id = session_id - - self._register_custom_provider() - - self.logger.info( - f"ADKAgent '{self.id}' registered as LiteLLM provider " - f"'{self._provider_name}' targeting {self.endpoint} " - f"(app={self.name}, session={self.session_id}, " - f"fresh_session_per_request={self.fresh_session_per_request})" - ) - - def _register_custom_provider(self) -> None: - litellm, available = _get_litellm() - if not available: - raise AgentConfigurationError( - "litellm is required for ADKAgent but is not installed." - ) - - handler_cls = _get_adk_custom_llm_class() - handler = handler_cls( - endpoint=self.endpoint, - app_name=self.name, - user_id=self.user_id, - default_session_id=self.session_id, - fresh_session_per_request=self.fresh_session_per_request, - timeout=self.timeout, - log=self.logger, - ) - - provider = self._provider_name - # Replace any stale entry for this provider name (e.g. when an - # ADKAgent with the same id is re-created during tests). - litellm.custom_provider_map = [ - entry - for entry in litellm.custom_provider_map - if entry.get("provider") != provider - ] - litellm.custom_provider_map.append( - {"provider": provider, "custom_handler": handler} - ) - if provider not in litellm._custom_providers: - litellm._custom_providers.append(provider) - - self._custom_handler = handler - - def _resolve_litellm_model(self, raw_model: str) -> str: - return f"{self._provider_name}/{raw_model}" - - # ---- forward ADK-specific request fields ---------------------------- - - def _get_completion_parameters( - self, request_data: Dict[str, Any] - ) -> Dict[str, Any]: - params = super()._get_completion_parameters(request_data) - session_id = request_data.get("session_id", request_data.get("adk_session_id")) - if session_id: - params["session_id"] = session_id - if "initial_session_state" in request_data: - params["initial_session_state"] = request_data["initial_session_state"] - return params - - def _get_excluded_request_keys(self) -> set: - base = super()._get_excluded_request_keys() - return base | {"session_id", "adk_session_id", "initial_session_state"} - - def _build_agent_specific_data( - self, - completion_result: Dict[str, Any], - parameters: Dict[str, Any], - ) -> Dict[str, Any]: - data = super()._build_agent_specific_data(completion_result, parameters) - raw = completion_result.get("raw_response") - adk_fields: Dict[str, Any] = {} - try: - adk_fields = ( - getattr(raw.choices[0].message, "provider_specific_fields", None) or {} - ) - except (AttributeError, IndexError, TypeError): - adk_fields = {} - events = adk_fields.get("adk_events_list") - if events is not None: - data["adk_events_list"] = events - if "adk_session_id" in adk_fields: - data["adk_session_id"] = adk_fields["adk_session_id"] - return data diff --git a/hackagent/router/providers/__init__.py b/hackagent/router/providers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/hackagent/router/providers/adk.py b/hackagent/router/providers/adk.py new file mode 100644 index 00000000..5459fd9d --- /dev/null +++ b/hackagent/router/providers/adk.py @@ -0,0 +1,444 @@ +# Copyright 2026 - AI4I. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Google ADK (Agent Development Kit) adapter built on top of LiteLLM. + +LiteLLM has no built-in provider for the ADK server protocol (POST /run +with sessions and events), so issue #379 routes ADK through LiteLLM by +registering a per-instance :class:`litellm.CustomLLM` handler under a +unique provider name. The HTTP transport against the deployed ADK server +lives in the lazily-defined ``_ADKCustomLLM`` class, while +:class:`ADKAgent` itself is a thin :class:`LiteLLMAgent` subclass that +registers the handler and asks LiteLLM to route through it. +""" + +import json +import uuid +from hackagent.logger import get_logger +from typing import Any, Dict, List, Optional + +import requests + +from hackagent.router.adapters.base import ( + AdapterConfigurationError, + AdapterInteractionError, + AdapterResponseParsingError, +) +from hackagent.router.adapters.litellm import LiteLLMAgent, _get_litellm + + +logger = get_logger(__name__) + + +# --- Custom exceptions (kept for backwards compatibility) --- +class AgentConfigurationError(AdapterConfigurationError): + """ADK adapter configuration issues.""" + + pass + + +class AgentInteractionError(AdapterInteractionError): + """Errors interacting with the ADK agent server.""" + + pass + + +class ResponseParsingError(AdapterResponseParsingError): + """Errors parsing the ADK server's event-list response.""" + + pass + + +_ADK_PROVIDER_PREFIX = "hackagent_adk" + + +def _last_user_text(messages: List[Dict[str, Any]]) -> Optional[str]: + """Return the text of the last user message in ``messages``.""" + for msg in reversed(messages or []): + if (msg or {}).get("role") != "user": + continue + content = msg.get("content") + if isinstance(content, str): + return content + # OpenAI-style content lists. + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + text = part.get("text") + if isinstance(text, str): + return text + return None + + +def _extract_final_text(events: List[Dict[str, Any]]) -> Optional[str]: + """Walk ``events`` newest-first and return the agent's final reply.""" + for event in reversed(events): + actions = event.get("actions") + if actions and isinstance(actions, dict) and actions.get("escalate"): + error_msg = event.get( + "error_message", + "No specific message provided by agent for escalation.", + ) + return f"Agent escalated: {error_msg}" + + content = event.get("content") + if not isinstance(content, dict): + continue + parts = content.get("parts") + if not isinstance(parts, list) or not parts: + continue + first = parts[0] + if not isinstance(first, dict): + continue + text = first.get("text") + if isinstance(text, str) and text.strip(): + return text + return None + + +_ADK_CUSTOM_LLM_CLASS = None + + +def _get_adk_custom_llm_class(): + """Lazily build the CustomLLM subclass once litellm is importable. + + Defined as a function instead of a module-level class so this module + keeps loading even when litellm is missing — the LiteLLMAgent base + will raise a clear error before anyone tries to actually use ADK. + """ + global _ADK_CUSTOM_LLM_CLASS + if _ADK_CUSTOM_LLM_CLASS is not None: + return _ADK_CUSTOM_LLM_CLASS + + from litellm import CustomLLM + from litellm.types.utils import ModelResponse + + class _ADKCustomLLM(CustomLLM): + """LiteLLM CustomLLM handler that proxies to an ADK server.""" + + def __init__( + self, + *, + endpoint: str, + app_name: str, + user_id: str, + default_session_id: str, + fresh_session_per_request: bool, + timeout: int, + log, + ): + super().__init__() + self.endpoint = endpoint.rstrip("/") + self.app_name = app_name + self.user_id = user_id + self.default_session_id = default_session_id + self.fresh_session_per_request = fresh_session_per_request + self.timeout = timeout + self.logger = log + + # ---- ADK transport (kept close to the previous implementation) --- + + def _create_session( + self, session_id: str, initial_state: Optional[dict] = None + ) -> None: + url = ( + f"{self.endpoint}/apps/{self.app_name}/users/" + f"{self.user_id}/sessions/{session_id}" + ) + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + } + payload = initial_state or {} + try: + response = requests.post(url, headers=headers, json=payload, timeout=30) + response.raise_for_status() + return + except requests.exceptions.HTTPError as http_err: + response_text = "" + status_code = None + if http_err.response is not None: + status_code = http_err.response.status_code + try: + response_text = http_err.response.text or "" + except Exception: + response_text = "" + if status_code == 409: + return + if ( + status_code == 400 + and "session already exists" in response_text.lower() + ): + return + raise AgentInteractionError( + f"HTTP Error {status_code} creating session " + f"{session_id}: {response_text[:200]}" + ) from http_err + except requests.exceptions.RequestException as e: + raise AgentInteractionError( + f"Request failed creating session {session_id}: {e}" + ) from e + + def _run(self, prompt_text: str, session_id: str) -> Dict[str, Any]: + url = f"{self.endpoint}/run" + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + } + payload = { + "app_name": self.app_name, + "user_id": self.user_id, + "session_id": session_id, + "new_message": { + "role": "user", + "parts": [{"text": prompt_text}], + }, + } + + try: + response = requests.post( + url, headers=headers, json=payload, timeout=self.timeout + ) + except requests.exceptions.Timeout as e: + raise AgentInteractionError(f"Request timed out: {e}") from e + except requests.exceptions.RequestException as e: + raise AgentInteractionError(f"Request failed: {e}") from e + + response_body = response.text + try: + response.raise_for_status() + except requests.exceptions.HTTPError as http_err: + raise AgentInteractionError( + f"HTTP Error: {response.status_code}" + ) from http_err + + try: + events = response.json() + except (json.JSONDecodeError, ValueError) as parse_err: + raise ResponseParsingError( + f"JSON parse failed: {parse_err}. Body: {response_body[:200]}" + ) from parse_err + + if not isinstance(events, list): + if isinstance(events, dict) and "detail" in events: + raise ResponseParsingError( + f"ADK returned detail message: {events['detail']}" + ) + raise ResponseParsingError( + "ADK response format unrecognized (not a list)." + ) + + return { + "events": events, + "raw_request": payload, + "raw_response_body": response_body, + "raw_response_headers": dict(response.headers), + "status_code": response.status_code, + "final_text": _extract_final_text(events), + } + + # ---- LiteLLM CustomLLM API --------------------------------------- + + def completion(self, *args, **kwargs): + """Translate a LiteLLM completion call into an ADK /run request.""" + messages = kwargs.get("messages") or [] + optional_params = kwargs.get("optional_params") or {} + model_response: ModelResponse = ( + kwargs.get("model_response") or ModelResponse() + ) + + prompt_text = _last_user_text(messages) + if not prompt_text: + raise AgentInteractionError( + "ADK adapter requires at least one user message with text content." + ) + + session_id = optional_params.get("session_id") + if not session_id: + session_id = ( + str(uuid.uuid4()) + if self.fresh_session_per_request + else self.default_session_id + ) + initial_state = optional_params.get("initial_session_state") + + self.logger.info( + f"🌐 ADK run for app '{self.app_name}' (session {session_id})" + ) + self._create_session(session_id=session_id, initial_state=initial_state) + result = self._run(prompt_text=prompt_text, session_id=session_id) + + final_text = result["final_text"] or "" + model_response.choices[0].message.content = final_text # type: ignore[attr-defined] + try: + model_response.choices[0].finish_reason = "stop" # type: ignore[attr-defined] + except Exception: + pass + model_response.model = ( + kwargs.get("model") or f"{_ADK_PROVIDER_PREFIX}/{self.app_name}" + ) + + # Stash ADK-specific bits where the outer adapter can find them. + try: + model_response.choices[0].message.provider_specific_fields = { # type: ignore[attr-defined] + "adk_events_list": result["events"], + "adk_session_id": session_id, + "adk_raw_response_body": result["raw_response_body"], + "adk_raw_request": result["raw_request"], + "adk_status_code": result["status_code"], + } + except Exception: + pass + return model_response + + async def acompletion(self, *args, **kwargs): + """Async wrapper — run the sync ADK transport in a worker thread.""" + import asyncio + + return await asyncio.get_event_loop().run_in_executor( + None, lambda: self.completion(*args, **kwargs) + ) + + _ADK_CUSTOM_LLM_CLASS = _ADKCustomLLM + return _ADKCustomLLM + + +class ADKAgent(LiteLLMAgent): + """ + Adapter for a deployed Google ADK agent server. + + The request travels through LiteLLM via a per-instance + :class:`CustomLLM` handler registered as + ``hackagent_adk_/``. From the router's perspective this + is just another LiteLLM agent. + + Required config: + - ``name``: ADK app name (used as both the model string and the + ``app_name`` in the request payload). + - ``endpoint``: ADK server base URL. + - ``user_id``: User ID for ADK sessions. + + Optional config: + - ``timeout`` (seconds, default 120). + - ``session_id``: sticky session ID; if unset a UUID is generated. + - ``fresh_session_per_request`` (default True): if True, every + request gets a brand-new session unless the caller supplies one. + """ + + ADAPTER_TYPE = "ADKAgent" + + def __init__(self, id: str, config: Dict[str, Any]): + for key in ("name", "endpoint", "user_id"): + if key not in config: + raise AgentConfigurationError( + f"Missing required configuration key '{key}' for ADKAgent: {id}" + ) + + # Provider name is per-instance so each ADKAgent gets its own handler. + # Set on self before super().__init__ runs so that the base's call to + # _resolve_litellm_model (overridden below) sees the right value. + self._provider_name = f"{_ADK_PROVIDER_PREFIX}_{id}" + + adk_endpoint = str(config["endpoint"]).strip("/") + adk_user_id = config["user_id"] + adk_app_name = config["name"] + adk_timeout = int(config.get("timeout", 120)) + fresh = bool(config.get("fresh_session_per_request", True)) + session_id = config.get("session_id") or str(uuid.uuid4()) + + # The base passes ``endpoint`` along to LiteLLM as ``api_base``; we + # don't want that since our custom provider hits ADK directly. + base_config = {k: v for k, v in config.items() if k != "endpoint"} + super().__init__(id, base_config) + + self.endpoint = adk_endpoint + self.user_id = adk_user_id + self.name = adk_app_name + self.timeout = adk_timeout + self.fresh_session_per_request = fresh + self.session_id = session_id + + self._register_custom_provider() + + self.logger.info( + f"ADKAgent '{self.id}' registered as LiteLLM provider " + f"'{self._provider_name}' targeting {self.endpoint} " + f"(app={self.name}, session={self.session_id}, " + f"fresh_session_per_request={self.fresh_session_per_request})" + ) + + def _register_custom_provider(self) -> None: + litellm, available = _get_litellm() + if not available: + raise AgentConfigurationError( + "litellm is required for ADKAgent but is not installed." + ) + + handler_cls = _get_adk_custom_llm_class() + handler = handler_cls( + endpoint=self.endpoint, + app_name=self.name, + user_id=self.user_id, + default_session_id=self.session_id, + fresh_session_per_request=self.fresh_session_per_request, + timeout=self.timeout, + log=self.logger, + ) + + provider = self._provider_name + # Replace any stale entry for this provider name (e.g. when an + # ADKAgent with the same id is re-created during tests). + litellm.custom_provider_map = [ + entry + for entry in litellm.custom_provider_map + if entry.get("provider") != provider + ] + litellm.custom_provider_map.append( + {"provider": provider, "custom_handler": handler} + ) + if provider not in litellm._custom_providers: + litellm._custom_providers.append(provider) + + self._custom_handler = handler + + def _resolve_litellm_model(self, raw_model: str) -> str: + return f"{self._provider_name}/{raw_model}" + + # ---- forward ADK-specific request fields ---------------------------- + + def _get_completion_parameters( + self, request_data: Dict[str, Any] + ) -> Dict[str, Any]: + params = super()._get_completion_parameters(request_data) + session_id = request_data.get("session_id", request_data.get("adk_session_id")) + if session_id: + params["session_id"] = session_id + if "initial_session_state" in request_data: + params["initial_session_state"] = request_data["initial_session_state"] + return params + + def _get_excluded_request_keys(self) -> set: + base = super()._get_excluded_request_keys() + return base | {"session_id", "adk_session_id", "initial_session_state"} + + def _build_agent_specific_data( + self, + completion_result: Dict[str, Any], + parameters: Dict[str, Any], + ) -> Dict[str, Any]: + data = super()._build_agent_specific_data(completion_result, parameters) + raw = completion_result.get("raw_response") + adk_fields: Dict[str, Any] = {} + try: + adk_fields = ( + getattr(raw.choices[0].message, "provider_specific_fields", None) or {} + ) + except (AttributeError, IndexError, TypeError): + adk_fields = {} + events = adk_fields.get("adk_events_list") + if events is not None: + data["adk_events_list"] = events + if "adk_session_id" in adk_fields: + data["adk_session_id"] = adk_fields["adk_session_id"] + return data diff --git a/tests/unit/adapters/test_google_adk.py b/tests/unit/adapters/test_google_adk.py index ad9f7837..86f75a7a 100644 --- a/tests/unit/adapters/test_google_adk.py +++ b/tests/unit/adapters/test_google_adk.py @@ -25,6 +25,7 @@ _extract_final_text, _last_user_text, ) +from hackagent.router.providers import adk as adk_provider_module logging.disable(logging.CRITICAL) @@ -45,6 +46,20 @@ def _make_handler(**overrides): return handler_cls(**defaults) +class TestADKModuleLayout(unittest.TestCase): + """Phase E — ADK lives at router/providers/adk.py with a back-compat shim.""" + + def test_adk_classes_resolve_to_same_object_through_both_paths(self): + self.assertIs(ADKAgent, adk_provider_module.ADKAgent) + self.assertIs( + AgentConfigurationError, adk_provider_module.AgentConfigurationError + ) + + def test_provider_module_exposes_helpers(self): + self.assertIs(_extract_final_text, adk_provider_module._extract_final_text) + self.assertIs(_last_user_text, adk_provider_module._last_user_text) + + class TestADKHelpers(unittest.TestCase): def test_last_user_text_returns_last_user_string(self): messages = [ From 0b0a2e50a4025cf384126f172e151268d3fe3049 Mon Sep 17 00:00:00 2001 From: Nicola Franco Date: Sat, 23 May 2026 17:08:22 +0200 Subject: [PATCH 07/23] feat(router): surface response_cost + call_id, unify status_code (#379 Phase F.1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase F.1 of the LiteLLM router refactor. - ``hackagent/router/envelope.py`` gains ``extract_response_cost`` and ``extract_litellm_call_id`` helpers that pull LiteLLM's ``_hidden_params['response_cost']`` and ``_hidden_params['litellm_call_id']`` (falling back to ``response.id``). Both fields, when present, flow through ``build_agent_specific_data`` into the envelope so downstream traces can correlate input ↔ output ↔ spend without rooting in the raw response object. - ``AgentRouter._build_error_response`` now sets ``status_code`` (the canonical field used by the chat-dispatch envelope) alongside the legacy ``raw_response_status`` alias, eliminating the inconsistency between the two error paths. Co-Authored-By: Claude Opus 4.7 (1M context) --- hackagent/router/envelope.py | 47 +++++++++++++++++++++++++++++- hackagent/router/router.py | 13 +++++++++ tests/unit/router/test_dispatch.py | 37 +++++++++++++++++++++-- tests/unit/router/test_envelope.py | 29 ++++++++++++++++++ 4 files changed, 122 insertions(+), 4 deletions(-) diff --git a/hackagent/router/envelope.py b/hackagent/router/envelope.py index f59e3bf1..6ff89a9c 100644 --- a/hackagent/router/envelope.py +++ b/hackagent/router/envelope.py @@ -288,7 +288,18 @@ def build_agent_specific_data( "invoked_parameters": invoked_parameters, } if completion_result: - for key in ("usage", "finish_reason", "provider_model", "raw_response"): + # ``response_cost`` and ``litellm_call_id`` are free metrics that + # LiteLLM attaches to every successful call — surface them so + # downstream traces can use them without rooting in the raw + # response object. + for key in ( + "usage", + "finish_reason", + "provider_model", + "raw_response", + "response_cost", + "litellm_call_id", + ): value = completion_result.get(key) if value is not None: data[key] = value @@ -297,3 +308,37 @@ def build_agent_specific_data( if extra: data.update(extra) return data + + +# ---- LiteLLM response metadata extraction -------------------------------- + + +def extract_response_cost(response: Any) -> Optional[float]: + """Pull ``response_cost`` off a LiteLLM ``ModelResponse`` if present. + + LiteLLM exposes the per-call cost (when the model is in its pricing + catalogue) via the ``_hidden_params`` attribute. Returns ``None`` + when unavailable rather than raising, since cost tracking is + best-effort. + """ + hidden = getattr(response, "_hidden_params", None) or {} + cost = hidden.get("response_cost") if isinstance(hidden, dict) else None + try: + return float(cost) if cost is not None else None + except (TypeError, ValueError): + return None + + +def extract_litellm_call_id(response: Any) -> Optional[str]: + """Pull ``litellm_call_id`` (or ``x-litellm-call-id``) off a response.""" + hidden = getattr(response, "_hidden_params", None) or {} + if isinstance(hidden, dict): + for key in ("litellm_call_id", "x-litellm-call-id"): + value = hidden.get(key) + if value: + return str(value) + # LiteLLM also sets ``response.id`` to a unique value per call. + response_id = getattr(response, "id", None) + if response_id: + return str(response_id) + return None diff --git a/hackagent/router/router.py b/hackagent/router/router.py index 1f83523c..03b3b9a5 100644 --- a/hackagent/router/router.py +++ b/hackagent/router/router.py @@ -300,6 +300,10 @@ def _build_error_response( "raw_request": raw_request, "processed_response": None, "generated_text": None, + # Phase F.1 — ``status_code`` is the canonical field used by + # the new chat-dispatch envelope; ``raw_response_status`` is + # kept as an alias for legacy callers that read it. + "status_code": status_code, "raw_response_status": status_code, "raw_response_headers": None, "raw_response_body": None, @@ -625,6 +629,15 @@ def _dispatch_via_litellm( completion_result["provider_model"] = response.model except AttributeError: pass + # Phase F.1 — surface LiteLLM's response_cost and call_id so + # downstream traces can join input ↔ output ↔ spend without + # poking at private attributes on the raw response object. + response_cost = _envelope.extract_response_cost(response) + if response_cost is not None: + completion_result["response_cost"] = response_cost + call_id = _envelope.extract_litellm_call_id(response) + if call_id is not None: + completion_result["litellm_call_id"] = call_id agent_specific_data = _envelope.build_agent_specific_data( model_name=model_name, diff --git a/tests/unit/router/test_dispatch.py b/tests/unit/router/test_dispatch.py index 4bf50772..f9b3cce7 100644 --- a/tests/unit/router/test_dispatch.py +++ b/tests/unit/router/test_dispatch.py @@ -192,12 +192,43 @@ def test_caller_supplied_metadata_is_merged_and_wins(self, mock_completion): def test_unknown_registration_key_returns_404_envelope(self): router, _ = self._make_router_for_openai() response = router.route_request("nonexistent-key", {"prompt": "hi"}) - # The legacy router-level AgentNotFound envelope uses - # ``raw_response_status`` rather than ``status_code``. Phase F may - # unify these; we just check the actual current behaviour here. + # Phase F.1 unified the field name; legacy ``raw_response_status`` + # stays as a back-compat alias. + self.assertEqual(response["status_code"], 404) self.assertEqual(response["raw_response_status"], 404) self.assertIn("Agent not found", response["error_message"]) + @patch("litellm.completion") + def test_response_cost_and_call_id_surface_in_envelope(self, mock_completion): + """Phase F.1 — LiteLLM's response_cost / call_id show up in the envelope.""" + response = _make_litellm_response("ok") + response._hidden_params = { + "response_cost": 0.000123, + "litellm_call_id": "call-abc", + } + mock_completion.return_value = response + + router, reg_key = self._make_router_for_openai() + env = router.route_request(reg_key, {"prompt": "hi"}) + + self.assertEqual(env["status_code"], 200) + agent_data = env["agent_specific_data"] + self.assertAlmostEqual(agent_data["response_cost"], 0.000123) + self.assertEqual(agent_data["litellm_call_id"], "call-abc") + + @patch("litellm.completion") + def test_response_cost_absent_when_not_in_hidden_params(self, mock_completion): + """No ``response_cost`` from LiteLLM → envelope omits the field.""" + response = _make_litellm_response("ok") + # Don't set _hidden_params; cost should just be missing. + if hasattr(response, "_hidden_params"): + del response._hidden_params + mock_completion.return_value = response + + router, reg_key = self._make_router_for_openai() + env = router.route_request(reg_key, {"prompt": "hi"}) + self.assertNotIn("response_cost", env["agent_specific_data"]) + class TestDispatchADKBypassesLiteLLM(unittest.TestCase): """Verify ADK requests still flow through the adapter's handle_request.""" diff --git a/tests/unit/router/test_envelope.py b/tests/unit/router/test_envelope.py index 03c90bb6..d8823598 100644 --- a/tests/unit/router/test_envelope.py +++ b/tests/unit/router/test_envelope.py @@ -220,6 +220,35 @@ def test_error_uses_supplied_status(self): self.assertEqual(env["status_code"], 400) +class TestExtractResponseCostAndCallId(unittest.TestCase): + def test_response_cost_pulled_from_hidden_params(self): + response = MagicMock() + response._hidden_params = {"response_cost": 0.0005} + self.assertAlmostEqual(envelope.extract_response_cost(response), 0.0005) + + def test_response_cost_returns_none_when_missing(self): + response = MagicMock() + response._hidden_params = {} + self.assertIsNone(envelope.extract_response_cost(response)) + + def test_response_cost_handles_non_numeric_gracefully(self): + response = MagicMock() + response._hidden_params = {"response_cost": "n/a"} + self.assertIsNone(envelope.extract_response_cost(response)) + + def test_call_id_prefers_hidden_params_over_response_id(self): + response = MagicMock() + response._hidden_params = {"litellm_call_id": "hidden-id"} + response.id = "id-field" + self.assertEqual(envelope.extract_litellm_call_id(response), "hidden-id") + + def test_call_id_falls_back_to_response_id(self): + response = MagicMock() + response._hidden_params = {} + response.id = "id-field" + self.assertEqual(envelope.extract_litellm_call_id(response), "id-field") + + class TestBuildAgentSpecificData(unittest.TestCase): def test_merges_completion_metadata(self): data = envelope.build_agent_specific_data( From c3090e907e87f2ae4bf070721b70771e981d9607 Mon Sep 17 00:00:00 2001 From: Nicola Franco Date: Sat, 23 May 2026 17:14:15 +0200 Subject: [PATCH 08/23] refactor(router): refactor ADKAgent off LiteLLMAgent (#379 Phase E.2a) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase E.2a of the LiteLLM router refactor. ``ADKAgent`` no longer inherits from ``LiteLLMAgent``; it extends :class:`Agent` directly and implements ``handle_request`` itself, calling ``litellm.completion(model="hackagent_adk_/", messages=…)`` which still routes through the per-instance ``_ADKCustomLLM``. The lazy ``_get_litellm`` helper is now defined locally in this module so ADK doesn't depend on ``hackagent.router.adapters.litellm``, which Phase E.2c is about to delete. All ADK public attributes (``litellm_model``, ``name``, ``endpoint``, ``user_id``, ``timeout``, ``session_id``, ``fresh_session_per_request``, ``default_*``) are preserved so the router's dispatch path and external code that pokes at the adapter keep working unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) --- hackagent/router/providers/adk.py | 210 +++++++++++++++++++++--------- 1 file changed, 149 insertions(+), 61 deletions(-) diff --git a/hackagent/router/providers/adk.py b/hackagent/router/providers/adk.py index 5459fd9d..40bf9e93 100644 --- a/hackagent/router/providers/adk.py +++ b/hackagent/router/providers/adk.py @@ -2,15 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 """ -Google ADK (Agent Development Kit) adapter built on top of LiteLLM. +Google ADK (Agent Development Kit) provider built on top of LiteLLM. LiteLLM has no built-in provider for the ADK server protocol (POST /run with sessions and events), so issue #379 routes ADK through LiteLLM by registering a per-instance :class:`litellm.CustomLLM` handler under a unique provider name. The HTTP transport against the deployed ADK server lives in the lazily-defined ``_ADKCustomLLM`` class, while -:class:`ADKAgent` itself is a thin :class:`LiteLLMAgent` subclass that -registers the handler and asks LiteLLM to route through it. +:class:`ADKAgent` registers the handler and dispatches requests via +``litellm.completion``. Since Phase E.2a, :class:`ADKAgent` extends +:class:`Agent` directly (not :class:`LiteLLMAgent`) so the chat-adapter +classes can be deleted in Phase E.2c without affecting ADK. """ import json @@ -20,12 +22,33 @@ import requests +from hackagent.router import envelope as _envelope from hackagent.router.adapters.base import ( + Agent, AdapterConfigurationError, AdapterInteractionError, AdapterResponseParsingError, ) -from hackagent.router.adapters.litellm import LiteLLMAgent, _get_litellm + + +# Local copy of the LiteLLM lazy importer so ADK no longer depends on +# ``hackagent.router.adapters.litellm`` (which is being deleted in +# Phase E.2c). +_litellm_module = None + + +def _get_litellm(): + """Lazily import litellm. Returns ``(module, is_available)``.""" + global _litellm_module + if _litellm_module is not None: + return _litellm_module, True + try: + import litellm + + _litellm_module = litellm + return litellm, True + except ImportError: + return None, False logger = get_logger(__name__) @@ -304,14 +327,15 @@ async def acompletion(self, *args, **kwargs): return _ADKCustomLLM -class ADKAgent(LiteLLMAgent): +class ADKAgent(Agent): """ Adapter for a deployed Google ADK agent server. - The request travels through LiteLLM via a per-instance - :class:`CustomLLM` handler registered as - ``hackagent_adk_/``. From the router's perspective this - is just another LiteLLM agent. + Each instance registers its own :class:`litellm.CustomLLM` handler + under a unique provider name (``hackagent_adk_``) so the call + goes through ``litellm.completion`` like every other LiteLLM + provider — even though LiteLLM has no built-in knowledge of the + ADK ``POST /run`` + sessions + events protocol. Required config: - ``name``: ADK app name (used as both the model string and the @@ -335,29 +359,32 @@ def __init__(self, id: str, config: Dict[str, Any]): f"Missing required configuration key '{key}' for ADKAgent: {id}" ) - # Provider name is per-instance so each ADKAgent gets its own handler. - # Set on self before super().__init__ runs so that the base's call to - # _resolve_litellm_model (overridden below) sees the right value. - self._provider_name = f"{_ADK_PROVIDER_PREFIX}_{id}" + super().__init__(id, config) + self._init_generation_params() - adk_endpoint = str(config["endpoint"]).strip("/") - adk_user_id = config["user_id"] - adk_app_name = config["name"] - adk_timeout = int(config.get("timeout", 120)) - fresh = bool(config.get("fresh_session_per_request", True)) - session_id = config.get("session_id") or str(uuid.uuid4()) - - # The base passes ``endpoint`` along to LiteLLM as ``api_base``; we - # don't want that since our custom provider hits ADK directly. - base_config = {k: v for k, v in config.items() if k != "endpoint"} - super().__init__(id, base_config) - - self.endpoint = adk_endpoint - self.user_id = adk_user_id - self.name = adk_app_name - self.timeout = adk_timeout - self.fresh_session_per_request = fresh - self.session_id = session_id + self.name: str = config["name"] + self.model_name = self.name # for the base ``Agent`` envelope helpers + self.endpoint: str = str(config["endpoint"]).strip("/") + self.user_id: str = config["user_id"] + self.timeout: int = int(config.get("timeout", 120)) + self.fresh_session_per_request: bool = bool( + config.get("fresh_session_per_request", True) + ) + self.session_id: str = config.get("session_id") or str(uuid.uuid4()) + + # Per-instance LiteLLM provider name + the model string the + # router will call ``litellm.completion(model=...)`` with. + self._provider_name = f"{_ADK_PROVIDER_PREFIX}_{id}" + self.litellm_model = f"{self._provider_name}/{self.name}" + # Kept for backwards compatibility with code that read these off + # the legacy ``LiteLLMAgent`` base; ADK has no API base/key of + # its own (the custom provider talks to the ADK server itself). + self.api_base_url: Optional[str] = None + self.actual_api_key: Optional[str] = None + self.default_thinking = None + self.default_tools = None + self.default_tool_choice = None + self.default_extra_body = None self._register_custom_provider() @@ -402,43 +429,104 @@ def _register_custom_provider(self) -> None: self._custom_handler = handler - def _resolve_litellm_model(self, raw_model: str) -> str: - return f"{self._provider_name}/{raw_model}" - - # ---- forward ADK-specific request fields ---------------------------- + # ---- request handling ---------------------------------------------- + + def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """Send a single ADK turn via ``litellm.completion``. + + Implemented directly on :class:`ADKAgent` so the class no longer + depends on ``LiteLLMAgent`` (which Phase E.2c deletes). The + request flow is the same as before: + + request_data → litellm.completion(model="hackagent_adk_/", + messages=…, session_id=…) + → _ADKCustomLLM.completion → ADK ``/run`` + """ + is_valid, prompt_text, messages = self._validate_request(request_data) + if not is_valid: + return self._build_error_response( + error_message=( + "Request data must include either 'messages' or 'prompt' field." + ), + status_code=400, + raw_request=request_data, + ) + if not messages: + messages = self._prompt_to_messages(prompt_text) # type: ignore[arg-type] - def _get_completion_parameters( - self, request_data: Dict[str, Any] - ) -> Dict[str, Any]: - params = super()._get_completion_parameters(request_data) + # ADK-specific knobs that the custom handler reads out of + # ``optional_params``. ``adk_session_id`` is a legacy alias. session_id = request_data.get("session_id", request_data.get("adk_session_id")) + initial_session_state = request_data.get("initial_session_state") + + litellm, available = _get_litellm() + if not available: + return self._build_error_response( + error_message="litellm is not installed", + status_code=500, + raw_request=request_data, + ) + + kwargs: Dict[str, Any] = { + "model": self.litellm_model, + "messages": messages, + } if session_id: - params["session_id"] = session_id - if "initial_session_state" in request_data: - params["initial_session_state"] = request_data["initial_session_state"] - return params - - def _get_excluded_request_keys(self) -> set: - base = super()._get_excluded_request_keys() - return base | {"session_id", "adk_session_id", "initial_session_state"} - - def _build_agent_specific_data( - self, - completion_result: Dict[str, Any], - parameters: Dict[str, Any], - ) -> Dict[str, Any]: - data = super()._build_agent_specific_data(completion_result, parameters) - raw = completion_result.get("raw_response") + kwargs["session_id"] = session_id + if initial_session_state is not None: + kwargs["initial_session_state"] = initial_session_state + + try: + response = litellm.completion(**kwargs) + except Exception as exc: + self.logger.exception( + f"ADK litellm dispatch failed for agent {self.id}: {exc}" + ) + return self._build_error_response( + error_message=( + f"{self.ADAPTER_TYPE} error ({type(exc).__name__}): {exc}" + ), + status_code=500, + raw_request=request_data, + ) + + text = _envelope.extract_text_from_response( + response, model_name=self.litellm_model + ) + if isinstance(text, str) and text.startswith("[GENERATION_ERROR:"): + return self._build_error_response( + error_message=f"{self.ADAPTER_TYPE} generation error: {text}", + status_code=500, + raw_request=request_data, + ) + + # The custom handler stashes ADK events/session_id on + # ``provider_specific_fields`` — pull them back out for the + # envelope. adk_fields: Dict[str, Any] = {} try: adk_fields = ( - getattr(raw.choices[0].message, "provider_specific_fields", None) or {} + getattr(response.choices[0].message, "provider_specific_fields", None) + or {} ) except (AttributeError, IndexError, TypeError): adk_fields = {} - events = adk_fields.get("adk_events_list") - if events is not None: - data["adk_events_list"] = events + + invoked_parameters: Dict[str, Any] = {} + if session_id: + invoked_parameters["session_id"] = session_id + agent_specific_data = _envelope.build_agent_specific_data( + model_name=self.litellm_model, + invoked_parameters=invoked_parameters, + ) + if adk_fields.get("adk_events_list") is not None: + agent_specific_data["adk_events_list"] = adk_fields["adk_events_list"] if "adk_session_id" in adk_fields: - data["adk_session_id"] = adk_fields["adk_session_id"] - return data + agent_specific_data["adk_session_id"] = adk_fields["adk_session_id"] + + return self._build_success_response( + processed_response=text, + raw_request=request_data, + raw_response_body=response, + agent_specific_data=agent_specific_data, + ) From 5b7b9aff8345fe4815604642a032b1d801d43374 Mon Sep 17 00:00:00 2001 From: Nicola Franco Date: Sat, 23 May 2026 17:26:11 +0200 Subject: [PATCH 09/23] refactor(router): chat AgentTypes use _ChatRegistration (#379 Phase E.2b) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase E.2b of the LiteLLM router refactor. ``AgentRouter`` no longer instantiates ``LiteLLMAgent`` / ``OpenAIAgent`` / ``OllamaAgent`` for chat-completion AgentTypes; it builds a lightweight ``_ChatRegistration`` config holder instead. The dispatch path is unchanged — it reads the same attribute names (``litellm_model``, ``api_base_url``, ``actual_api_key``, ``default_*``…) off whichever object is stored in ``_agent_registry``. ``_ChatRegistration`` covers the two adapter-class quirks the dispatch path didn't yet own: - OpenAI custom endpoint without API key → placeholder ``"not-required"``. - OpenAI custom endpoint without model name → defaults to ``"default"``. - Ollama default endpoint resolution + trailing ``/api/*`` stripping. ADK still uses ``ADKAgent`` because its CustomLLM registration with LiteLLM is a per-instance side-effect. The chat adapter classes (``LiteLLMAgent`` / ``OpenAIAgent`` / ``OllamaAgent``) remain importable but are no longer on any hot path or instantiated by the router. Integration tests updated to assert against the ``_ChatRegistration`` shape; Phase E.2c will delete the adapter classes and the obsolete unit tests. Co-Authored-By: Claude Opus 4.7 (1M context) --- hackagent/router/_chat_registration.py | 183 ++++++++++++++++++++ hackagent/router/router.py | 42 +++-- tests/integration/adapters/test_litellm.py | 16 +- tests/integration/adapters/test_openai.py | 16 +- tests/unit/router/test_chat_registration.py | 116 +++++++++++++ tests/unit/router/test_router.py | 103 +++++------ 6 files changed, 388 insertions(+), 88 deletions(-) create mode 100644 hackagent/router/_chat_registration.py create mode 100644 tests/unit/router/test_chat_registration.py diff --git a/hackagent/router/_chat_registration.py b/hackagent/router/_chat_registration.py new file mode 100644 index 00000000..af3c3d5e --- /dev/null +++ b/hackagent/router/_chat_registration.py @@ -0,0 +1,183 @@ +# Copyright 2026 - AI4I. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Lightweight per-registration config used by ``AgentRouter`` for +chat-completion AgentTypes. + +Phase E.2 of the LiteLLM router refactor (issue #379) replaces the +``LiteLLMAgent`` / ``OpenAIAgent`` / ``OllamaAgent`` adapter instances +in ``AgentRouter._agent_registry`` with instances of this class. The +router's ``_dispatch_via_litellm`` reads the same attributes off either +object (``litellm_model``, ``api_base_url``, ``actual_api_key``, +``default_*``…), so consumers that mutate ``adapter.default_max_tokens`` +or similar keep working. + +``ADKAgent`` is unaffected — it stays as an :class:`Agent` subclass +because its custom-LLM registration with LiteLLM is per-instance and +needs construction-time side effects. +""" + +from __future__ import annotations + +import os +from typing import Any, Dict, Optional + +from hackagent.logger import get_logger +from hackagent.router import envelope as _envelope +from hackagent.router.provider_config import ProviderConfig +from hackagent.router.types import AgentTypeEnum + +logger = get_logger(__name__) + + +# ---- per-AgentType config normalisation --------------------------------- +# These helpers cover the small adapter-class quirks that used to live +# in ``OpenAIAgent.__init__`` and ``OllamaAgent.__init__``. + +_OLLAMA_DEFAULT_ENDPOINT = "http://localhost:11434" + + +def _normalise_ollama_endpoint(endpoint: Optional[str]) -> str: + """Resolve & normalise the Ollama endpoint URL the way OllamaAgent did.""" + resolved = endpoint or os.environ.get("OLLAMA_BASE_URL", _OLLAMA_DEFAULT_ENDPOINT) + resolved = resolved.rstrip("/") + for suffix in ("/api/generate", "/api/chat", "/api/tags", "/api/show", "/api"): + if resolved.endswith(suffix): + resolved = resolved[: -len(suffix)] + break + return resolved + + +def _resolve_api_key_from_config( + config: Dict[str, Any], env_var_fallback: Optional[str] +) -> Optional[str]: + """Mirror the API-key resolution path in ``Agent._resolve_api_key``.""" + api_key_config = config.get("api_key") + if api_key_config: + # The config value may itself be an env-var name. + env_val = os.environ.get(api_key_config) + if env_val: + return env_val + return api_key_config + if env_var_fallback: + env_val = os.environ.get(env_var_fallback) + if env_val: + return env_val + return None + + +def _default_api_key_env_var( + litellm_model: str, api_base_url: Optional[str] +) -> Optional[str]: + if api_base_url: + return None + if litellm_model.startswith(("openai/", "gpt-")): + return "OPENAI_API_KEY" + if litellm_model.startswith(("anthropic/", "claude-")): + return "ANTHROPIC_API_KEY" + return None + + +class _ChatRegistration: + """Mutable config holder consumed by ``AgentRouter._dispatch_via_litellm``. + + Exposes exactly the attributes the dispatch path and external code + used to read off ``LiteLLMAgent`` / ``OpenAIAgent`` / ``OllamaAgent`` + instances: ``id``, ``ADAPTER_TYPE``, ``model_name``, + ``litellm_model``, ``api_base_url``, ``actual_api_key``, + ``default_max_tokens``, ``default_temperature``, ``default_top_p``, + ``default_thinking``, ``default_tools``, ``default_tool_choice``, + ``default_extra_body``, and any Ollama-specific extras + (``default_top_k``, ``default_num_ctx``, ``default_stream``). + """ + + DEFAULT_MAX_TOKENS: int = 100 + DEFAULT_TEMPERATURE: float = 0.8 + DEFAULT_TOP_P: float = 0.95 + + def __init__( + self, + *, + id: str, + agent_type: AgentTypeEnum, + provider_config: ProviderConfig, + config: Dict[str, Any], + ): + self.id = id + self.agent_type = agent_type + self.config: Dict[str, Any] = dict(config) + self.ADAPTER_TYPE: str = provider_config.adapter_label + + # ---- model + endpoint ---- + # OpenAI custom-endpoint quirk: if endpoint is set but no model + # name, default to ``"default"`` so the server decides. + if "name" not in self.config: + if agent_type == AgentTypeEnum.OPENAI_SDK and self.config.get("endpoint"): + self.model_name = "default" + else: + raise ValueError( + f"Missing required configuration key 'name' for " + f"{provider_config.adapter_label}: {id}" + ) + else: + self.model_name = self.config["name"] + + # Ollama special-cases the endpoint default + normalisation. + if agent_type == AgentTypeEnum.OLLAMA: + self.api_base_url: Optional[str] = _normalise_ollama_endpoint( + self.config.get("endpoint") + ) + else: + self.api_base_url = self.config.get("endpoint") or None + + self.litellm_model = _envelope.resolve_litellm_model( + self.model_name, provider_prefix=provider_config.provider_prefix + ) + + # ---- API key resolution ---- + env_var_fallback = _default_api_key_env_var( + self.litellm_model, self.api_base_url + ) + self.actual_api_key: Optional[str] = _resolve_api_key_from_config( + self.config, env_var_fallback + ) + # OpenAI custom-endpoint quirk: when no key is configured but an + # endpoint is, use a placeholder so the OpenAI client (under + # LiteLLM) doesn't choke. + if ( + agent_type == AgentTypeEnum.OPENAI_SDK + and not self.actual_api_key + and self.api_base_url + ): + self.actual_api_key = "not-required" + + # ---- generation defaults ---- + self.default_max_tokens: int = self.config.get( + "max_tokens", self.DEFAULT_MAX_TOKENS + ) + # OpenAI's default temperature historically was 1.0; everyone else is 0.8. + self.default_temperature: float = self.config.get( + "temperature", + 1.0 if agent_type == AgentTypeEnum.OPENAI_SDK else self.DEFAULT_TEMPERATURE, + ) + self.default_top_p: float = self.config.get("top_p", self.DEFAULT_TOP_P) + self.default_thinking = self.config.get("thinking") + self.default_tools = self.config.get("tools") + self.default_tool_choice = self.config.get("tool_choice") + self.default_extra_body = self.config.get("extra_body") + + # Ollama extras — also tolerated for the other AgentTypes but only + # used by LiteLLM for ``ollama_chat/``. + self.default_top_k = self.config.get("top_k") + self.default_num_ctx = self.config.get("num_ctx") + self.default_stream = self.config.get("stream", False) + + logger.info( + f"{self.ADAPTER_TYPE} '{self.id}' registered for LiteLLM model: " + f"'{self.litellm_model}'" + + (f" API Base: '{self.api_base_url}'" if self.api_base_url else "") + ) + + def get_identifier(self) -> str: + return self.id diff --git a/hackagent/router/router.py b/hackagent/router/router.py index 03b3b9a5..0d295f8a 100644 --- a/hackagent/router/router.py +++ b/hackagent/router/router.py @@ -7,13 +7,16 @@ from hackagent.server.storage.base import AgentRecord, StorageBackend from hackagent.router import envelope as _envelope from hackagent.router import tracking_logger as _tracking_logger +from hackagent.router._chat_registration import _ChatRegistration from hackagent.router.adapters.base import Agent from hackagent.router.provider_config import ProviderConfig, get_provider_config from hackagent.router.types import AgentTypeEnum -# Adapter imports - these are imported at module level for backwards compatibility -# with test patching (tests patch hackagent.router.router.LiteLLMAgent etc.) -# The actual heavy dependency (litellm) is lazy-loaded within LiteLLMAgent +# Adapter imports - these stay at module level for tests that still patch +# ``hackagent.router.router.LiteLLMAgent`` etc. As of Phase E.2 chat +# AgentTypes no longer instantiate these classes — they build a +# ``_ChatRegistration`` instead. ADK still uses ``ADKAgent`` because its +# CustomLLM registration is per-instance. from hackagent.router.adapters import ADKAgent from hackagent.router.adapters.litellm import LiteLLMAgent, _get_litellm from hackagent.router.adapters.openai import OpenAIAgent @@ -232,15 +235,34 @@ def _configure_and_instantiate_adapter( ) adapter_instance_config["user_id"] = self.user_id_str + provider_config = get_provider_config(agent_type) + try: + if provider_config is not None: + # Phase E.2b — chat AgentTypes no longer go through the + # heavy adapter classes; the router stores a lightweight + # ``_ChatRegistration`` that ``_dispatch_via_litellm`` reads + # off. Adapter classes remain importable for back-compat. + logger.debug( + f"ROUTER_DEBUG: Building _ChatRegistration for " + f"'{registration_key}' (Type: {agent_type.value}), " + f"config_keys={list(adapter_instance_config.keys())}" + ) + adapter_instance: Any = _ChatRegistration( + id=registration_key, + agent_type=agent_type, + provider_config=provider_config, + config=adapter_instance_config, + ) + else: + logger.debug( + f"ROUTER_DEBUG: About to call adapter_class(id='{registration_key}', config_keys={list(adapter_instance_config.keys())})" + ) + adapter_instance = adapter_class( + id=registration_key, config=adapter_instance_config + ) logger.debug( - f"ROUTER_DEBUG: About to call adapter_class(id='{registration_key}', config_keys={list(adapter_instance_config.keys())})" - ) - adapter_instance = adapter_class( - id=registration_key, config=adapter_instance_config - ) - logger.debug( - f"ROUTER_DEBUG: Called adapter_class. Resulting instance: {adapter_instance}, type: {type(adapter_instance)}" + f"ROUTER_DEBUG: Resulting instance: {adapter_instance}, type: {type(adapter_instance)}" ) self._agent_registry[registration_key] = adapter_instance self._agent_types[registration_key] = agent_type diff --git a/tests/integration/adapters/test_litellm.py b/tests/integration/adapters/test_litellm.py index 8a28c1d5..7a18a0c9 100644 --- a/tests/integration/adapters/test_litellm.py +++ b/tests/integration/adapters/test_litellm.py @@ -337,12 +337,12 @@ def test_router_creates_litellm_adapter( litellm_model: str, ollama_base_url: str, ): - """Test that AgentRouter correctly creates LiteLLMAgent adapter.""" + """Test that AgentRouter correctly creates the LITELLM registration.""" from hackagent.server.client import AuthenticatedClient from hackagent.server.storage.remote import RemoteBackend from hackagent.router.router import AgentRouter from hackagent.router.types import AgentTypeEnum - from hackagent.router.adapters.litellm import LiteLLMAgent + from hackagent.router._chat_registration import _ChatRegistration client = AuthenticatedClient( base_url=hackagent_api_base_url, @@ -364,12 +364,14 @@ def test_router_creates_litellm_adapter( endpoint=endpoint, ) - # Verify adapter was created + # Since #379 Phase E.2b the router stores a ``_ChatRegistration`` + # for chat AgentTypes; the adapter classes are no longer + # instantiated. agent_id = str(router.backend_agent.id) - adapter = router.get_agent_instance(registration_key=agent_id) - - assert isinstance(adapter, LiteLLMAgent) - logger.info(f"Router created LiteLLM adapter: {adapter.id}") + registration = router.get_agent_instance(registration_key=agent_id) + assert isinstance(registration, _ChatRegistration) + assert registration.ADAPTER_TYPE == "LiteLLMAgent" + logger.info(f"Router created LiteLLM registration: {registration.id}") def test_router_handles_litellm_request( self, diff --git a/tests/integration/adapters/test_openai.py b/tests/integration/adapters/test_openai.py index 4790bf8c..164e6d81 100644 --- a/tests/integration/adapters/test_openai.py +++ b/tests/integration/adapters/test_openai.py @@ -327,11 +327,11 @@ def test_router_creates_openai_adapter( openai_model: str, openai_base_url: str, ): - """Test that AgentRouter correctly creates OpenAIAgent adapter.""" + """Test that AgentRouter correctly creates the OpenAI registration.""" from hackagent.server.client import AuthenticatedClient from hackagent.router.router import AgentRouter from hackagent.router.types import AgentTypeEnum - from hackagent.router.adapters.openai import OpenAIAgent + from hackagent.router._chat_registration import _ChatRegistration client = AuthenticatedClient( base_url=hackagent_api_base_url, @@ -349,12 +349,14 @@ def test_router_creates_openai_adapter( endpoint=openai_base_url, ) - # Verify adapter was created + # Since #379 Phase E.2b the router stores a ``_ChatRegistration`` + # for chat AgentTypes; the adapter classes are no longer + # instantiated. agent_id = str(router.backend_agent.id) - adapter = router.get_agent_instance(registration_key=agent_id) - - assert isinstance(adapter, OpenAIAgent) - logger.info(f"Router created OpenAI adapter: {adapter.id}") + registration = router.get_agent_instance(registration_key=agent_id) + assert isinstance(registration, _ChatRegistration) + assert registration.ADAPTER_TYPE == "OpenAIAgent" + logger.info(f"Router created OpenAI registration: {registration.id}") def test_router_handles_openai_request( self, diff --git a/tests/unit/router/test_chat_registration.py b/tests/unit/router/test_chat_registration.py new file mode 100644 index 00000000..cea30587 --- /dev/null +++ b/tests/unit/router/test_chat_registration.py @@ -0,0 +1,116 @@ +# Copyright 2026 - AI4I. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for ``hackagent/router/_chat_registration.py``.""" + +import logging +import os +import unittest +from unittest.mock import patch + +from hackagent.router._chat_registration import _ChatRegistration +from hackagent.router.provider_config import get_provider_config +from hackagent.router.types import AgentTypeEnum + +logging.disable(logging.CRITICAL) + + +def _build(agent_type: AgentTypeEnum, config) -> _ChatRegistration: + return _ChatRegistration( + id="reg-id", + agent_type=agent_type, + provider_config=get_provider_config(agent_type), + config=config, + ) + + +class TestOpenAIRegistration(unittest.TestCase): + def test_basic_openai_attributes(self): + reg = _build(AgentTypeEnum.OPENAI_SDK, {"name": "gpt-4"}) + self.assertEqual(reg.model_name, "gpt-4") + self.assertEqual(reg.litellm_model, "openai/gpt-4") + self.assertEqual(reg.ADAPTER_TYPE, "OpenAIAgent") + # OpenAI's default temperature is historically 1.0. + self.assertEqual(reg.default_temperature, 1.0) + + def test_custom_endpoint_without_api_key_uses_placeholder(self): + reg = _build( + AgentTypeEnum.OPENAI_SDK, + {"name": "gpt-4", "endpoint": "https://proxy/v1"}, + ) + self.assertEqual(reg.api_base_url, "https://proxy/v1") + self.assertEqual(reg.actual_api_key, "not-required") + + def test_custom_endpoint_defaults_model_name_to_default(self): + reg = _build(AgentTypeEnum.OPENAI_SDK, {"endpoint": "https://example.com/v1"}) + self.assertEqual(reg.model_name, "default") + + @patch.dict(os.environ, {"CUSTOM_API_KEY": "sk-test"}) + def test_api_key_resolved_from_env(self): + reg = _build( + AgentTypeEnum.OPENAI_SDK, + {"name": "gpt-4", "api_key": "CUSTOM_API_KEY"}, + ) + self.assertEqual(reg.actual_api_key, "sk-test") + + def test_preserves_existing_provider_prefix(self): + reg = _build(AgentTypeEnum.OPENAI_SDK, {"name": "openai/gpt-4"}) + self.assertEqual(reg.litellm_model, "openai/gpt-4") + + +class TestOllamaRegistration(unittest.TestCase): + def test_basic_ollama_attributes(self): + reg = _build(AgentTypeEnum.OLLAMA, {"name": "llama3"}) + self.assertEqual(reg.litellm_model, "ollama_chat/llama3") + self.assertEqual(reg.api_base_url, "http://localhost:11434") + self.assertEqual(reg.ADAPTER_TYPE, "OllamaAgent") + + def test_endpoint_normalisation_strips_api_suffix(self): + reg = _build( + AgentTypeEnum.OLLAMA, + {"name": "llama3", "endpoint": "http://host:11434/api/chat/"}, + ) + self.assertEqual(reg.api_base_url, "http://host:11434") + + @patch.dict(os.environ, {"OLLAMA_BASE_URL": "http://env-ollama:11434"}) + def test_env_var_endpoint_fallback(self): + reg = _build(AgentTypeEnum.OLLAMA, {"name": "llama3"}) + self.assertEqual(reg.api_base_url, "http://env-ollama:11434") + + def test_top_k_num_ctx_stream_recorded(self): + reg = _build( + AgentTypeEnum.OLLAMA, + { + "name": "llama3", + "top_k": 40, + "num_ctx": 8192, + "stream": True, + }, + ) + self.assertEqual(reg.default_top_k, 40) + self.assertEqual(reg.default_num_ctx, 8192) + self.assertTrue(reg.default_stream) + + +class TestLiteLLMRegistration(unittest.TestCase): + def test_no_provider_prefix_when_litellm_passthrough(self): + reg = _build(AgentTypeEnum.LITELLM, {"name": "ollama/llama3"}) + self.assertEqual(reg.litellm_model, "ollama/llama3") + self.assertEqual(reg.ADAPTER_TYPE, "LiteLLMAgent") + + def test_missing_name_raises(self): + with self.assertRaises(ValueError): + _build(AgentTypeEnum.LITELLM, {}) + + +class TestRegistrationMutability(unittest.TestCase): + """External code mutates ``adapter.default_max_tokens``; that must work.""" + + def test_default_max_tokens_is_mutable(self): + reg = _build(AgentTypeEnum.OPENAI_SDK, {"name": "gpt-4"}) + reg.default_max_tokens = 500 + self.assertEqual(reg.default_max_tokens, 500) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/router/test_router.py b/tests/unit/router/test_router.py index 027aeb37..29badbb9 100644 --- a/tests/unit/router/test_router.py +++ b/tests/unit/router/test_router.py @@ -284,19 +284,9 @@ def test_agent_router_init_existing_agent_metadata_differs_overwrite_false( self.assertEqual(router.backend_agent.metadata, existing_metadata) self.assertEqual(router.backend_agent.endpoint, existing_endpoint) - @patch("hackagent.router.router.LiteLLMAgent", autospec=True) - @patch("hackagent.router.router.ADKAgent", autospec=True) - @patch("hackagent.router.router.AGENT_TYPE_TO_ADAPTER_MAP", new_callable=dict) - def test_agent_router_init_creates_new_litellm_agent( - self, - MockAgentMap, - MockADKAdapter, - MockLiteLLMAdapter, - ): - MockAgentMap[AgentTypeEnum.LITELLM] = MockLiteLLMAdapter - MockAgentMap[AgentTypeEnum.GOOGLE_ADK] = MockADKAdapter - MockADKAdapter.__name__ = "ADKAgent" - MockLiteLLMAdapter.__name__ = "LiteLLMAgent" + def test_agent_router_init_creates_new_litellm_agent(self): + """Chat AgentTypes now register a ``_ChatRegistration`` (Phase E.2b).""" + from hackagent.router._chat_registration import _ChatRegistration mock_org_id = uuid.uuid4() mock_backend = _make_backend(org_id=mock_org_id, user_id="789") @@ -325,18 +315,18 @@ def test_agent_router_init_creates_new_litellm_agent( overwrite_metadata=True, ) - MockADKAdapter.assert_not_called() - MockLiteLLMAdapter.assert_called_once() - mock_litellm_instance = MockLiteLLMAdapter.return_value - adapter_kwargs = MockLiteLLMAdapter.call_args[1] - self.assertEqual(adapter_kwargs["id"], str(created_id)) - actual_config = adapter_kwargs["config"] - self.assertEqual(actual_config["name"], "gpt-3.5-turbo") - self.assertEqual(actual_config["endpoint"], agent_endpoint) - self.assertEqual(actual_config["api_key"], "env_var_for_llm_key") - self.assertEqual(actual_config["temperature"], 0.8) - self.assertIn(str(created_id), router._agent_registry) - self.assertEqual(router._agent_registry[str(created_id)], mock_litellm_instance) + reg_key = str(created_id) + self.assertIn(reg_key, router._agent_registry) + registration = router._agent_registry[reg_key] + self.assertIsInstance(registration, _ChatRegistration) + self.assertEqual(registration.id, reg_key) + self.assertEqual(registration.model_name, "gpt-3.5-turbo") + self.assertEqual(registration.api_base_url, agent_endpoint) + # ``api_key`` config value is also a valid env var name; when the + # env var doesn't exist it falls through as the literal value. + self.assertEqual(registration.actual_api_key, "env_var_for_llm_key") + self.assertEqual(registration.default_temperature, 0.8) + self.assertEqual(registration.ADAPTER_TYPE, "LiteLLMAgent") class TestAnyUrlEndpointConversion(unittest.TestCase): @@ -368,89 +358,74 @@ def test_adk_adapter_receives_str_endpoint_when_backend_returns_anyurl( self.assertIsInstance(endpoint_value, str) self.assertEqual(endpoint_value, "http://adk-endpoint.com/") - @patch("hackagent.router.router.LiteLLMAgent", autospec=True) - @patch("hackagent.router.router.AGENT_TYPE_TO_ADAPTER_MAP", new_callable=dict) - def test_litellm_adapter_receives_str_endpoint_when_backend_returns_anyurl( - self, - MockAgentMap, - MockLiteLLMAdapter, - ): + def test_litellm_chat_registration_has_str_endpoint(self): + """Phase E.2b — chat AgentTypes store ``_ChatRegistration`` with str endpoint.""" from pydantic import AnyUrl - MockAgentMap[AgentTypeEnum.LITELLM] = MockLiteLLMAdapter - MockLiteLLMAdapter.__name__ = "LiteLLMAgent" mock_backend = _make_backend() + agent_id = uuid.uuid4() mock_backend.create_or_update_agent.return_value = _make_agent_rec( + agent_id=agent_id, agent_type_str="LITELLM", endpoint=AnyUrl("http://litellm-endpoint.com/"), metadata={"name": "gpt-4"}, ) - _ = AgentRouter( + router = AgentRouter( backend=mock_backend, name="TestLiteLLMAgent", agent_type=AgentTypeEnum.LITELLM, endpoint="http://litellm-endpoint.com/", metadata={"name": "gpt-4"}, ) - endpoint_value = MockLiteLLMAdapter.call_args[1]["config"]["endpoint"] - self.assertIsInstance(endpoint_value, str) - self.assertEqual(endpoint_value, "http://litellm-endpoint.com/") + registration = router._agent_registry[str(agent_id)] + self.assertIsInstance(registration.api_base_url, str) + self.assertEqual(registration.api_base_url, "http://litellm-endpoint.com/") - @patch("hackagent.router.router.OpenAIAgent", autospec=True) - @patch("hackagent.router.router.AGENT_TYPE_TO_ADAPTER_MAP", new_callable=dict) - def test_openai_adapter_receives_str_endpoint_when_backend_returns_anyurl( - self, - MockAgentMap, - MockOpenAIAdapter, - ): + def test_openai_chat_registration_has_str_endpoint(self): from pydantic import AnyUrl - MockAgentMap[AgentTypeEnum.OPENAI_SDK] = MockOpenAIAdapter - MockOpenAIAdapter.__name__ = "OpenAIAgent" mock_backend = _make_backend() + agent_id = uuid.uuid4() mock_backend.create_or_update_agent.return_value = _make_agent_rec( + agent_id=agent_id, agent_type_str="OPENAI_SDK", endpoint=AnyUrl("http://openai-endpoint.com/v1/"), metadata={"name": "gpt-4o"}, ) - _ = AgentRouter( + router = AgentRouter( backend=mock_backend, name="TestOpenAIAgent", agent_type=AgentTypeEnum.OPENAI_SDK, endpoint="http://openai-endpoint.com/v1/", metadata={"name": "gpt-4o"}, ) - endpoint_value = MockOpenAIAdapter.call_args[1]["config"]["endpoint"] - self.assertIsInstance(endpoint_value, str) - self.assertEqual(endpoint_value, "http://openai-endpoint.com/v1/") + registration = router._agent_registry[str(agent_id)] + self.assertIsInstance(registration.api_base_url, str) + self.assertEqual(registration.api_base_url, "http://openai-endpoint.com/v1/") - @patch("hackagent.router.router.OllamaAgent", autospec=True) - @patch("hackagent.router.router.AGENT_TYPE_TO_ADAPTER_MAP", new_callable=dict) - def test_ollama_adapter_receives_str_endpoint_when_backend_returns_anyurl( - self, - MockAgentMap, - MockOllamaAdapter, - ): + def test_ollama_chat_registration_has_str_endpoint(self): + """Ollama still applies its endpoint normalisation rules.""" from pydantic import AnyUrl - MockAgentMap[AgentTypeEnum.OLLAMA] = MockOllamaAdapter - MockOllamaAdapter.__name__ = "OllamaAgent" mock_backend = _make_backend() + agent_id = uuid.uuid4() mock_backend.create_or_update_agent.return_value = _make_agent_rec( + agent_id=agent_id, agent_type_str="OLLAMA", endpoint=AnyUrl("http://ollama-endpoint.com/"), metadata={"name": "llama3"}, ) - _ = AgentRouter( + router = AgentRouter( backend=mock_backend, name="TestOllamaAgent", agent_type=AgentTypeEnum.OLLAMA, endpoint="http://ollama-endpoint.com/", metadata={"name": "llama3"}, ) - endpoint_value = MockOllamaAdapter.call_args[1]["config"]["endpoint"] - self.assertIsInstance(endpoint_value, str) - self.assertEqual(endpoint_value, "http://ollama-endpoint.com/") + registration = router._agent_registry[str(agent_id)] + self.assertIsInstance(registration.api_base_url, str) + # Trailing slash is stripped by Ollama's normaliser. + self.assertEqual(registration.api_base_url, "http://ollama-endpoint.com") class TestMetadataNoneStripping(unittest.TestCase): From e3de58f426fc85a36dd5591214470cf221341500 Mon Sep 17 00:00:00 2001 From: Nicola Franco Date: Sat, 23 May 2026 17:34:18 +0200 Subject: [PATCH 10/23] refactor(router): delete chat adapter classes (#379 Phase E.2c) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase E.2c (final phase) of the LiteLLM router refactor. Deletes: - hackagent/router/adapters/litellm.py - hackagent/router/adapters/openai.py - hackagent/router/adapters/ollama.py - tests/unit/adapters/test_{litellm,openai,ollama}.py - tests/integration/adapters/test_{litellm,openai,ollama}.py These adapter classes haven't been on the hot path since Phase C and the router stopped instantiating them in Phase E.2b — every chat AgentType now goes through ``AgentRouter._dispatch_via_litellm`` with config supplied by ``_ChatRegistration``. Coverage moves to ``tests/unit/router/test_dispatch.py`` and ``tests/unit/router/test_chat_registration.py``. ``hackagent/router/adapters/__init__.py`` is reduced to the exception classes + an ``ADKAgent`` re-export. ``hackagent/router/__init__.py`` drops the public ``OllamaAgent`` symbol. ``AGENT_TYPE_TO_ADAPTER_MAP`` now only carries the ADK entry; chat-type validation goes through ``get_provider_config``. Co-Authored-By: Claude Opus 4.7 (1M context) --- LITELLM_ROUTER_REFACTOR_PLAN.md | 34 +- hackagent/router/__init__.py | 8 +- hackagent/router/adapters/__init__.py | 47 +- hackagent/router/adapters/litellm.py | 512 --------------------- hackagent/router/adapters/ollama.py | 150 ------ hackagent/router/adapters/openai.py | 79 ---- hackagent/router/router.py | 68 +-- tests/integration/adapters/test_litellm.py | 509 -------------------- tests/integration/adapters/test_ollama.py | 346 -------------- tests/integration/adapters/test_openai.py | 460 ------------------ tests/unit/adapters/test_litellm.py | 274 ----------- tests/unit/adapters/test_ollama.py | 220 --------- tests/unit/adapters/test_openai.py | 236 ---------- tests/unit/router/test_router.py | 39 +- 14 files changed, 83 insertions(+), 2899 deletions(-) delete mode 100644 hackagent/router/adapters/litellm.py delete mode 100644 hackagent/router/adapters/ollama.py delete mode 100644 hackagent/router/adapters/openai.py delete mode 100644 tests/integration/adapters/test_litellm.py delete mode 100644 tests/integration/adapters/test_ollama.py delete mode 100644 tests/integration/adapters/test_openai.py delete mode 100644 tests/unit/adapters/test_litellm.py delete mode 100644 tests/unit/adapters/test_ollama.py delete mode 100644 tests/unit/adapters/test_openai.py diff --git a/LITELLM_ROUTER_REFACTOR_PLAN.md b/LITELLM_ROUTER_REFACTOR_PLAN.md index 675f4ec5..07c2cd13 100644 --- a/LITELLM_ROUTER_REFACTOR_PLAN.md +++ b/LITELLM_ROUTER_REFACTOR_PLAN.md @@ -336,18 +336,28 @@ Phases A–E (partial) landed in five commits on ### Remaining work (deferred) -- **Phase E.2 — full deletion of the chat adapter classes.** Requires - replacing the per-registration adapter instance with a lightweight - config dataclass (`_ChatRegistration`) and dropping the public - `LiteLLMAgent` / `OpenAIAgent` / `OllamaAgent` symbols. Hold off - until we know no downstream code in `hackagent-api`, - `hackagent-webapp`, or external consumers depends on those imports. -- **Phase F — optional follow-ups.** Adopt `litellm.Router` for - multi-deployment load balancing; surface `response_cost` and - `x-litellm-call-id` in the envelope; streaming support; full - `_build_error_response` shape unification (currently the - `AgentNotFound` envelope still uses ``raw_response_status`` while - the chat-dispatch envelope uses ``status_code``). +- **Phase F.2+ — optional follow-ups.** Adopt `litellm.Router` for + multi-deployment load balancing; streaming support + (`stream=True`/`CustomStreamWrapper`); richer caller-supplied + `metadata` plumbing for `org_id` / `attack_id` / `evaluator_id`. + +### Recent landings (post-original plan) + +- **Phase F.1** (`0b0a2e5`): `extract_response_cost` and + `extract_litellm_call_id` envelope helpers; chat-dispatch envelope + now carries `response_cost` and `litellm_call_id` in + `agent_specific_data`. The router-level `AgentNotFound` envelope now + sets `status_code` alongside the legacy `raw_response_status` alias. +- **Phase E.2a** (`c3090e9`): `ADKAgent` no longer inherits from + `LiteLLMAgent`; it extends `Agent` directly and implements + `handle_request` itself. +- **Phase E.2b** (`5b7b9af`): chat AgentTypes no longer instantiate + `LiteLLMAgent` / `OpenAIAgent` / `OllamaAgent` — the router builds + a `_ChatRegistration` config holder. +- **Phase E.2c**: the chat adapter classes and their unit/integration + tests are gone. `hackagent/router/adapters/` keeps only the + exception base classes plus an `ADKAgent` re-export. ADK lives at + `hackagent.router.providers.adk`. ### Tests diff --git a/hackagent/router/__init__.py b/hackagent/router/__init__.py index c339ad9f..c241b285 100644 --- a/hackagent/router/__init__.py +++ b/hackagent/router/__init__.py @@ -3,17 +3,13 @@ """Main router logic for dispatching requests to appropriate agents.""" -from .adapters import ( - ADKAgent, - OllamaAgent, -) # This makes it easy to access agents via router module +from .adapters import ADKAgent from .router import AgentRouter from .tracking import StepTracker, TrackingContext, track_operation __all__ = [ "AgentRouter", - "ADKAgent", # Exporting specific agents for convenience - "OllamaAgent", # Ollama agent for local LLMs + "ADKAgent", "StepTracker", "TrackingContext", "track_operation", diff --git a/hackagent/router/adapters/__init__.py b/hackagent/router/adapters/__init__.py index 37e73fb7..bac6e458 100644 --- a/hackagent/router/adapters/__init__.py +++ b/hackagent/router/adapters/__init__.py @@ -2,25 +2,21 @@ # SPDX-License-Identifier: Apache-2.0 """ -Legacy adapter package. - -Issue #379 routed every chat-completion AgentType through LiteLLM -(Phase A–B), hoisted the call path into ``AgentRouter`` (Phase C), and -wired a ``litellm.CustomLogger`` for I/O capture (Phase D). The -adapter classes ``LiteLLMAgent`` / ``OpenAIAgent`` / ``OllamaAgent`` -are no longer on the hot path — ``AgentRouter._dispatch_via_litellm`` -calls ``litellm.completion`` directly, reading the resolved config off -the instance only for its model name, endpoint, API key, and -generation defaults. - -The classes remain available for backwards compatibility with external -callers that import them. ADK has moved to -``hackagent.router.providers.adk``; this package re-exports -``ADKAgent`` from there so old imports keep working. +Adapter exception classes + ADKAgent re-export. + +Issue #379 completed: + - Phases A–D moved the chat-completion call path off the adapter + classes and onto ``AgentRouter._dispatch_via_litellm`` + ``litellm``. + - Phase E.2 deleted ``LiteLLMAgent`` / ``OpenAIAgent`` / ``OllamaAgent`` + entirely; chat AgentTypes now use ``_ChatRegistration``. + - ADK lives at ``hackagent.router.providers.adk`` and is re-exported + here so old imports keep working. + +If you were importing ``LiteLLMAgent``, ``OpenAIAgent``, or +``OllamaAgent`` from this package, switch to driving requests through +``AgentRouter.route_request(...)`` instead. """ -# Lazy imports for adapters to improve startup time -# These adapters import heavy dependencies (litellm ~2s, google-adk ~0.1s) from .base import ( Agent, ChatCompletionsAgent, @@ -31,31 +27,16 @@ def __getattr__(name): - """Lazy load adapter classes on first access.""" + """Lazy load for the surviving adapter class.""" if name == "ADKAgent": from hackagent.router.providers.adk import ADKAgent return ADKAgent - elif name == "LiteLLMAgent": - from .litellm import LiteLLMAgent - - return LiteLLMAgent - elif name == "OpenAIAgent": - from .openai import OpenAIAgent - - return OpenAIAgent - elif name == "OllamaAgent": - from .ollama import OllamaAgent - - return OllamaAgent raise AttributeError(f"module {__name__!r} has no attribute {name!r}") __all__ = [ "ADKAgent", - "LiteLLMAgent", - "OpenAIAgent", - "OllamaAgent", "Agent", "ChatCompletionsAgent", "AdapterConfigurationError", diff --git a/hackagent/router/adapters/litellm.py b/hackagent/router/adapters/litellm.py deleted file mode 100644 index 0606076d..00000000 --- a/hackagent/router/adapters/litellm.py +++ /dev/null @@ -1,512 +0,0 @@ -# Copyright 2026 - AI4I. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - - -from hackagent.logger import get_logger -from typing import Any, Dict, List, Optional - -from hackagent.router import envelope as _envelope -from hackagent.router.provider_config import ProviderConfig - -from .base import ChatCompletionsAgent, AdapterConfigurationError - -# Lazy load litellm - only import when actually needed to avoid ~2s startup delay -# The actual import happens in _get_litellm() method -_litellm_module = None -_litellm_exceptions = None - - -def _get_litellm(): - """Lazily import litellm module. Returns (litellm_module, is_available).""" - global _litellm_module, _litellm_exceptions - if _litellm_module is not None: - return _litellm_module, True - - try: - import litellm - - _litellm_module = litellm - return litellm, True - except ImportError: - return None, False - - -def _get_litellm_exceptions(): - """Lazily import litellm exceptions. Returns dict of exception classes.""" - global _litellm_exceptions - if _litellm_exceptions is not None: - return _litellm_exceptions - - try: - from litellm.exceptions import ( - APIConnectionError, - APIError, - AuthenticationError, - BadRequestError, - ContextWindowExceededError, - NotFoundError, - PermissionDeniedError, - RateLimitError, - ServiceUnavailableError, - Timeout, - ) - - _litellm_exceptions = { - "APIConnectionError": APIConnectionError, - "APIError": APIError, - "AuthenticationError": AuthenticationError, - "BadRequestError": BadRequestError, - "ContextWindowExceededError": ContextWindowExceededError, - "NotFoundError": NotFoundError, - "PermissionDeniedError": PermissionDeniedError, - "RateLimitError": RateLimitError, - "ServiceUnavailableError": ServiceUnavailableError, - "Timeout": Timeout, - } - except ImportError: - # Define dummy exceptions if litellm is not available - _litellm_exceptions = { - "APIConnectionError": Exception, - "APIError": Exception, - "AuthenticationError": Exception, - "BadRequestError": Exception, - "ContextWindowExceededError": Exception, - "NotFoundError": Exception, - "PermissionDeniedError": Exception, - "RateLimitError": Exception, - "ServiceUnavailableError": Exception, - "Timeout": Exception, - } - return _litellm_exceptions - - -# --- Custom Exceptions (subclass from base) --- -class LiteLLMConfigurationError(AdapterConfigurationError): - """Custom exception for LiteLLM adapter configuration issues.""" - - pass - - -logger = get_logger(__name__) # Module-level logger - - -# Sourced from envelope.py — kept here as a module-level alias so any -# external code that imported it from this module still works. -_KNOWN_LITELLM_PROVIDER_PREFIXES = _envelope.KNOWN_LITELLM_PROVIDER_PREFIXES - - -class LiteLLMAgent(ChatCompletionsAgent): - """ - Unified adapter that routes every chat-completion request through LiteLLM. - - All chat-style adapters (OpenAI-SDK, Ollama, LangChain, plain LiteLLM) - subclass this class. Each subclass sets ``PROVIDER_PREFIX`` to declare - which LiteLLM provider the AgentType maps to (e.g. ``"openai"`` for an - OpenAI-compatible endpoint, ``"ollama_chat"`` for Ollama). The base class - handles model-string normalisation, endpoint plumbing, generation - parameters, tool calls, and the unified ``thinking`` knob. - - Thinking knob: - Any subclass can be asked to enable or disable provider reasoning by - setting ``thinking`` in the adapter config or per-request payload. - Accepted values: - - bool: ``True`` enables thinking with provider defaults, - ``False`` disables it explicitly. - - dict: passed through verbatim (e.g. - ``{"type": "enabled", "budget_tokens": 1024}`` for Anthropic). - - str: a reasoning effort level (``"low"``, ``"medium"``, - ``"high"``) — translated by the subclass as appropriate. - - int: budget tokens for providers that accept a budget. - Subclasses override ``_apply_thinking`` to translate the value into - the provider-specific request fields. - """ - - ADAPTER_TYPE = "LiteLLMAgent" - # When set, the model string passed to LiteLLM is prefixed with - # ``"{PROVIDER_PREFIX}/"`` unless it already starts with a known - # LiteLLM provider prefix. ``None`` means "let LiteLLM auto-detect". - # Subclasses can either set this class attribute or pass a - # ``ProviderConfig`` to ``__init__`` (the new Phase B path). - PROVIDER_PREFIX: Optional[str] = None - - def __init__( - self, - id: str, - config: Dict[str, Any], - provider_config: Optional[ProviderConfig] = None, - ): - """ - Initialise the adapter from configuration. - - Args: - id: Unique identifier for this adapter instance. - config: Configuration dict. Supported keys: - - ``name``: model string (e.g. ``"llama3"`` or - ``"gpt-4"``). Required. - - ``endpoint`` (optional): API base URL. - - ``api_key`` (optional): API key or environment variable name. - - ``max_tokens`` / ``temperature`` / ``top_p`` (optional). - - ``tools`` / ``tool_choice`` (optional): function-calling - definitions, passed through to LiteLLM. - - ``thinking`` (optional): see class docstring. - - ``extra_body`` (optional): provider-specific request body. - provider_config: Optional :class:`ProviderConfig` looked up - from ``hackagent.router.provider_config``. When supplied, - it takes precedence over the class-level - ``PROVIDER_PREFIX`` and ``_apply_thinking`` override. This - is the path Phase C will use to drive the adapter from - ``router.py`` without subclassing. - """ - super().__init__(id, config) - self._provider_config: Optional[ProviderConfig] = provider_config - - # Require model name - self.model_name = self._require_config_key("name", LiteLLMConfigurationError) - self.api_base_url: Optional[str] = self._get_config_key("endpoint") - - # Determine the effective LiteLLM model string (with provider prefix). - self.litellm_model = self._resolve_litellm_model(self.model_name) - - # Handle API key configuration using base class helper - env_var_fallback = self._default_api_key_env_var() - self.actual_api_key = self._resolve_api_key( - config_key="api_key", env_var_fallback=env_var_fallback - ) - - # When using a custom endpoint without credentials, rely on - # endpoint-side auth (common for local model servers). - if self.api_base_url and not self.actual_api_key: - self.logger.debug( - f"Using custom endpoint '{self.api_base_url}' without api_key - " - "endpoint handles its own auth" - ) - - self.logger.info( - f"{self.ADAPTER_TYPE} '{self.id}' initialised for LiteLLM model: " - f"'{self.litellm_model}'" - + (f" API Base: '{self.api_base_url}'" if self.api_base_url else "") - ) - - # Default generation parameters (max_tokens, temperature, top_p). - self._init_generation_params() - - # Pass-through fields commonly supplied via config. - self.default_tools = self._get_config_key("tools") - self.default_tool_choice = self._get_config_key("tool_choice") - self.default_extra_body = self._get_config_key("extra_body") - self.default_thinking = self._get_config_key("thinking") - - # ---- subclass extension points --------------------------------------- - - def _resolve_litellm_model(self, raw_model: str) -> str: - """Return the model string to pass to ``litellm.completion``. - - Honors ``self._provider_config.provider_prefix`` when a - :class:`ProviderConfig` was supplied, otherwise falls back to the - class-level ``PROVIDER_PREFIX``. Subclasses (notably - :class:`ADKAgent`) override this entirely to inject a per-instance - provider prefix. - """ - prefix = ( - self._provider_config.provider_prefix - if self._provider_config is not None - else self.PROVIDER_PREFIX - ) - return _envelope.resolve_litellm_model(raw_model, provider_prefix=prefix) - - def _default_api_key_env_var(self) -> Optional[str]: - """Return the env var used as a fallback when no API key is configured.""" - if self.api_base_url: - return None - if self.litellm_model.startswith(("openai/", "gpt-")): - return "OPENAI_API_KEY" - if self.litellm_model.startswith(("anthropic/", "claude-")): - return "ANTHROPIC_API_KEY" - return None - - def _apply_thinking(self, litellm_params: Dict[str, Any], thinking: Any) -> None: - """Translate the unified ``thinking`` value into LiteLLM params. - - When a :class:`ProviderConfig` was supplied, its - ``thinking_translator`` is consulted; otherwise the default - generic translator is used. Subclasses may still override this - method for backwards compatibility, but the recommended path is - to supply a ``ProviderConfig`` instead. - """ - if thinking is None: - return - if self._provider_config is not None: - payload = self._provider_config.thinking_translator( - thinking, model_name=self.litellm_model - ) - if payload: - litellm_params.update(payload) - return - - # Fallback path for adapters built without a ProviderConfig. - from hackagent.router.provider_config import default_thinking_translator - - payload = default_thinking_translator(thinking, model_name=self.litellm_model) - if payload: - litellm_params.update(payload) - - # ---- request preparation -------------------------------------------- - - def _prepare_litellm_params( - self, - messages: List[Dict[str, str]], - max_tokens: int, - temperature: float, - top_p: float, - **kwargs, - ) -> Dict[str, Any]: - """Build the kwargs dict for ``litellm.completion``. - - Delegates the bulk construction to - :func:`hackagent.router.envelope.build_litellm_kwargs`. The - thinking translation still goes through ``_apply_thinking`` so - subclasses can specialise it. - """ - thinking_payload: Dict[str, Any] = {} - thinking = kwargs.pop("thinking", self.default_thinking) - self._apply_thinking(thinking_payload, thinking) - - tools = kwargs.pop("tools", self.default_tools) - tool_choice = kwargs.pop("tool_choice", self.default_tool_choice) - extra_body = kwargs.pop("extra_body", self.default_extra_body) - - return _envelope.build_litellm_kwargs( - model=self.litellm_model, - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - api_base=self.api_base_url, - api_key=self.actual_api_key, - tools=tools, - tool_choice=tool_choice, - extra_body=extra_body, - thinking_payload=thinking_payload, - extra_kwargs=kwargs, - ) - - def _extract_raw_response_content(self, response: Any, context: str = "") -> str: - """Extract content from a litellm response object. - - Delegates to - :func:`hackagent.router.envelope.extract_text_from_response`. The - ``context`` argument is preserved for backwards compatibility but - is only used for logging when the response is malformed. - """ - text = _envelope.extract_text_from_response( - response, model_name=self.litellm_model - ) - if text == "[GENERATION_ERROR: UNEXPECTED_RESPONSE]": - self.logger.warning( - f"LiteLLM received unexpected response structure for model " - f"'{self.litellm_model}'{context}. Response: {response}" - ) - elif text == "[GENERATION_ERROR: EMPTY_RESPONSE]": - self.logger.warning( - f"LiteLLM received empty content and no reasoning field for " - f"model '{self.litellm_model}'{context}." - ) - return text - - def _extract_tool_calls(self, response: Any) -> Optional[List[Dict[str, Any]]]: - """Return OpenAI-style ``tool_calls`` from a ``ModelResponse``, or ``None``. - - Delegates to :func:`hackagent.router.envelope.extract_tool_calls`. - """ - return _envelope.extract_tool_calls(response) - - def _get_excluded_request_keys(self) -> set: - """Return keys handled explicitly so they aren't re-passed as kwargs.""" - return { - "prompt", - "messages", - "max_tokens", - "temperature", - "top_p", - "tools", - "tool_choice", - "thinking", - "extra_body", - } - - def _get_completion_parameters( - self, request_data: Dict[str, Any] - ) -> Dict[str, Any]: - """Extract completion parameters with provider-agnostic passthroughs.""" - params = super()._get_completion_parameters(request_data) - # Carry passthrough fields when present in the request. - for key in ("tools", "tool_choice", "thinking", "extra_body"): - if key in request_data: - params[key] = request_data[key] - return params - - # ---- execution ------------------------------------------------------- - - def _execute_completion( - self, messages: List[Dict[str, str]], **parameters - ) -> Dict[str, Any]: - """Execute a completion via ``litellm.completion``.""" - litellm, is_available = _get_litellm() - if not is_available: - return { - "success": False, - "error_type": "configuration_error", - "error_message": "litellm is not installed", - } - - exceptions = _get_litellm_exceptions() - AuthenticationError = exceptions["AuthenticationError"] - - try: - if messages: - msg_preview = str(messages[-1].get("content", ""))[:100] - self.logger.info(f"🌐 Querying model {self.litellm_model}") - self.logger.debug(f" Message preview: {msg_preview}...") - - max_tokens = parameters.pop("max_tokens", self.default_max_tokens) - temperature = parameters.pop("temperature", self.default_temperature) - top_p = parameters.pop("top_p", self.default_top_p) - - litellm_params = self._prepare_litellm_params( - messages, max_tokens, temperature, top_p, **parameters - ) - response = litellm.completion(**litellm_params) - - content = self._extract_raw_response_content(response) - tool_calls = self._extract_tool_calls(response) - - self.logger.info(f"✅ Model responded ({len(content)} chars)") - - result: Dict[str, Any] = { - "success": True, - "content": content, - "raw_response": response, - } - if tool_calls is not None: - result["tool_calls"] = tool_calls - # Surface useful diagnostics when available. - try: - result["finish_reason"] = response.choices[0].finish_reason - except (AttributeError, IndexError, TypeError): - pass - try: - if response.usage is not None: - result["usage"] = response.usage.model_dump() - except AttributeError: - pass - try: - result["provider_model"] = response.model - except AttributeError: - pass - - return result - - except AuthenticationError as e: - error_msg = f"Authentication failed for model '{self.litellm_model}': {e}" - self.logger.error(error_msg) - llm_provider = e.llm_provider if hasattr(e, "llm_provider") else "unknown" - raise AuthenticationError( - error_msg, llm_provider, self.litellm_model - ) from e - except Exception as e: - self.logger.error( - f"LiteLLM completion call failed for model '{self.litellm_model}': {e}", - exc_info=True, - ) - return { - "success": False, - "error_type": type(e).__name__, - "error_message": str(e), - } - - # ---- legacy convenience helpers ------------------------------------- - # (Response-shaping is handled by the base ``ChatCompletionsAgent`` via - # ``envelope.build_agent_specific_data`` since Phase A.) - - def _execute_litellm_completion_with_messages( - self, - messages: List[Dict[str, str]], - max_tokens: int, - temperature: float, - top_p: float, - **kwargs, - ) -> str: - """Single completion call returning the generated text only.""" - result = self._execute_completion( - messages, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - **kwargs, - ) - if result.get("success"): - return result.get("content", "") - return f"[GENERATION_ERROR: {result.get('error_type', 'UNKNOWN')}]" - - def _execute_litellm_completion( - self, - texts: List[str], - max_tokens: int, - temperature: float, - top_p: float, - **kwargs, - ) -> List[str]: - """Generate completions for a batch of prompt strings.""" - if not texts: - return [] - - litellm, is_available = _get_litellm() - if not is_available: - raise LiteLLMConfigurationError("litellm is not installed") - - exceptions = _get_litellm_exceptions() - AuthenticationError = exceptions["AuthenticationError"] - - completions: List[str] = [] - self.logger.info( - f"Sending {len(texts)} requests via LiteLLM to model " - f"'{self.litellm_model}'..." - ) - - for text_prompt in texts: - messages = [{"role": "user", "content": text_prompt}] - try: - litellm_params = self._prepare_litellm_params( - messages, max_tokens, temperature, top_p, **kwargs - ) - response = litellm.completion(**litellm_params) - completion_text = self._extract_raw_response_content( - response, context=f" for prompt '{text_prompt[:50]}...'" - ) - except AuthenticationError as e: - error_msg = ( - f"Authentication failed for model '{self.litellm_model}': {e}" - ) - self.logger.error(error_msg) - llm_provider = ( - e.llm_provider if hasattr(e, "llm_provider") else "unknown" - ) - raise AuthenticationError( - error_msg, llm_provider, self.litellm_model - ) from e - except Exception as e: - self.logger.error( - f"LiteLLM completion call failed for model " - f"'{self.litellm_model}' for prompt " - f"'{text_prompt[:50]}...': {e}", - exc_info=True, - ) - completion_text = f" [GENERATION_ERROR: {type(e).__name__}]" - - completions.append(text_prompt + completion_text) - - self.logger.info( - f"Finished LiteLLM requests for model '{self.litellm_model}'. " - f"Generated {len(completions)} responses." - ) - return completions diff --git a/hackagent/router/adapters/ollama.py b/hackagent/router/adapters/ollama.py deleted file mode 100644 index 5b4931fc..00000000 --- a/hackagent/router/adapters/ollama.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2026 - AI4I. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -Ollama adapter built on top of LiteLLM. - -LiteLLM ships with an ``ollama_chat`` provider that targets a local or -remote Ollama server's ``/api/chat`` endpoint, so we no longer have to -hand-roll the HTTP calls. This adapter just pins the provider prefix, -normalises the endpoint URL the way the previous direct adapter did, and -translates the unified ``thinking`` knob into Ollama's ``think`` parameter. -""" - -import os -from hackagent.logger import get_logger -from typing import Any, Dict, List, Optional - -from hackagent.router.provider_config import get_provider_config -from hackagent.router.types import AgentTypeEnum - -from .base import AdapterConfigurationError -from .litellm import LiteLLMAgent - - -class OllamaConfigurationError(AdapterConfigurationError): - """Custom exception for Ollama adapter configuration issues.""" - - pass - - -logger = get_logger(__name__) - - -class OllamaAgent(LiteLLMAgent): - """ - Adapter for an Ollama server. - - Configuration: - - ``name``: Ollama model tag (e.g. ``"llama3"``, ``"mistral"``). - - ``endpoint`` (optional): Ollama base URL. Defaults to - ``$OLLAMA_BASE_URL`` if set, otherwise ``http://localhost:11434``. - API-path suffixes such as ``/api/chat`` are stripped automatically - so users can paste their browser URL. - - ``thinking`` (optional): see :class:`LiteLLMAgent` for the - accepted shapes. Translated into Ollama's native ``think`` field. - - ``top_k`` / ``num_ctx`` / ``stream`` (optional): forwarded as - Ollama generation options. - """ - - ADAPTER_TYPE = "OllamaAgent" - DEFAULT_ENDPOINT = "http://localhost:11434" - - def __init__(self, id: str, config: Dict[str, Any]): - # Resolve and normalise the endpoint before delegating to the base. - effective_endpoint = config.get("endpoint") or os.environ.get( - "OLLAMA_BASE_URL", self.DEFAULT_ENDPOINT - ) - effective_endpoint = self._normalize_endpoint(effective_endpoint) - config = {**config, "endpoint": effective_endpoint} - - try: - super().__init__( - id, - config, - provider_config=get_provider_config(AgentTypeEnum.OLLAMA), - ) - except AdapterConfigurationError as e: - raise OllamaConfigurationError(str(e)) from e - - # Ollama-specific generation options that LiteLLM forwards via - # ``optional_params`` (any extra kwarg passed to - # ``litellm.completion`` for the ``ollama_chat`` provider). - self.default_top_k = self._get_config_key("top_k") - self.default_num_ctx = self._get_config_key("num_ctx") - self.default_stream = self._get_config_key("stream", False) - - @staticmethod - def _normalize_endpoint(endpoint: str) -> str: - """Strip trailing slash and Ollama API path suffixes from ``endpoint``.""" - endpoint = endpoint.rstrip("/") - for suffix in ("/api/generate", "/api/chat", "/api/tags", "/api/show", "/api"): - if endpoint.endswith(suffix): - endpoint = endpoint[: -len(suffix)] - break - return endpoint - - # ---- request shaping ------------------------------------------------ - - def _get_completion_parameters( - self, request_data: Dict[str, Any] - ) -> Dict[str, Any]: - """Inject Ollama-specific defaults (top_k, num_ctx, stream).""" - params = super()._get_completion_parameters(request_data) - if "top_k" not in params and self.default_top_k is not None: - params["top_k"] = self.default_top_k - if "num_ctx" not in params and self.default_num_ctx is not None: - params["num_ctx"] = self.default_num_ctx - if "stream" not in params and self.default_stream: - params["stream"] = self.default_stream - return params - - def _get_excluded_request_keys(self) -> set: - base = super()._get_excluded_request_keys() - return base | {"top_k", "num_ctx", "stream", "system"} - - # Thinking translation is driven by the ``OLLAMA`` ``ProviderConfig`` - # (see ``hackagent/router/provider_config.py``); no override needed. - - # ---- diagnostics passthroughs (kept for callers/tests) -------------- - - def list_models(self) -> List[Dict[str, Any]]: - """Return models reported by ``GET {endpoint}/api/tags``.""" - import requests - - try: - response = requests.get(f"{self.api_base_url}/api/tags", timeout=30) - response.raise_for_status() - return response.json().get("models", []) - except Exception as e: - self.logger.error(f"Failed to list Ollama models: {e}") - return [] - - def model_info(self) -> Dict[str, Any]: - """Return ``POST {endpoint}/api/show`` payload for the current model.""" - import requests - - try: - response = requests.post( - f"{self.api_base_url}/api/show", - json={"name": self.model_name}, - timeout=30, - ) - response.raise_for_status() - return response.json() - except Exception as e: - self.logger.error(f"Failed to get model info for '{self.model_name}': {e}") - return {} - - def is_available(self) -> bool: - """True iff the configured model appears in ``/api/tags``.""" - try: - models = self.list_models() - if not self.model_name: - return False - base_model = self.model_name.split(":")[0] - names: List[Optional[str]] = [m.get("name") for m in models] - base_names = [(m.get("name") or "").split(":")[0] for m in models] - return base_model in base_names or self.model_name in names - except Exception: - return False diff --git a/hackagent/router/adapters/openai.py b/hackagent/router/adapters/openai.py deleted file mode 100644 index f87f6599..00000000 --- a/hackagent/router/adapters/openai.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2026 - AI4I. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -OpenAI-compatible adapter built on top of LiteLLM. - -The OpenAI agent type used to talk to the OpenAI SDK directly. As of -issue #379 every chat-completion adapter routes through LiteLLM, so this -class is now a thin specialisation of :class:`LiteLLMAgent` that pulls -its provider prefix + thinking translator out of the -:class:`ProviderConfig` table. -""" - -from hackagent.logger import get_logger -from typing import Any, Dict - -from hackagent.router.provider_config import get_provider_config -from hackagent.router.types import AgentTypeEnum - -from .base import AdapterConfigurationError -from .litellm import LiteLLMAgent - - -# Keep this exception public for backwards compatibility — downstream code -# (and several tests) import OpenAIConfigurationError from this module. -class OpenAIConfigurationError(AdapterConfigurationError): - """Custom exception for OpenAI adapter configuration issues.""" - - pass - - -logger = get_logger(__name__) - - -class OpenAIAgent(LiteLLMAgent): - """ - Adapter for OpenAI-compatible chat endpoints. - - Configured via the ``OPENAI_SDK`` agent type. Internally uses - LiteLLM, so any OpenAI-compatible server (the official API, a local - model server exposing ``/v1/chat/completions``, OpenRouter, etc.) - works as the endpoint. - - Reasoning / "thinking": - Driven by the :class:`ProviderConfig` for - ``AgentTypeEnum.OPENAI_SDK`` — for the o-series and newer GPT - reasoning models the unified ``thinking`` value is translated - to ``reasoning_effort``. - """ - - ADAPTER_TYPE = "OpenAIAgent" - DEFAULT_TEMPERATURE = 1.0 - - def __init__(self, id: str, config: Dict[str, Any]): - # Custom endpoints don't always require a model name; default to - # ``"default"`` (the server then decides) when an endpoint is set - # but no model is provided. - if "name" not in config and config.get("endpoint"): - config = {**config, "name": config.get("name", "default")} - - try: - super().__init__( - id, - config, - provider_config=get_provider_config(AgentTypeEnum.OPENAI_SDK), - ) - except AdapterConfigurationError as e: - # Re-raise as the OpenAI-flavoured subclass so legacy callers - # that catch OpenAIConfigurationError keep working. - raise OpenAIConfigurationError(str(e)) from e - - # For custom endpoints without an API key, use a placeholder so - # the OpenAI client (under LiteLLM's hood) doesn't error out. - if not self.actual_api_key and self.api_base_url: - self.actual_api_key = "not-required" - self.logger.info( - f"No API key configured for custom endpoint " - f"'{self.api_base_url}', using placeholder" - ) diff --git a/hackagent/router/router.py b/hackagent/router/router.py index 0d295f8a..94172313 100644 --- a/hackagent/router/router.py +++ b/hackagent/router/router.py @@ -9,30 +9,22 @@ from hackagent.router import tracking_logger as _tracking_logger from hackagent.router._chat_registration import _ChatRegistration from hackagent.router.adapters.base import Agent +from hackagent.router.providers.adk import ADKAgent, _get_litellm from hackagent.router.provider_config import ProviderConfig, get_provider_config from hackagent.router.types import AgentTypeEnum -# Adapter imports - these stay at module level for tests that still patch -# ``hackagent.router.router.LiteLLMAgent`` etc. As of Phase E.2 chat -# AgentTypes no longer instantiate these classes — they build a -# ``_ChatRegistration`` instead. ADK still uses ``ADKAgent`` because its -# CustomLLM registration is per-instance. -from hackagent.router.adapters import ADKAgent -from hackagent.router.adapters.litellm import LiteLLMAgent, _get_litellm -from hackagent.router.adapters.openai import OpenAIAgent -from hackagent.router.adapters.ollama import OllamaAgent - # Use explicit hierarchical logger name for clarity logger = logging.getLogger("hackagent.router") # --- Agent Type to Adapter Mapping --- +# Phase E.2c deleted the chat adapter classes. Chat AgentTypes +# (LITELLM, OPENAI_SDK, OLLAMA, LANGCHAIN) are now driven entirely by +# ``hackagent.router.provider_config.get_provider_config`` plus a +# ``_ChatRegistration``. The map only carries adapter classes for agent +# types that need a custom Python object (ADK has a per-instance +# CustomLLM registration side-effect). AGENT_TYPE_TO_ADAPTER_MAP: Dict[AgentTypeEnum, Type[Agent]] = { AgentTypeEnum.GOOGLE_ADK: ADKAgent, - AgentTypeEnum.LITELLM: LiteLLMAgent, - AgentTypeEnum.OPENAI_SDK: OpenAIAgent, - AgentTypeEnum.OLLAMA: OllamaAgent, - AgentTypeEnum.LANGCHAIN: LiteLLMAgent, # LangChain agents can use LiteLLM adapter - # Add other agent types and their corresponding adapters here } @@ -43,8 +35,9 @@ class AgentRouter: The `AgentRouter` is responsible for initializing an agent, which includes: 1. Resolving organizational context via the storage backend. 2. Ensuring the agent is registered in the storage backend. - 3. Instantiating the appropriate adapter (e.g., `ADKAgent`, `LiteLLMAgent`) - based on the `agent_type`. + 3. Building either an ``ADKAgent`` instance (for the GOOGLE_ADK + type, which needs a per-instance CustomLLM registration) or a + lightweight ``_ChatRegistration`` (for every chat AgentType). 4. Storing this adapter for subsequent request routing. Attributes: @@ -101,10 +94,18 @@ def __init__( f"User ID={self.user_id_str}" ) - if agent_type not in AGENT_TYPE_TO_ADAPTER_MAP: + # Either a chat AgentType (driven by ProviderConfig) or one of + # the explicit adapter classes (currently just ADK). + if ( + get_provider_config(agent_type) is None + and agent_type not in AGENT_TYPE_TO_ADAPTER_MAP + ): + supported = list(AGENT_TYPE_TO_ADAPTER_MAP.keys()) + from hackagent.router.provider_config import PROVIDER_CONFIGS as _PC + + supported.extend(_PC.keys()) raise ValueError( - f"Unsupported agent type: {agent_type}. " - f"Supported types: {list(AGENT_TYPE_TO_ADAPTER_MAP.keys())}" + f"Unsupported agent type: {agent_type}. Supported types: {supported}" ) actual_metadata = {k: v for k, v in (metadata or {}).items() if v is not None} @@ -171,7 +172,7 @@ def _configure_and_instantiate_adapter( ValueError: If essential configuration for an adapter type is missing (e.g., model name for LiteLLM) or if adapter instantiation fails. """ - adapter_class = AGENT_TYPE_TO_ADAPTER_MAP[agent_type] + adapter_class = AGENT_TYPE_TO_ADAPTER_MAP.get(agent_type) logger.debug( f"ROUTER_DEBUG: adapter_class is: {adapter_class}, type: {type(adapter_class)}, id: {id(adapter_class)}" @@ -181,9 +182,10 @@ def _configure_and_instantiate_adapter( adapter_operational_config.copy() if adapter_operational_config else {} ) - # Every adapter now subclasses LiteLLMAgent, so the same set of - # config fields applies (with ADK adding a required user_id). - # ``name`` is the model string, ``endpoint`` is the API base URL. + # ``_ChatRegistration`` for chat AgentTypes and ``ADKAgent`` for + # ADK take the same config shape (with ADK adding a required + # user_id). ``name`` is the model string, ``endpoint`` is the + # API base URL. if "name" not in adapter_instance_config: metadata = self.backend_agent.metadata if isinstance(metadata, dict) and "name" in metadata: @@ -261,6 +263,15 @@ def _configure_and_instantiate_adapter( adapter_instance = adapter_class( id=registration_key, config=adapter_instance_config ) + adapter_label = ( + provider_config.adapter_label + if provider_config is not None + else ( + adapter_class.__name__ + if adapter_class + else type(adapter_instance).__name__ + ) + ) logger.debug( f"ROUTER_DEBUG: Resulting instance: {adapter_instance}, type: {type(adapter_instance)}" ) @@ -268,16 +279,21 @@ def _configure_and_instantiate_adapter( self._agent_types[registration_key] = agent_type logger.info( f"Agent '{name}' (Backend ID: {registration_key}, Type: {agent_type.value}) " - f"successfully initialized and registered with adapter {adapter_class.__name__}. " + f"successfully initialized and registered as {adapter_label}. " f"Adapter config keys: {list(adapter_instance_config.keys())}" ) except Exception as e: + adapter_label_for_error = ( + provider_config.adapter_label + if provider_config is not None + else (adapter_class.__name__ if adapter_class else "adapter") + ) logger.error( f"Failed to instantiate adapter for agent '{name}' (Backend ID: {registration_key}): {e}", exc_info=True, ) raise ValueError( - f"Failed to instantiate adapter {adapter_class.__name__}: {e}" + f"Failed to instantiate adapter {adapter_label_for_error}: {e}" ) from e def get_agent_instance(self, registration_key: str) -> Optional[Agent]: diff --git a/tests/integration/adapters/test_litellm.py b/tests/integration/adapters/test_litellm.py deleted file mode 100644 index 7a18a0c9..00000000 --- a/tests/integration/adapters/test_litellm.py +++ /dev/null @@ -1,509 +0,0 @@ -# Copyright 2026 - AI4I. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -Integration tests for LiteLLM adapter. - -These tests verify end-to-end functionality with LiteLLM's multi-provider support: -- Adapter initialization with various providers -- Chat completions through different backends (Ollama, OpenAI, etc.) -- Model identifier parsing and routing -- Error handling for unavailable providers -- Full HackAgent integration with LiteLLM - -LiteLLM supports 100+ LLMs via a unified interface: -- ollama/tinyllama - Ollama local models -- openai/gpt-4 - OpenAI models -- anthropic/claude-3 - Anthropic models -- And many more... - -Prerequisites: - - At least one supported backend must be available - - For Ollama: Ollama must be running - - For OpenAI: OPENAI_API_KEY must be set - -Run with: - pytest tests/integration/test_litellm_integration.py --run-integration --run-litellm - -Environment Variables: - LITELLM_MODEL: Model identifier (default: ollama/tinyllama) - OLLAMA_BASE_URL: For Ollama-backed models - OPENAI_API_KEY: For OpenAI-backed models -""" - -import logging -from typing import Any, Dict - -import pytest - -logger = logging.getLogger(__name__) - - -@pytest.mark.integration -@pytest.mark.litellm -class TestLiteLLMAdapterIntegration: - """Integration tests for LiteLLMAgent adapter.""" - - def test_adapter_initialization_with_ollama_model( - self, - skip_if_ollama_unavailable, - ollama_base_url: str, - ): - """Test LiteLLM adapter initialization with Ollama model.""" - from hackagent.router.adapters.litellm import LiteLLMAgent - - config = { - "name": "ollama/tinyllama", - "endpoint": ollama_base_url, - "max_tokens": 20, - } - - adapter = LiteLLMAgent(id="test_litellm_ollama", config=config) - - assert adapter.id == "test_litellm_ollama" - assert adapter.model_name == "ollama/tinyllama" - logger.info(f"LiteLLM adapter initialized with Ollama: {adapter.model_name}") - - def test_adapter_initialization_with_openai_model( - self, - skip_if_openai_unavailable, - openai_api_key: str, - ): - """Test LiteLLM adapter initialization with OpenAI model.""" - from hackagent.router.adapters.litellm import LiteLLMAgent - - config = { - "name": "gpt-4o-mini", - "api_key": openai_api_key, - "max_tokens": 20, - } - - adapter = LiteLLMAgent(id="test_litellm_openai", config=config) - - assert adapter.id == "test_litellm_openai" - assert adapter.model_name == "gpt-4o-mini" - logger.info(f"LiteLLM adapter initialized with OpenAI: {adapter.model_name}") - - def test_chat_completion_with_ollama( - self, - skip_if_ollama_unavailable, - ollama_base_url: str, - ): - """Test chat completion through LiteLLM with Ollama backend.""" - from hackagent.router.adapters.litellm import LiteLLMAgent - - config = { - "name": "ollama/tinyllama", - "endpoint": ollama_base_url, - "max_tokens": 15, - } - - adapter = LiteLLMAgent(id="test_litellm_chat_ollama", config=config) - - request = { - "messages": [ - {"role": "user", "content": "What is 2 + 2? Answer in one word."} - ], - "max_tokens": 20, - } - - response = adapter.handle_request(request) - - assert response is not None - assert "processed_response" in response - logger.info(f"LiteLLM/Ollama response: {response['processed_response']}") - - def test_chat_completion_with_openai( - self, - skip_if_openai_unavailable, - openai_api_key: str, - ): - """Test chat completion through LiteLLM with OpenAI backend.""" - from hackagent.router.adapters.litellm import LiteLLMAgent - - config = { - "name": "gpt-4o-mini", - "api_key": openai_api_key, - "max_tokens": 15, - } - - adapter = LiteLLMAgent(id="test_litellm_chat_openai", config=config) - - request = { - "messages": [ - {"role": "user", "content": "What is 2 + 2? Answer in one word."} - ], - "max_tokens": 20, - } - - response = adapter.handle_request(request) - - assert response is not None - assert "processed_response" in response - logger.info(f"LiteLLM/OpenAI response: {response['processed_response']}") - - def test_generation_with_custom_parameters( - self, - skip_if_litellm_unavailable, - litellm_config: Dict[str, Any], - ): - """Test generation with custom temperature and parameters.""" - from hackagent.router.adapters.litellm import LiteLLMAgent - - adapter = LiteLLMAgent(id="test_litellm_params", config=litellm_config) - - request = { - "messages": [ - {"role": "user", "content": "Generate a creative one-word response."} - ], - "max_tokens": 20, - "temperature": 1.2, - } - - response = adapter.handle_request(request) - - assert response is not None - assert "processed_response" in response - logger.info( - f"LiteLLM response with custom temp: {response['processed_response']}" - ) - - def test_multi_turn_conversation( - self, - skip_if_litellm_unavailable, - litellm_config: Dict[str, Any], - ): - """Test multi-turn conversation handling.""" - from hackagent.router.adapters.litellm import LiteLLMAgent - - adapter = LiteLLMAgent(id="test_litellm_multi", config=litellm_config) - - request = { - "messages": [ - {"role": "user", "content": "My name is Bob."}, - {"role": "assistant", "content": "Nice to meet you, Bob!"}, - {"role": "user", "content": "What is my name?"}, - ], - "max_tokens": 30, - } - - response = adapter.handle_request(request) - - assert response is not None - assert "processed_response" in response - # Should remember context - assert ( - "Bob" in response["processed_response"] - or "bob" in response["processed_response"].lower() - ) - logger.info(f"LiteLLM multi-turn response: {response['processed_response']}") - - def test_system_message_handling( - self, - skip_if_litellm_unavailable, - litellm_config: Dict[str, Any], - ): - """Test system message handling with LiteLLM.""" - from hackagent.router.adapters.litellm import LiteLLMAgent - - adapter = LiteLLMAgent(id="test_litellm_system", config=litellm_config) - - request = { - "messages": [ - { - "role": "system", - "content": "You are a helpful math tutor. Be brief.", - }, - {"role": "user", "content": "What is the square root of 16?"}, - ], - "max_tokens": 30, - } - - response = adapter.handle_request(request) - - assert response is not None - assert "processed_response" in response - logger.info(f"LiteLLM system msg response: {response['processed_response']}") - - -@pytest.mark.integration -@pytest.mark.litellm -@pytest.mark.hackagent_backend -class TestLiteLLMHackAgentIntegration: - """End-to-end tests for HackAgent with LiteLLM backend.""" - - def test_hackagent_with_litellm_initialization( - self, - skip_if_litellm_unavailable, - skip_if_no_hackagent_key, - hackagent_client_factory, - litellm_model: str, - ollama_base_url: str, - ): - """Test HackAgent initialization with LiteLLM agent type.""" - from hackagent import AgentTypeEnum - - # Determine endpoint based on model - endpoint = ( - ollama_base_url - if litellm_model.startswith("ollama/") - else "https://api.openai.com/v1" - ) - - agent = hackagent_client_factory( - name=litellm_model, - endpoint=endpoint, - agent_type=AgentTypeEnum.LITELLM, - ) - - assert agent is not None - assert agent.router is not None - logger.info(f"HackAgent initialized with LiteLLM: {agent.router.backend_agent}") - - def test_hackagent_litellm_baseline_attack( - self, - skip_if_litellm_unavailable, - skip_if_no_hackagent_key, - hackagent_client_factory, - litellm_model: str, - ollama_base_url: str, - basic_attack_config: Dict[str, Any], - ): - """Test running a baseline attack against LiteLLM agent.""" - from hackagent import AgentTypeEnum - - endpoint = ( - ollama_base_url - if litellm_model.startswith("ollama/") - else "https://api.openai.com/v1" - ) - - agent = hackagent_client_factory( - name=litellm_model, - endpoint=endpoint, - agent_type=AgentTypeEnum.LITELLM, - ) - - logger.info("Starting baseline attack against LiteLLM agent...") - results = agent.hack(attack_config=basic_attack_config) - - assert results is not None - logger.info(f"Baseline attack completed: {results}") - - @pytest.mark.slow - def test_hackagent_litellm_advprefix_attack( - self, - skip_if_litellm_unavailable, - skip_if_no_hackagent_key, - hackagent_client_factory, - litellm_model: str, - ollama_base_url: str, - advprefix_attack_config: Dict[str, Any], - ): - """Test running an advprefix attack against LiteLLM agent.""" - from hackagent import AgentTypeEnum - - endpoint = ( - ollama_base_url - if litellm_model.startswith("ollama/") - else "https://api.openai.com/v1" - ) - - agent = hackagent_client_factory( - name=litellm_model, - endpoint=endpoint, - agent_type=AgentTypeEnum.LITELLM, - ) - - logger.info("Starting advprefix attack against LiteLLM agent...") - results = agent.hack(attack_config=advprefix_attack_config) - - assert results is not None - logger.info(f"Advprefix attack completed: {results}") - - -@pytest.mark.integration -@pytest.mark.litellm -@pytest.mark.hackagent_backend -class TestLiteLLMRouterIntegration: - """Integration tests for AgentRouter with LiteLLM.""" - - def test_router_creates_litellm_adapter( - self, - skip_if_litellm_unavailable, - skip_if_no_hackagent_key, - hackagent_api_base_url: str, - hackagent_api_key: str, - litellm_model: str, - ollama_base_url: str, - ): - """Test that AgentRouter correctly creates the LITELLM registration.""" - from hackagent.server.client import AuthenticatedClient - from hackagent.server.storage.remote import RemoteBackend - from hackagent.router.router import AgentRouter - from hackagent.router.types import AgentTypeEnum - from hackagent.router._chat_registration import _ChatRegistration - - client = AuthenticatedClient( - base_url=hackagent_api_base_url, - token=hackagent_api_key, - prefix="Bearer", - ) - backend = RemoteBackend(client) - - endpoint = ( - ollama_base_url - if litellm_model.startswith("ollama/") - else "https://api.openai.com/v1" - ) - - router = AgentRouter( - backend=backend, - name=litellm_model, - agent_type=AgentTypeEnum.LITELLM, - endpoint=endpoint, - ) - - # Since #379 Phase E.2b the router stores a ``_ChatRegistration`` - # for chat AgentTypes; the adapter classes are no longer - # instantiated. - agent_id = str(router.backend_agent.id) - registration = router.get_agent_instance(registration_key=agent_id) - assert isinstance(registration, _ChatRegistration) - assert registration.ADAPTER_TYPE == "LiteLLMAgent" - logger.info(f"Router created LiteLLM registration: {registration.id}") - - def test_router_handles_litellm_request( - self, - skip_if_litellm_unavailable, - skip_if_no_hackagent_key, - hackagent_api_base_url: str, - hackagent_api_key: str, - litellm_model: str, - ollama_base_url: str, - ): - """Test that router can handle requests through LiteLLM adapter.""" - from hackagent.server.client import AuthenticatedClient - from hackagent.server.storage.remote import RemoteBackend - from hackagent.router.router import AgentRouter - from hackagent.router.types import AgentTypeEnum - - client = AuthenticatedClient( - base_url=hackagent_api_base_url, - token=hackagent_api_key, - prefix="Bearer", - ) - backend = RemoteBackend(client) - - endpoint = ( - ollama_base_url - if litellm_model.startswith("ollama/") - else "https://api.openai.com/v1" - ) - - router = AgentRouter( - backend=backend, - name=litellm_model, - agent_type=AgentTypeEnum.LITELLM, - endpoint=endpoint, - ) - - # Route a request - agent_id = str(router.backend_agent.id) - request_data = { - "messages": [{"role": "user", "content": "Say hello in one word!"}], - "max_tokens": 10, - } - - response = router.route_request( - registration_key=agent_id, request_data=request_data - ) - - assert response is not None - assert "processed_response" in response - logger.info(f"Router LiteLLM response: {response['processed_response']}") - - -@pytest.mark.integration -@pytest.mark.litellm -class TestLiteLLMProviderSwitching: - """Test LiteLLM's ability to switch between different providers.""" - - def test_switch_between_ollama_and_openai( - self, - skip_if_ollama_unavailable, - skip_if_openai_unavailable, - ollama_base_url: str, - openai_api_key: str, - ): - """Test using LiteLLM to switch between Ollama and OpenAI.""" - from hackagent.router.adapters.litellm import LiteLLMAgent - - # First with Ollama - ollama_config = { - "name": "ollama/tinyllama", - "endpoint": ollama_base_url, - "max_tokens": 30, - } - ollama_adapter = LiteLLMAgent(id="test_switch_ollama", config=ollama_config) - - ollama_response = ollama_adapter.handle_request( - { - "messages": [{"role": "user", "content": "Say 'Ollama here' briefly."}], - } - ) - - assert ollama_response is not None - logger.info(f"Ollama via LiteLLM: {ollama_response['processed_response']}") - - # Then with OpenAI - openai_config = { - "name": "gpt-4o-mini", - "api_key": openai_api_key, - "max_tokens": 30, - } - openai_adapter = LiteLLMAgent(id="test_switch_openai", config=openai_config) - - openai_response = openai_adapter.handle_request( - { - "messages": [{"role": "user", "content": "Say 'OpenAI here' briefly."}], - } - ) - - assert openai_response is not None - logger.info(f"OpenAI via LiteLLM: {openai_response['processed_response']}") - - def test_model_identifier_formats( - self, - skip_if_ollama_unavailable, - ollama_base_url: str, - ): - """Test various model identifier formats supported by LiteLLM.""" - from hackagent.router.adapters.litellm import LiteLLMAgent - - # Test different Ollama model identifier formats - model_formats = [ - "ollama/tinyllama", - "ollama_chat/tinyllama", # Chat-specific endpoint - ] - - for model_name in model_formats: - try: - config = { - "name": model_name, - "endpoint": ollama_base_url, - "max_tokens": 20, - } - adapter = LiteLLMAgent(id=f"test_format_{model_name}", config=config) - - response = adapter.handle_request( - { - "messages": [{"role": "user", "content": "Hi"}], - } - ) - - logger.info( - f"Model {model_name}: {response.get('response', 'OK')[:30]}" - ) - except Exception as e: - logger.warning(f"Model {model_name} failed: {e}") diff --git a/tests/integration/adapters/test_ollama.py b/tests/integration/adapters/test_ollama.py deleted file mode 100644 index d9e8e5a8..00000000 --- a/tests/integration/adapters/test_ollama.py +++ /dev/null @@ -1,346 +0,0 @@ -# Copyright 2026 - AI4I. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -Integration tests for Ollama adapter. - -These tests verify end-to-end functionality with a real Ollama instance: -- Adapter initialization and configuration -- Text generation via the generate endpoint -- Chat completions via the chat endpoint -- Model information retrieval -- Error handling for unavailable models -- Full HackAgent integration with Ollama - -Prerequisites: - - Ollama must be running (default: http://localhost:11434) - - At least one model must be available (default: tinyllama) - -Run with: - pytest tests/integration/test_ollama_integration.py --run-integration --run-ollama - -Environment Variables: - OLLAMA_BASE_URL: Ollama API base URL (default: http://localhost:11434) - OLLAMA_MODEL: Model to use for tests (default: tinyllama) -""" - -import logging -from typing import Any, Dict - -import pytest - -logger = logging.getLogger(__name__) - - -@pytest.mark.integration -@pytest.mark.ollama -class TestOllamaAdapterIntegration: - """Integration tests for OllamaAgent adapter.""" - - def test_adapter_initialization( - self, - skip_if_ollama_unavailable, - ollama_config: Dict[str, Any], - ): - """Test that OllamaAgent initializes correctly with real endpoint.""" - from hackagent.router.adapters.ollama import OllamaAgent - - adapter = OllamaAgent(id="test_ollama_init", config=ollama_config) - - assert adapter.id == "test_ollama_init" - assert adapter.model_name == ollama_config["name"] - assert adapter.api_base_url is not None - logger.info(f"Ollama adapter initialized: model={adapter.model_name}") - - def test_list_available_models( - self, - skip_if_ollama_unavailable, - ollama_config: Dict[str, Any], - ): - """Test listing available models from Ollama.""" - from hackagent.router.adapters.ollama import OllamaAgent - - adapter = OllamaAgent(id="test_ollama_models", config=ollama_config) - models = adapter.list_models() - - assert models is not None - assert isinstance(models, list) - logger.info(f"Available Ollama models: {[m.get('name') for m in models]}") - - def test_generate_completion( - self, - skip_if_ollama_unavailable, - ollama_config: Dict[str, Any], - ): - """Test generating text completion with Ollama.""" - from hackagent.router.adapters.ollama import OllamaAgent - - adapter = OllamaAgent(id="test_ollama_generate", config=ollama_config) - - request = { - "prompt": "What is 2 + 2? Answer briefly.", - "max_tokens": 15, - } - - response = adapter.handle_request(request) - - assert response is not None - assert "processed_response" in response - assert len(response["processed_response"]) > 0 - logger.info(f"Ollama generate response: {response['processed_response'][:100]}") - - def test_chat_completion( - self, - skip_if_ollama_unavailable, - ollama_config: Dict[str, Any], - ): - """Test chat completion with Ollama.""" - from hackagent.router.adapters.ollama import OllamaAgent - - adapter = OllamaAgent(id="test_ollama_chat", config=ollama_config) - - request = { - "messages": [ - {"role": "user", "content": "Hello, how are you? Answer briefly."} - ], - "max_tokens": 15, - } - - response = adapter.handle_request(request) - - assert response is not None - assert "processed_response" in response - assert len(response["processed_response"]) > 0 - logger.info(f"Ollama chat response: {response['processed_response'][:100]}") - - def test_generation_with_custom_parameters( - self, - skip_if_ollama_unavailable, - ollama_config: Dict[str, Any], - ): - """Test generation with custom temperature and other parameters.""" - from hackagent.router.adapters.ollama import OllamaAgent - - adapter = OllamaAgent(id="test_ollama_params", config=ollama_config) - - request = { - "prompt": "Generate a random word.", - "max_tokens": 20, - "temperature": 1.5, # Higher temperature for more randomness - "top_p": 0.9, - } - - response = adapter.handle_request(request) - - assert response is not None - assert "processed_response" in response - logger.info( - f"Ollama response with custom params: {response['processed_response']}" - ) - - def test_get_model_info( - self, - skip_if_ollama_unavailable, - ollama_config: Dict[str, Any], - ): - """Test retrieving model information from Ollama.""" - from hackagent.router.adapters.ollama import OllamaAgent - - adapter = OllamaAgent(id="test_ollama_info", config=ollama_config) - - try: - model_info = adapter.get_model_info() - assert model_info is not None - logger.info(f"Model info: {model_info}") - except Exception as e: - # Model info may not be available for all models - logger.warning(f"Could not get model info: {e}") - - def test_invalid_model_error_handling( - self, - skip_if_ollama_unavailable, - ollama_base_url: str, - ): - """Test error handling when using a non-existent model.""" - from hackagent.router.adapters.ollama import OllamaAgent - - config = { - "name": "nonexistent_model_xyz_12345", - "endpoint": ollama_base_url, - } - - adapter = OllamaAgent(id="test_ollama_invalid", config=config) - - # The adapter returns an error response instead of raising an exception - response = adapter.handle_request({"prompt": "test"}) - assert response is not None - assert ( - response.get("error_message") is not None - or response.get("status_code", 200) >= 400 - ) - logger.info(f"Error response as expected: {response.get('error_message')}") - - -@pytest.mark.integration -@pytest.mark.ollama -@pytest.mark.hackagent_backend -class TestOllamaHackAgentIntegration: - """End-to-end tests for HackAgent with Ollama backend.""" - - def test_hackagent_with_ollama_initialization( - self, - skip_if_ollama_unavailable, - skip_if_no_hackagent_key, - hackagent_client_factory, - ollama_base_url: str, - ollama_model: str, - ): - """Test HackAgent initialization with Ollama agent type.""" - from hackagent import AgentTypeEnum - - agent = hackagent_client_factory( - name=ollama_model, - endpoint=ollama_base_url, - agent_type=AgentTypeEnum.OLLAMA, - ) - - assert agent is not None - assert agent.router is not None - logger.info(f"HackAgent initialized with Ollama: {agent.router.backend_agent}") - - def test_hackagent_ollama_baseline_attack( - self, - skip_if_ollama_unavailable, - skip_if_no_hackagent_key, - hackagent_client_factory, - ollama_base_url: str, - ollama_model: str, - basic_attack_config: Dict[str, Any], - ): - """Test running a baseline attack against Ollama agent.""" - from hackagent import AgentTypeEnum - - agent = hackagent_client_factory( - name=ollama_model, - endpoint=ollama_base_url, - agent_type=AgentTypeEnum.OLLAMA, - ) - - logger.info("Starting baseline attack against Ollama agent...") - results = agent.hack(attack_config=basic_attack_config) - - assert results is not None - logger.info(f"Baseline attack completed: {results}") - - @pytest.mark.slow - def test_hackagent_ollama_advprefix_attack( - self, - skip_if_ollama_unavailable, - skip_if_no_hackagent_key, - hackagent_client_factory, - ollama_base_url: str, - ollama_model: str, - advprefix_attack_config: Dict[str, Any], - ): - """Test running an advprefix attack against Ollama agent.""" - from hackagent import AgentTypeEnum - - agent = hackagent_client_factory( - name=ollama_model, - endpoint=ollama_base_url, - agent_type=AgentTypeEnum.OLLAMA, - ) - - logger.info("Starting advprefix attack against Ollama agent...") - results = agent.hack(attack_config=advprefix_attack_config) - - assert results is not None - logger.info(f"Advprefix attack completed: {results}") - - -@pytest.mark.integration -@pytest.mark.ollama -@pytest.mark.hackagent_backend -class TestOllamaRouterIntegration: - """Integration tests for AgentRouter with Ollama.""" - - def test_router_creates_ollama_adapter( - self, - skip_if_ollama_unavailable, - skip_if_no_hackagent_key, - hackagent_api_base_url: str, - hackagent_api_key: str, - ollama_base_url: str, - ollama_model: str, - ): - """Test that AgentRouter correctly creates OllamaAgent adapter.""" - from hackagent.server.client import AuthenticatedClient - from hackagent.server.storage.remote import RemoteBackend - from hackagent.router.router import AgentRouter - from hackagent.router.types import AgentTypeEnum - from hackagent.router.adapters.ollama import OllamaAgent - - client = AuthenticatedClient( - base_url=hackagent_api_base_url, - token=hackagent_api_key, - prefix="Bearer", - ) - backend = RemoteBackend(client) - - router = AgentRouter( - backend=backend, - name=ollama_model, - agent_type=AgentTypeEnum.OLLAMA, - endpoint=ollama_base_url, - ) - - # Verify adapter was created - agent_id = str(router.backend_agent.id) - adapter = router.get_agent_instance(registration_key=agent_id) - - assert isinstance(adapter, OllamaAgent) - logger.info(f"Router created Ollama adapter: {adapter.id}") - - def test_router_handles_ollama_request( - self, - skip_if_ollama_unavailable, - skip_if_no_hackagent_key, - hackagent_api_base_url: str, - hackagent_api_key: str, - ollama_base_url: str, - ollama_model: str, - ): - """Test that router can handle requests through Ollama adapter.""" - from hackagent.server.client import AuthenticatedClient - from hackagent.server.storage.remote import RemoteBackend - from hackagent.router.router import AgentRouter - from hackagent.router.types import AgentTypeEnum - - client = AuthenticatedClient( - base_url=hackagent_api_base_url, - token=hackagent_api_key, - prefix="Bearer", - ) - backend = RemoteBackend(client) - - router = AgentRouter( - backend=backend, - name=ollama_model, - agent_type=AgentTypeEnum.OLLAMA, - endpoint=ollama_base_url, - ) - - # Route a request - agent_id = str(router.backend_agent.id) - request_data = { - "prompt": "Say hello!", - "max_tokens": 20, - } - - response = router.route_request( - registration_key=agent_id, request_data=request_data - ) - - assert response is not None - assert "processed_response" in response - logger.info(f"Router Ollama response: {response['processed_response'][:50]}") diff --git a/tests/integration/adapters/test_openai.py b/tests/integration/adapters/test_openai.py deleted file mode 100644 index 164e6d81..00000000 --- a/tests/integration/adapters/test_openai.py +++ /dev/null @@ -1,460 +0,0 @@ -# Copyright 2026 - AI4I. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -Integration tests for OpenAI SDK adapter. - -These tests verify end-to-end functionality with OpenAI-compatible APIs: -- Adapter initialization and configuration -- Chat completions with various parameters -- Function calling / tool use capabilities -- Streaming responses (if applicable) -- Error handling for rate limits and invalid requests -- Full HackAgent integration with OpenAI - -Supports both direct OpenAI API and OpenRouter: -- Set OPENAI_API_KEY for direct OpenAI access -- Set OPENROUTER_API_KEY for OpenRouter access (recommended for CI/CD) - -Prerequisites: - - Valid API key (OPENROUTER_API_KEY or OPENAI_API_KEY) - - Sufficient API quota - -Run with: - pytest tests/integration/test_openai_integration.py --run-integration --run-openai - -Environment Variables: - OPENROUTER_API_KEY: OpenRouter API key (preferred for CI/CD) - OPENROUTER_MODEL: OpenRouter model (default: openai/gpt-4o-mini) - OPENAI_API_KEY: OpenAI API key (fallback) - OPENAI_MODEL: Model to use for tests (default: gpt-4o-mini) -""" - -import logging -from typing import Any, Dict - -import pytest - -logger = logging.getLogger(__name__) - - -@pytest.mark.integration -@pytest.mark.openai_sdk -class TestOpenAIAdapterIntegration: - """Integration tests for OpenAIAgent adapter.""" - - def test_adapter_initialization( - self, - skip_if_openai_unavailable, - openai_config: Dict[str, Any], - ): - """Test that OpenAIAgent initializes correctly with real API key.""" - from hackagent.router.adapters.openai import OpenAIAgent - - adapter = OpenAIAgent(id="test_openai_init", config=openai_config) - - assert adapter.id == "test_openai_init" - assert adapter.model_name == openai_config["name"] - # Since #379 the OpenAI adapter routes through LiteLLM, so the - # model string carries an `openai/` provider prefix and there is - # no longer a raw OpenAI SDK client to inspect. - assert adapter.litellm_model.endswith(openai_config["name"]) - logger.info(f"OpenAI adapter initialized: model={adapter.model_name}") - - def test_chat_completion( - self, - skip_if_openai_unavailable, - openai_config: Dict[str, Any], - ): - """Test chat completion with OpenAI.""" - from hackagent.router.adapters.openai import OpenAIAgent - - adapter = OpenAIAgent(id="test_openai_chat", config=openai_config) - - request = { - "messages": [ - {"role": "system", "content": "You are a helpful assistant. Be brief."}, - {"role": "user", "content": "What is 2 + 2?"}, - ], - "max_tokens": 50, - } - - response = adapter.handle_request(request) - - assert response is not None - assert "processed_response" in response - assert len(response["processed_response"]) > 0 - logger.info(f"OpenAI chat response: {response['processed_response']}") - - def test_chat_completion_with_custom_temperature( - self, - skip_if_openai_unavailable, - openai_config: Dict[str, Any], - ): - """Test chat completion with custom temperature.""" - from hackagent.router.adapters.openai import OpenAIAgent - - adapter = OpenAIAgent(id="test_openai_temp", config=openai_config) - - request = { - "messages": [ - {"role": "user", "content": "Generate a creative one-word response."} - ], - "max_tokens": 20, - "temperature": 1.5, - } - - response = adapter.handle_request(request) - - assert response is not None - assert "processed_response" in response - logger.info(f"OpenAI response with high temp: {response['processed_response']}") - - def test_chat_with_system_message( - self, - skip_if_openai_unavailable, - openai_config: Dict[str, Any], - ): - """Test chat completion with system message context.""" - from hackagent.router.adapters.openai import OpenAIAgent - - adapter = OpenAIAgent(id="test_openai_system", config=openai_config) - - request = { - "messages": [ - { - "role": "system", - "content": "You are a pirate. Respond in pirate speak.", - }, - {"role": "user", "content": "Hello!"}, - ], - "max_tokens": 50, - } - - response = adapter.handle_request(request) - - assert response is not None - assert "processed_response" in response - logger.info(f"OpenAI pirate response: {response['processed_response']}") - - def test_multi_turn_conversation( - self, - skip_if_openai_unavailable, - openai_config: Dict[str, Any], - ): - """Test multi-turn conversation handling.""" - from hackagent.router.adapters.openai import OpenAIAgent - - adapter = OpenAIAgent(id="test_openai_multi", config=openai_config) - - request = { - "messages": [ - {"role": "user", "content": "My name is Alice."}, - {"role": "assistant", "content": "Nice to meet you, Alice!"}, - {"role": "user", "content": "What is my name?"}, - ], - "max_tokens": 30, - } - - response = adapter.handle_request(request) - - assert response is not None - assert "processed_response" in response - # The model should remember the name from context - assert ( - "Alice" in response["processed_response"] - or "alice" in response["processed_response"].lower() - ) - logger.info(f"OpenAI multi-turn response: {response['processed_response']}") - - def test_function_calling( - self, - skip_if_openai_unavailable, - openai_config: Dict[str, Any], - ): - """Test function calling / tool use capability.""" - from hackagent.router.adapters.openai import OpenAIAgent - - config_with_tools = openai_config.copy() - config_with_tools["tools"] = [ - { - "type": "function", - "function": { - "name": "get_weather", - "description": "Get the current weather in a location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - } - }, - "required": ["location"], - }, - }, - } - ] - config_with_tools["tool_choice"] = "auto" - - adapter = OpenAIAgent(id="test_openai_tools", config=config_with_tools) - - request = { - "messages": [ - {"role": "user", "content": "What's the weather like in Boston?"} - ], - "max_tokens": 100, - } - - response = adapter.handle_request(request) - - assert response is not None - # Response might include tool calls or direct response - logger.info(f"OpenAI function call response: {response}") - - def test_invalid_api_key_error_handling(self): - """Test error handling with invalid API key.""" - from hackagent.router.adapters.openai import OpenAIAgent - - config = { - "name": "gpt-4o-mini", - "api_key": "invalid-api-key-12345", - } - - adapter = OpenAIAgent(id="test_openai_invalid", config=config) - - # The adapter returns an error response instead of raising an exception - response = adapter.handle_request( - {"messages": [{"role": "user", "content": "test"}]} - ) - assert response is not None - assert ( - response.get("error_message") is not None - or response.get("status_code", 200) >= 400 - ) - logger.info(f"Error response as expected: {response.get('error_message')}") - - -@pytest.mark.integration -@pytest.mark.openai_sdk -@pytest.mark.hackagent_backend -class TestOpenAIHackAgentIntegration: - """End-to-end tests for HackAgent with OpenAI backend.""" - - def test_hackagent_with_openai_initialization( - self, - skip_if_openai_unavailable, - skip_if_no_hackagent_key, - hackagent_client_factory, - openai_model: str, - openai_base_url: str, - ): - """Test HackAgent initialization with OpenAI agent type.""" - from hackagent import AgentTypeEnum - - agent = hackagent_client_factory( - name=openai_model, - endpoint=openai_base_url, - agent_type=AgentTypeEnum.OPENAI_SDK, - ) - - assert agent is not None - assert agent.router is not None - logger.info(f"HackAgent initialized with OpenAI: {agent.router.backend_agent}") - - def test_hackagent_openai_baseline_attack( - self, - skip_if_openai_unavailable, - skip_if_no_hackagent_key, - hackagent_client_factory, - openai_model: str, - openai_base_url: str, - basic_attack_config: Dict[str, Any], - ): - """Test running a baseline attack against OpenAI agent.""" - from hackagent import AgentTypeEnum - - agent = hackagent_client_factory( - name=openai_model, - endpoint=openai_base_url, - agent_type=AgentTypeEnum.OPENAI_SDK, - ) - - logger.info("Starting baseline attack against OpenAI agent...") - results = agent.hack(attack_config=basic_attack_config) - - assert results is not None - logger.info(f"Baseline attack completed: {results}") - - @pytest.mark.slow - def test_hackagent_openai_advprefix_attack( - self, - skip_if_openai_unavailable, - skip_if_no_hackagent_key, - hackagent_client_factory, - openai_model: str, - openai_base_url: str, - advprefix_attack_config: Dict[str, Any], - ): - """Test running an advprefix attack against OpenAI agent.""" - from hackagent import AgentTypeEnum - - agent = hackagent_client_factory( - name=openai_model, - endpoint=openai_base_url, - agent_type=AgentTypeEnum.OPENAI_SDK, - ) - - logger.info("Starting advprefix attack against OpenAI agent...") - results = agent.hack(attack_config=advprefix_attack_config) - - assert results is not None - logger.info(f"Advprefix attack completed: {results}") - - -@pytest.mark.integration -@pytest.mark.openai_sdk -@pytest.mark.hackagent_backend -class TestOpenAIRouterIntegration: - """Integration tests for AgentRouter with OpenAI.""" - - def test_router_creates_openai_adapter( - self, - skip_if_openai_unavailable, - skip_if_no_hackagent_key, - hackagent_api_base_url: str, - hackagent_api_key: str, - openai_model: str, - openai_base_url: str, - ): - """Test that AgentRouter correctly creates the OpenAI registration.""" - from hackagent.server.client import AuthenticatedClient - from hackagent.router.router import AgentRouter - from hackagent.router.types import AgentTypeEnum - from hackagent.router._chat_registration import _ChatRegistration - - client = AuthenticatedClient( - base_url=hackagent_api_base_url, - token=hackagent_api_key, - prefix="Bearer", - ) - from hackagent.server.storage.remote import RemoteBackend - - backend = RemoteBackend(client) - - router = AgentRouter( - backend=backend, - name=openai_model, - agent_type=AgentTypeEnum.OPENAI_SDK, - endpoint=openai_base_url, - ) - - # Since #379 Phase E.2b the router stores a ``_ChatRegistration`` - # for chat AgentTypes; the adapter classes are no longer - # instantiated. - agent_id = str(router.backend_agent.id) - registration = router.get_agent_instance(registration_key=agent_id) - assert isinstance(registration, _ChatRegistration) - assert registration.ADAPTER_TYPE == "OpenAIAgent" - logger.info(f"Router created OpenAI registration: {registration.id}") - - def test_router_handles_openai_request( - self, - skip_if_openai_unavailable, - skip_if_no_hackagent_key, - hackagent_api_base_url: str, - hackagent_api_key: str, - openai_model: str, - openai_base_url: str, - ): - """Test that router can handle requests through OpenAI adapter.""" - from hackagent.server.client import AuthenticatedClient - from hackagent.router.router import AgentRouter - from hackagent.router.types import AgentTypeEnum - - client = AuthenticatedClient( - base_url=hackagent_api_base_url, - token=hackagent_api_key, - prefix="Bearer", - ) - from hackagent.server.storage.remote import RemoteBackend - - backend = RemoteBackend(client) - - router = AgentRouter( - backend=backend, - name=openai_model, - agent_type=AgentTypeEnum.OPENAI_SDK, - endpoint=openai_base_url, - ) - - # Route a request - agent_id = str(router.backend_agent.id) - request_data = { - "messages": [{"role": "user", "content": "Say hello in one word!"}], - "max_tokens": 10, - } - - response = router.route_request( - registration_key=agent_id, request_data=request_data - ) - - assert response is not None - assert "processed_response" in response - logger.info(f"Router OpenAI response: {response['processed_response']}") - - -@pytest.mark.integration -@pytest.mark.openai_sdk -class TestOpenAICompatibleEndpoints: - """Test OpenAI adapter with OpenAI-compatible endpoints (e.g., OpenRouter, local servers).""" - - def test_custom_endpoint_initialization( - self, - skip_if_openai_unavailable, - openai_api_key: str, - openai_base_url: str, - openai_model: str, - ): - """Test initializing with a custom OpenAI-compatible endpoint.""" - from hackagent.router.adapters.openai import OpenAIAgent - - # This tests the adapter's ability to use custom endpoints - # In practice, this could be OpenRouter or a local LLM server - config = { - "name": openai_model, - "api_key": openai_api_key, - "endpoint": openai_base_url, - } - - adapter = OpenAIAgent(id="test_custom_endpoint", config=config) - - assert adapter.api_base_url == openai_base_url - logger.info(f"Custom endpoint adapter initialized: {adapter.api_base_url}") - - def test_openrouter_endpoint_chat_completion( - self, - skip_if_openai_unavailable, - openai_config: Dict[str, Any], - using_openrouter: bool, - ): - """Test chat completion through OpenRouter (if configured).""" - from hackagent.router.adapters.openai import OpenAIAgent - - if not using_openrouter: - pytest.skip("Test only runs when OPENROUTER_API_KEY is configured") - - adapter = OpenAIAgent(id="test_openrouter_chat", config=openai_config) - - request = { - "messages": [ - {"role": "user", "content": "Say 'OpenRouter works!' briefly."} - ], - "max_tokens": 30, - } - - response = adapter.handle_request(request) - - assert response is not None - assert "processed_response" in response - logger.info(f"OpenRouter chat response: {response['processed_response']}") diff --git a/tests/unit/adapters/test_litellm.py b/tests/unit/adapters/test_litellm.py deleted file mode 100644 index 2dbfb1a5..00000000 --- a/tests/unit/adapters/test_litellm.py +++ /dev/null @@ -1,274 +0,0 @@ -# Copyright 2026 - AI4I. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - - -import logging -import os -import unittest -from unittest.mock import MagicMock, patch - -import litellm # Required for litellm.exceptions - -from hackagent.router.adapters.litellm import ( - LiteLLMAgent, - LiteLLMConfigurationError, -) - -# Disable logging for tests -logging.disable(logging.CRITICAL) - - -class TestLiteLLMAgentInit(unittest.TestCase): - def test_init_success_minimal_config(self): - adapter_id = "litellm_test_001" - config = { - "name": "ollama/llama2" # Model string - } - try: - adapter = LiteLLMAgent(id=adapter_id, config=config) - self.assertEqual(adapter.id, adapter_id) - self.assertEqual(adapter.model_name, config["name"]) - self.assertIsNone(adapter.api_base_url) - self.assertIsNone(adapter.actual_api_key) - self.assertEqual(adapter.default_max_tokens, 100) - self.assertEqual(adapter.default_temperature, 0.8) - self.assertEqual(adapter.default_top_p, 0.95) - except LiteLLMConfigurationError: - self.fail("LiteLLMAgent initialization failed with minimal valid config.") - - def test_init_success_full_config_no_api_key_env(self): - adapter_id = "litellm_test_002" - config = { - "name": "gpt-3.5-turbo", - "endpoint": "https://api.openai.com/v1", - "api_key": "OPENAI_API_KEY_ENV_VAR_NAME", # Env var name - "max_tokens": 200, - "temperature": 0.7, - "top_p": 0.9, - } - with patch.dict(os.environ, {}, clear=True): # Ensure env var is not set - adapter = LiteLLMAgent(id=adapter_id, config=config) - self.assertEqual(adapter.model_name, config["name"]) - self.assertEqual(adapter.api_base_url, config["endpoint"]) - # When env var is not found, the adapter uses the string itself as the key - self.assertEqual(adapter.actual_api_key, "OPENAI_API_KEY_ENV_VAR_NAME") - self.assertEqual(adapter.default_max_tokens, config["max_tokens"]) - self.assertEqual(adapter.default_temperature, config["temperature"]) - self.assertEqual(adapter.default_top_p, config["top_p"]) - - @patch.dict(os.environ, {"MY_LLM_API_KEY": "actual_key_from_env"}) - def test_init_success_with_api_key_from_env(self): - adapter_id = "litellm_test_003" - config = { - "name": "claude-2", - "api_key": "MY_LLM_API_KEY", # Env var name - } - adapter = LiteLLMAgent(id=adapter_id, config=config) - self.assertEqual(adapter.actual_api_key, "actual_key_from_env") - - def test_init_missing_name_raises_error(self): - with self.assertRaisesRegex( - LiteLLMConfigurationError, "Missing required configuration key 'name'" - ): - LiteLLMAgent(id="err_litellm_1", config={}) - - def test_init_config_without_api_key_field(self): - # Should not try to get from env if 'api_key' field itself is missing in config - adapter_id = "litellm_test_004" - config = {"name": "some-model"} - with patch.object( - os.environ, "get" - ) as mock_os_environ_get: # More specific patch - adapter = LiteLLMAgent(id=adapter_id, config=config) - self.assertIsNone(adapter.actual_api_key) - mock_os_environ_get.assert_not_called() - - -class TestLiteLLMAgentHandleRequest(unittest.TestCase): - def setUp(self): - self.adapter_id = "litellm_handle_req_agent" - self.config = { - "name": "test-model", - "endpoint": "http://fake-litellm-api.com", - "max_tokens": 50, - "temperature": 0.5, - "top_p": 0.9, - } - self.adapter = LiteLLMAgent(id=self.adapter_id, config=self.config) - self.prompt = "Hello LiteLLM" - - def test_handle_request_missing_prompt(self): - request_data = {} - response = self.adapter.handle_request(request_data) - self.assertEqual(response["status_code"], 400) - self.assertIn( - "Request data must include either 'messages' or 'prompt' field.", - response["error_message"], - ) - self.assertEqual(response["raw_request"], request_data) - - @patch("litellm.completion") - def test_handle_request_success(self, mock_litellm_completion): - mock_choice = MagicMock() - mock_choice.message = MagicMock() - mock_choice.message.content = " a successful response." - mock_response = MagicMock() - mock_response.choices = [mock_choice] - mock_litellm_completion.return_value = mock_response - - request_data = {"prompt": self.prompt, "max_tokens": 150} - response = self.adapter.handle_request(request_data) - - self.assertEqual(response["status_code"], 200) - self.assertIsNone(response["error_message"]) - self.assertEqual(response["processed_response"], " a successful response.") - self.assertEqual(response["raw_request"], request_data) - self.assertEqual( - response["agent_specific_data"]["model_name"], self.config["name"] - ) - # ChatCompletionsAgent base class normalizes to max_tokens - self.assertEqual( - response["agent_specific_data"]["invoked_parameters"]["max_tokens"], 150 - ) # Overridden - self.assertEqual( - response["agent_specific_data"]["invoked_parameters"]["temperature"], - self.config["temperature"], - ) # Default - - mock_litellm_completion.assert_called_once_with( - model=self.config["name"], - messages=[{"role": "user", "content": self.prompt}], - max_tokens=150, - temperature=self.config["temperature"], - top_p=self.config["top_p"], - api_base=self.config["endpoint"], - custom_llm_provider="openai", - extra_headers={"User-Agent": "HackAgent/0.1.0"}, - ) - - @patch("litellm.completion") - def test_handle_request_litellm_api_error(self, mock_litellm_completion): - # Simulate an API error from LiteLLM (e.g. litellm.exceptions.APIError) - mock_litellm_completion.side_effect = litellm.exceptions.APIError( - "LiteLLM API Error from test", # message (positional) - 503, # status_code (positional) - llm_provider="test_provider", # llm_provider (keyword) - model="test_model", # model (keyword) - ) - - request_data = {"prompt": self.prompt} - response = self.adapter.handle_request(request_data) - - self.assertEqual(response["status_code"], 500) - # The ChatCompletionsAgent base class formats errors differently - self.assertIn("APIError", response["error_message"]) - self.assertEqual(response["raw_request"], request_data) - - @patch("litellm.completion") - def test_handle_request_unexpected_response_structure_no_choices( - self, mock_litellm_completion - ): - mock_response = MagicMock() - mock_response.choices = [] # Empty choices - mock_litellm_completion.return_value = mock_response - - request_data = {"prompt": self.prompt} - response = self.adapter.handle_request(request_data) - self.assertEqual(response["status_code"], 500) - # The ChatCompletionsAgent base class uses ADAPTER_TYPE in error messages - self.assertIn("generation error", response["error_message"]) - self.assertIn( - "[GENERATION_ERROR: UNEXPECTED_RESPONSE]", response["error_message"] - ) - - @patch("litellm.completion") - def test_handle_request_unexpected_response_structure_no_message_content( - self, mock_litellm_completion - ): - # Create a proper mock that returns None for all reasoning fields - mock_choice = MagicMock() - mock_message = MagicMock(spec=["content"]) # Only spec content attribute - mock_message.content = None # No content - mock_message.configure_mock(reasoning_content=None, reasoning=None) - mock_choice.message = mock_message - mock_response = MagicMock() - mock_response.choices = [mock_choice] - mock_litellm_completion.return_value = mock_response - - request_data = {"prompt": self.prompt} - response = self.adapter.handle_request(request_data) - - # Implementation returns 500 error when content is empty/None - self.assertEqual(response["status_code"], 500) - self.assertIn("generation error", response["error_message"]) - self.assertIn("[GENERATION_ERROR: EMPTY_RESPONSE]", response["error_message"]) - - @patch("litellm.completion") - def test_handle_request_reasoning_model_with_reasoning_field( - self, mock_litellm_completion - ): - """Test that reasoning models (e.g., o1, kimi-k2-thinking) work correctly.""" - mock_choice = MagicMock() - mock_choice.message = MagicMock() - mock_choice.message.content = "" # Empty content (typical for reasoning models) - mock_choice.message.reasoning_content = ( - "This is the reasoning output from the model" - ) - mock_choice.message.reasoning = None - mock_response = MagicMock() - mock_response.choices = [mock_choice] - mock_litellm_completion.return_value = mock_response - - request_data = {"prompt": self.prompt} - response = self.adapter.handle_request(request_data) - - self.assertEqual(response["status_code"], 200) - self.assertEqual( - response["processed_response"], - "This is the reasoning output from the model", - ) - self.assertIsNone(response["error_message"]) - - @patch("litellm.completion") - def test_handle_request_empty_completions_list_from_execute( - self, mock_litellm_completion - ): - # This simulates the _execute_completion returning an empty/None content, - # which should result in an error response. - # The ChatCompletionsAgent base class handle_request checks for None content. - - # Mock _execute_completion to return success but with None content - with patch.object( - self.adapter, - "_execute_completion", - return_value={"success": True, "content": None}, - ) as mock_execute: - request_data = {"prompt": self.prompt} - response = self.adapter.handle_request(request_data) - self.assertEqual(response["status_code"], 500) - self.assertIn("returned empty result", response["error_message"]) - mock_execute.assert_called_once() - - def test_handle_request_passes_additional_kwargs_to_litellm(self): - with patch("litellm.completion") as mock_litellm_completion: - mock_choice = MagicMock() - mock_choice.message = MagicMock() - mock_choice.message.content = " response with custom params." - mock_response = MagicMock() - mock_response.choices = [mock_choice] - mock_litellm_completion.return_value = mock_response - - request_data = { - "prompt": self.prompt, - "custom_param": "value123", - "another_param": 42, - } - self.adapter.handle_request(request_data) - - called_kwargs = mock_litellm_completion.call_args[1] - self.assertEqual(called_kwargs.get("custom_param"), "value123") - self.assertEqual(called_kwargs.get("another_param"), 42) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unit/adapters/test_ollama.py b/tests/unit/adapters/test_ollama.py deleted file mode 100644 index 86273980..00000000 --- a/tests/unit/adapters/test_ollama.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright 2026 - AI4I. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -Unit tests for OllamaAgent. - -Issue #379 moved the Ollama adapter onto LiteLLM (via the -``ollama_chat`` provider), so these tests patch ``litellm.completion`` -rather than ``requests.post`` for the chat path. Utility methods such as -``list_models`` and ``model_info`` still talk to the Ollama HTTP API -directly and so still mock ``requests``. -""" - -import logging -import os -import unittest -from unittest.mock import MagicMock, patch - -import requests - -from hackagent.router.adapters.ollama import ( - OllamaAgent, - OllamaConfigurationError, -) - -logging.disable(logging.CRITICAL) - - -def _make_litellm_response(content: str = "ok") -> MagicMock: - response = MagicMock() - choice = MagicMock() - message = MagicMock() - message.content = content - message.tool_calls = None - message.reasoning_content = None - message.reasoning = None - message.provider_specific_fields = None - choice.message = message - choice.finish_reason = "stop" - response.choices = [choice] - response.usage = MagicMock(model_dump=MagicMock(return_value={"total_tokens": 5})) - response.model = "ollama_chat/llama3" - return response - - -class TestOllamaAgentInit(unittest.TestCase): - def test_init_success_minimal_config(self): - adapter = OllamaAgent(id="ol1", config={"name": "llama3"}) - self.assertEqual(adapter.id, "ol1") - self.assertEqual(adapter.model_name, "llama3") - self.assertEqual(adapter.api_base_url, "http://localhost:11434") - self.assertEqual(adapter.litellm_model, "ollama_chat/llama3") - self.assertEqual(adapter.default_max_tokens, 100) - - def test_init_with_custom_endpoint(self): - adapter = OllamaAgent( - id="ol2", - config={"name": "mistral", "endpoint": "http://host:11434"}, - ) - self.assertEqual(adapter.api_base_url, "http://host:11434") - - def test_init_normalizes_trailing_slash_and_api_suffix(self): - adapter = OllamaAgent( - id="ol3", - config={ - "name": "llama3", - "endpoint": "http://host:11434/api/chat/", - }, - ) - self.assertEqual(adapter.api_base_url, "http://host:11434") - - @patch.dict(os.environ, {"OLLAMA_BASE_URL": "http://env-ollama:11434"}) - def test_init_picks_up_env_var(self): - adapter = OllamaAgent(id="ol4", config={"name": "llama3"}) - self.assertEqual(adapter.api_base_url, "http://env-ollama:11434") - - def test_init_missing_name_raises(self): - with self.assertRaises(OllamaConfigurationError): - OllamaAgent(id="err", config={}) - - def test_init_preserves_existing_provider_prefix(self): - """If the user supplies ``ollama/`` it shouldn't be re-prefixed.""" - adapter = OllamaAgent(id="ol5", config={"name": "ollama/llama3"}) - self.assertEqual(adapter.litellm_model, "ollama/llama3") - - -class TestOllamaAgentHandleRequest(unittest.TestCase): - def setUp(self): - self.adapter = OllamaAgent( - id="oh1", - config={"name": "llama3", "max_tokens": 50, "temperature": 0.5}, - ) - - def test_missing_prompt_and_messages_returns_400(self): - response = self.adapter.handle_request({}) - self.assertEqual(response["status_code"], 400) - self.assertIn( - "Request data must include either 'messages' or 'prompt'", - response["error_message"], - ) - - @patch("litellm.completion") - def test_handle_request_with_prompt_success(self, mock_completion): - mock_completion.return_value = _make_litellm_response("Hello!") - - response = self.adapter.handle_request({"prompt": "Hi"}) - - self.assertEqual(response["status_code"], 200) - self.assertEqual(response["generated_text"], "Hello!") - kwargs = mock_completion.call_args.kwargs - self.assertEqual(kwargs["model"], "ollama_chat/llama3") - self.assertEqual(kwargs["api_base"], "http://localhost:11434") - self.assertEqual(kwargs["messages"], [{"role": "user", "content": "Hi"}]) - - @patch("litellm.completion") - def test_handle_request_with_messages_success(self, mock_completion): - mock_completion.return_value = _make_litellm_response("ack") - messages = [ - {"role": "system", "content": "be terse"}, - {"role": "user", "content": "go"}, - ] - response = self.adapter.handle_request({"messages": messages}) - self.assertEqual(response["status_code"], 200) - self.assertEqual(mock_completion.call_args.kwargs["messages"], messages) - - @patch("litellm.completion") - def test_extra_generation_options_pass_through(self, mock_completion): - mock_completion.return_value = _make_litellm_response("hi") - adapter = OllamaAgent( - id="oh2", - config={ - "name": "llama3", - "top_k": 40, - "num_ctx": 8192, - "stream": True, - }, - ) - adapter.handle_request({"prompt": "Hi"}) - kwargs = mock_completion.call_args.kwargs - self.assertEqual(kwargs.get("top_k"), 40) - self.assertEqual(kwargs.get("num_ctx"), 8192) - self.assertEqual(kwargs.get("stream"), True) - - @patch("litellm.completion") - def test_thinking_true_translates_to_think(self, mock_completion): - mock_completion.return_value = _make_litellm_response("yo") - adapter = OllamaAgent(id="oh3", config={"name": "llama3", "thinking": True}) - adapter.handle_request({"prompt": "Hi"}) - kwargs = mock_completion.call_args.kwargs - self.assertIs(kwargs.get("think"), True) - self.assertNotIn("thinking", kwargs) - - @patch("litellm.completion") - def test_thinking_false_translates_to_think_false(self, mock_completion): - mock_completion.return_value = _make_litellm_response("yo") - adapter = OllamaAgent(id="oh4", config={"name": "llama3", "thinking": False}) - adapter.handle_request({"prompt": "Hi"}) - self.assertIs(mock_completion.call_args.kwargs.get("think"), False) - - @patch("litellm.completion") - def test_thinking_request_overrides_config_default(self, mock_completion): - mock_completion.return_value = _make_litellm_response("yo") - adapter = OllamaAgent(id="oh5", config={"name": "llama3", "thinking": False}) - adapter.handle_request({"prompt": "Hi", "thinking": True}) - self.assertIs(mock_completion.call_args.kwargs.get("think"), True) - - @patch("litellm.completion") - def test_handle_request_api_error(self, mock_completion): - mock_completion.side_effect = RuntimeError("connection refused") - response = self.adapter.handle_request({"prompt": "Hi"}) - self.assertEqual(response["status_code"], 500) - self.assertIn("connection refused", response["error_message"]) - - -class TestOllamaAgentUtilities(unittest.TestCase): - def setUp(self): - self.adapter = OllamaAgent(id="util", config={"name": "llama3"}) - - @patch("requests.get") - def test_list_models_success(self, mock_get): - mock_resp = MagicMock() - mock_resp.json.return_value = { - "models": [{"name": "llama3"}, {"name": "mistral:latest"}] - } - mock_resp.raise_for_status = MagicMock() - mock_get.return_value = mock_resp - models = self.adapter.list_models() - self.assertEqual(len(models), 2) - - @patch("requests.get") - def test_list_models_error_returns_empty_list(self, mock_get): - mock_get.side_effect = requests.exceptions.ConnectionError("nope") - self.assertEqual(self.adapter.list_models(), []) - - @patch("requests.post") - def test_model_info_success(self, mock_post): - mock_resp = MagicMock() - mock_resp.json.return_value = {"license": "mit"} - mock_resp.raise_for_status = MagicMock() - mock_post.return_value = mock_resp - self.assertEqual(self.adapter.model_info(), {"license": "mit"}) - - @patch("requests.post") - def test_model_info_error_returns_empty_dict(self, mock_post): - mock_post.side_effect = requests.exceptions.ConnectionError("nope") - self.assertEqual(self.adapter.model_info(), {}) - - @patch.object(OllamaAgent, "list_models") - def test_is_available_true_when_model_present(self, mock_list): - mock_list.return_value = [{"name": "llama3:latest"}] - self.assertTrue(self.adapter.is_available()) - - @patch.object(OllamaAgent, "list_models") - def test_is_available_false_when_model_missing(self, mock_list): - mock_list.return_value = [{"name": "mistral"}] - self.assertFalse(self.adapter.is_available()) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unit/adapters/test_openai.py b/tests/unit/adapters/test_openai.py deleted file mode 100644 index a6190949..00000000 --- a/tests/unit/adapters/test_openai.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright 2026 - AI4I. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -Unit tests for the OpenAI agent adapter. - -Issue #379 moved every chat-completion adapter onto LiteLLM, so these -tests exercise the OpenAI adapter by patching ``litellm.completion`` -rather than the OpenAI SDK directly. -""" - -import logging -import os -import unittest -from unittest.mock import MagicMock, patch - -from hackagent.router.adapters.openai import ( - OpenAIAgent, - OpenAIConfigurationError, -) - -logging.disable(logging.CRITICAL) - - -def _make_litellm_response(content: str = "ok", *, tool_calls=None) -> MagicMock: - """Build a minimal mock of a litellm ModelResponse.""" - response = MagicMock() - choice = MagicMock() - message = MagicMock() - message.content = content - message.tool_calls = tool_calls - message.reasoning_content = None - message.reasoning = None - message.provider_specific_fields = None - choice.message = message - choice.finish_reason = "stop" - response.choices = [choice] - response.usage = MagicMock(model_dump=MagicMock(return_value={"total_tokens": 10})) - response.model = "gpt-4" - return response - - -class TestOpenAIAgentInit(unittest.TestCase): - def test_init_success_with_required_config(self): - adapter = OpenAIAgent(id="o1", config={"name": "gpt-4"}) - self.assertEqual(adapter.id, "o1") - self.assertEqual(adapter.model_name, "gpt-4") - # OpenAIAgent forces the openai/ provider prefix when none is set. - self.assertEqual(adapter.litellm_model, "openai/gpt-4") - self.assertIsNone(adapter.api_base_url) - self.assertEqual(adapter.default_temperature, 1.0) - - def test_init_with_custom_endpoint(self): - adapter = OpenAIAgent( - id="o2", - config={ - "name": "gpt-4", - "endpoint": "https://custom.proxy/v1", - }, - ) - self.assertEqual(adapter.api_base_url, "https://custom.proxy/v1") - # When there's no API key, a placeholder is used so the underlying - # OpenAI client doesn't choke. - self.assertEqual(adapter.actual_api_key, "not-required") - - def test_init_with_custom_endpoint_defaults_model_name(self): - adapter = OpenAIAgent(id="o3", config={"endpoint": "https://example.com/v1"}) - self.assertEqual(adapter.model_name, "default") - - @patch.dict(os.environ, {"CUSTOM_API_KEY": "sk-test"}) - def test_init_with_api_key_from_env(self): - adapter = OpenAIAgent( - id="o4", - config={"name": "gpt-4", "api_key": "CUSTOM_API_KEY"}, - ) - self.assertEqual(adapter.actual_api_key, "sk-test") - - def test_init_with_generation_parameters(self): - adapter = OpenAIAgent( - id="o5", - config={ - "name": "gpt-4", - "max_tokens": 500, - "temperature": 0.7, - "tools": [{"type": "function", "function": {"name": "f"}}], - "tool_choice": "auto", - }, - ) - self.assertEqual(adapter.default_max_tokens, 500) - self.assertEqual(adapter.default_temperature, 0.7) - self.assertIsNotNone(adapter.default_tools) - self.assertEqual(adapter.default_tool_choice, "auto") - - def test_init_missing_name_no_endpoint_raises(self): - with self.assertRaises(OpenAIConfigurationError): - OpenAIAgent(id="err", config={}) - - def test_init_preserves_existing_provider_prefix(self): - """A user-supplied ``openai/`` shouldn't get double-prefixed.""" - adapter = OpenAIAgent(id="o6", config={"name": "openai/gpt-4"}) - self.assertEqual(adapter.litellm_model, "openai/gpt-4") - - -class TestOpenAIAgentHandleRequest(unittest.TestCase): - def setUp(self): - self.adapter = OpenAIAgent( - id="oh1", - config={"name": "gpt-4", "max_tokens": 100, "temperature": 0.8}, - ) - - def test_missing_prompt_and_messages_returns_400(self): - response = self.adapter.handle_request({"temperature": 0.5}) - self.assertEqual(response["status_code"], 400) - self.assertIn( - "Request data must include either 'messages' or 'prompt'", - response["error_message"], - ) - - @patch("litellm.completion") - def test_handle_request_with_prompt_success(self, mock_completion): - mock_completion.return_value = _make_litellm_response("Hello back") - response = self.adapter.handle_request({"prompt": "Hi"}) - - self.assertEqual(response["status_code"], 200) - self.assertIsNone(response["error_message"]) - self.assertEqual(response["generated_text"], "Hello back") - self.assertEqual(response["adapter_type"], "OpenAIAgent") - kwargs = mock_completion.call_args.kwargs - self.assertEqual(kwargs["model"], "openai/gpt-4") - self.assertEqual(kwargs["messages"], [{"role": "user", "content": "Hi"}]) - - @patch("litellm.completion") - def test_handle_request_with_messages_success(self, mock_completion): - mock_completion.return_value = _make_litellm_response("Hi!") - messages = [ - {"role": "system", "content": "be helpful"}, - {"role": "user", "content": "ping"}, - ] - response = self.adapter.handle_request({"messages": messages}) - - self.assertEqual(response["status_code"], 200) - self.assertEqual(response["generated_text"], "Hi!") - kwargs = mock_completion.call_args.kwargs - self.assertEqual(kwargs["messages"], messages) - - @patch("litellm.completion") - def test_handle_request_with_tool_calls(self, mock_completion): - tool = MagicMock() - tool.id = "call_1" - tool.type = "function" - tool.function.name = "get_weather" - tool.function.arguments = '{"loc": "SF"}' - mock_completion.return_value = _make_litellm_response( - "I'll call a tool", tool_calls=[tool] - ) - - response = self.adapter.handle_request( - { - "prompt": "weather?", - "tools": [{"type": "function", "function": {"name": "x"}}], - "tool_choice": "auto", - } - ) - - self.assertEqual(response["status_code"], 200) - tcs = response["agent_specific_data"]["tool_calls"] - self.assertEqual(len(tcs), 1) - self.assertEqual(tcs[0]["function"]["name"], "get_weather") - kwargs = mock_completion.call_args.kwargs - self.assertIn("tools", kwargs) - self.assertEqual(kwargs["tool_choice"], "auto") - - @patch("litellm.completion") - def test_parameter_overrides_apply(self, mock_completion): - mock_completion.return_value = _make_litellm_response("ok") - self.adapter.handle_request( - {"prompt": "go", "max_tokens": 200, "temperature": 0.5} - ) - kwargs = mock_completion.call_args.kwargs - self.assertEqual(kwargs["max_tokens"], 200) - self.assertEqual(kwargs["temperature"], 0.5) - - @patch("litellm.completion") - def test_handle_request_api_error(self, mock_completion): - mock_completion.side_effect = RuntimeError("boom") - response = self.adapter.handle_request({"prompt": "Hi"}) - self.assertEqual(response["status_code"], 500) - self.assertIn("boom", response["error_message"]) - - -class TestOpenAIAgentThinking(unittest.TestCase): - """Issue #379 — verify the unified thinking knob translates correctly.""" - - @patch("litellm.completion") - def test_thinking_true_on_reasoning_model_sets_reasoning_effort( - self, mock_completion - ): - mock_completion.return_value = _make_litellm_response("hi") - adapter = OpenAIAgent(id="r1", config={"name": "o1-mini", "thinking": True}) - adapter.handle_request({"prompt": "hello"}) - kwargs = mock_completion.call_args.kwargs - self.assertEqual(kwargs.get("reasoning_effort"), "medium") - self.assertNotIn("thinking", kwargs) - - @patch("litellm.completion") - def test_thinking_false_on_reasoning_model_omits_effort(self, mock_completion): - mock_completion.return_value = _make_litellm_response("hi") - adapter = OpenAIAgent(id="r2", config={"name": "o3", "thinking": False}) - adapter.handle_request({"prompt": "hello"}) - kwargs = mock_completion.call_args.kwargs - self.assertNotIn("reasoning_effort", kwargs) - self.assertNotIn("thinking", kwargs) - - @patch("litellm.completion") - def test_thinking_string_passes_through_as_effort(self, mock_completion): - mock_completion.return_value = _make_litellm_response("hi") - adapter = OpenAIAgent(id="r3", config={"name": "o1"}) - adapter.handle_request({"prompt": "hello", "thinking": "high"}) - self.assertEqual( - mock_completion.call_args.kwargs.get("reasoning_effort"), "high" - ) - - @patch("litellm.completion") - def test_thinking_on_non_reasoning_model_passes_through_generically( - self, mock_completion - ): - mock_completion.return_value = _make_litellm_response("hi") - adapter = OpenAIAgent(id="r4", config={"name": "gpt-4"}) - adapter.handle_request({"prompt": "hello", "thinking": True}) - kwargs = mock_completion.call_args.kwargs - # Non-reasoning OpenAI models get the generic LiteLLM thinking dict. - self.assertEqual(kwargs.get("thinking"), {"type": "enabled"}) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unit/router/test_router.py b/tests/unit/router/test_router.py index 29badbb9..1f780e19 100644 --- a/tests/unit/router/test_router.py +++ b/tests/unit/router/test_router.py @@ -42,19 +42,15 @@ def _make_backend(org_id=None, user_id="test_user"): class TestAgentRouterInitialization(unittest.TestCase): - @patch("hackagent.router.router.LiteLLMAgent", autospec=True) @patch("hackagent.router.router.ADKAgent", autospec=True) @patch("hackagent.router.router.AGENT_TYPE_TO_ADAPTER_MAP", new_callable=dict) def test_agent_router_init_creates_new_agent_if_not_exists( self, MockAgentMap, MockADKAdapter, - MockLiteLLMAdapter, ): MockAgentMap[AgentTypeEnum.GOOGLE_ADK] = MockADKAdapter - MockAgentMap[AgentTypeEnum.LITELLM] = MockLiteLLMAdapter MockADKAdapter.__name__ = "ADKAgent" - MockLiteLLMAdapter.__name__ = "LiteLLMAgent" mock_org_id = uuid.uuid4() mock_backend = _make_backend(org_id=mock_org_id, user_id="123") @@ -103,7 +99,6 @@ def test_agent_router_init_creates_new_agent_if_not_exists( "endpoint": agent_endpoint, }, ) - MockLiteLLMAdapter.assert_not_called() self.assertEqual(router.backend, mock_backend) self.assertIsNotNone(router.backend_agent) self.assertEqual(router.backend_agent.id, mock_created_agent_id) @@ -112,19 +107,15 @@ def test_agent_router_init_creates_new_agent_if_not_exists( router._agent_registry[str(mock_created_agent_id)], mock_adk_instance ) - @patch("hackagent.router.router.LiteLLMAgent", autospec=True) @patch("hackagent.router.router.ADKAgent", autospec=True) @patch("hackagent.router.router.AGENT_TYPE_TO_ADAPTER_MAP", new_callable=dict) def test_agent_router_init_updates_existing_agent_if_metadata_differs( self, MockAgentMap, MockADKAdapter, - MockLiteLLMAdapter, ): MockAgentMap[AgentTypeEnum.GOOGLE_ADK] = MockADKAdapter - MockAgentMap[AgentTypeEnum.LITELLM] = MockLiteLLMAdapter MockADKAdapter.__name__ = "ADKAgent" - MockLiteLLMAdapter.__name__ = "LiteLLMAgent" mock_org_id = uuid.uuid4() mock_backend = _make_backend(org_id=mock_org_id, user_id="456") @@ -429,15 +420,7 @@ def test_ollama_chat_registration_has_str_endpoint(self): class TestMetadataNoneStripping(unittest.TestCase): - @patch("hackagent.router.router.OllamaAgent", autospec=True) - @patch("hackagent.router.router.AGENT_TYPE_TO_ADAPTER_MAP", new_callable=dict) - def test_none_values_stripped_from_metadata_on_create( - self, - MockAgentMap, - MockOllamaAdapter, - ): - MockAgentMap[AgentTypeEnum.OLLAMA] = MockOllamaAdapter - MockOllamaAdapter.__name__ = "OllamaAgent" + def test_none_values_stripped_from_metadata_on_create(self): mock_backend = _make_backend() mock_backend.create_or_update_agent.return_value = _make_agent_rec( agent_type_str="OLLAMA", @@ -471,15 +454,7 @@ def test_none_values_stripped_from_metadata_on_create( self.assertNotIn("api_key", sent_metadata) self.assertNotIn("max_tokens", sent_metadata) - @patch("hackagent.router.router.OllamaAgent", autospec=True) - @patch("hackagent.router.router.AGENT_TYPE_TO_ADAPTER_MAP", new_callable=dict) - def test_none_values_stripped_from_metadata_on_update( - self, - MockAgentMap, - MockOllamaAdapter, - ): - MockAgentMap[AgentTypeEnum.OLLAMA] = MockOllamaAdapter - MockOllamaAdapter.__name__ = "OllamaAgent" + def test_none_values_stripped_from_metadata_on_update(self): mock_backend = _make_backend() mock_backend.create_or_update_agent.return_value = _make_agent_rec( agent_type_str="OLLAMA", @@ -510,15 +485,7 @@ def test_none_values_stripped_from_metadata_on_update( class TestAgentPagination(unittest.TestCase): - @patch("hackagent.router.router.LiteLLMAgent", autospec=True) - @patch("hackagent.router.router.AGENT_TYPE_TO_ADAPTER_MAP", new_callable=dict) - def test_agent_found_on_page_two_is_not_recreated( - self, - MockAgentMap, - MockLiteLLMAdapter, - ): - MockAgentMap[AgentTypeEnum.LITELLM] = MockLiteLLMAdapter - MockLiteLLMAdapter.__name__ = "LiteLLMAgent" + def test_agent_found_on_page_two_is_not_recreated(self): mock_backend = _make_backend() target_agent_id = uuid.uuid4() agent_name = "llama2-uncensored" From fe5c624bcdf08bd3dc6a71b0618475951dff3337 Mon Sep 17 00:00:00 2001 From: Nicola Franco Date: Sat, 23 May 2026 17:47:24 +0200 Subject: [PATCH 11/23] refactor(router): namespace metadata under metadata['hackagent'] (#379 Phase F.2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase F.2 cleans up the LiteLLM ``metadata`` correlation keys. Instead of flat ``"hackagent_agent_id"`` / ``"hackagent_adapter_type"`` keys that sit alongside whatever the caller (Langfuse, OTEL, user code…) also stuffs into ``metadata``, the router now writes a nested block: metadata = { "hackagent": {"id": "", "adapter_type": "