From d6eb8066c8de36b1330281f76dedbfcac4a8d7e1 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Mon, 30 Mar 2026 16:32:57 -0400 Subject: [PATCH 1/2] feat(models): add OCI Generative AI provider for Google Gemini on OCI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds first-class support for Google Gemini models hosted on Oracle Cloud Infrastructure (OCI) Generative AI service — a native Google × OCI model partnership that makes Gemini available directly through OCI's inference endpoints. Key design points: - Subclasses BaseLlm following the anthropic_llm.py pattern - Uses the OCI Python SDK directly (no LangChain dependency) - Optional dependency: pip install google-adk[oci] - Supports API_KEY, INSTANCE_PRINCIPAL, and RESOURCE_PRINCIPAL auth - Both non-streaming (_call_oci) and streaming (_call_oci_stream) paths share setup code via _build_chat_details(); streaming collects OCI's OpenAI-compatible SSE events in a thread pool (asyncio.to_thread) and yields partial then final LlmResponse - Registers google.gemini-* (and other OCI-hosted) model patterns in LLMRegistry via optional try/except in models/__init__.py - 37 unit tests (fully mocked, no OCI account needed) - 10 integration tests (skipped when OCI_COMPARTMENT_ID is unset) Supported models: google.gemini-*, google.gemma-*, meta.llama-*, mistralai.*, xai.grok-*, nvidia.* --- pyproject.toml | 2 + src/google/adk/models/__init__.py | 14 + src/google/adk/models/oci_genai_llm.py | 633 +++++++++ .../integration/models/test_oci_genai_llm.py | 623 ++++++++ tests/unittests/models/test_oci_genai_llm.py | 1259 +++++++++++++++++ 5 files changed, 2531 insertions(+) create mode 100644 src/google/adk/models/oci_genai_llm.py create mode 100644 tests/integration/models/test_oci_genai_llm.py create mode 100644 tests/unittests/models/test_oci_genai_llm.py diff --git a/pyproject.toml b/pyproject.toml index d7bacd8a10..75b8c6fb8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -191,6 +191,8 @@ strict = true disable_error_code = [ "import-not-found", "import-untyped", "unused-ignore" ] follow_imports = "skip" +oci = ["oci>=2.126.0"] # For OCI Generative AI model support + [tool.pyink] # Format py files following Google style-guide line-length = 80 diff --git a/src/google/adk/models/__init__.py b/src/google/adk/models/__init__.py index 1f42993c88..a394b6feee 100644 --- a/src/google/adk/models/__init__.py +++ b/src/google/adk/models/__init__.py @@ -31,6 +31,7 @@ from .gemma_llm import Gemma3Ollama from .google_llm import Gemini from .lite_llm import LiteLlm + from .oci_genai_llm import OCIGenAILlm __all__ = [ 'ApigeeLlm', @@ -41,6 +42,7 @@ 'Gemma3Ollama', 'LLMRegistry', 'LiteLlm', + 'OCIGenAILlm', ] _LAZY_PROVIDERS: dict[str, tuple[list[str], str]] = { @@ -78,6 +80,18 @@ ], 'lite_llm', ), + 'OCIGenAILlm': ( + [ + r'meta\.llama-.*', + r'google\.gemini-.*', + r'google\.gemma-.*', + r'xai\.grok-.*', + r'mistralai\.mistral-.*', + r'mistralai\.mixtral-.*', + r'nvidia\..*', + ], + 'oci_genai_llm', + ), } for _name, (_patterns, _module) in _LAZY_PROVIDERS.items(): diff --git a/src/google/adk/models/oci_genai_llm.py b/src/google/adk/models/oci_genai_llm.py new file mode 100644 index 0000000000..779b7ed5f1 --- /dev/null +++ b/src/google/adk/models/oci_genai_llm.py @@ -0,0 +1,633 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OCI Generative AI integration for ADK models.""" + +from __future__ import annotations + +import asyncio +import base64 +import importlib.util +import json +import logging +import os +import threading +from typing import Any +from typing import AsyncGenerator +from typing import Optional +from typing import TYPE_CHECKING + +from google.genai import types +from typing_extensions import override + +if not TYPE_CHECKING and importlib.util.find_spec("oci") is None: + raise ImportError( + "OCI Generative AI support requires: pip install google-adk[oci]" + "\nOr: pip install oci" + ) + +from .base_llm import BaseLlm +from .llm_response import LlmResponse + +if TYPE_CHECKING: + from .llm_request import LlmRequest + +__all__ = ["OCIGenAILlm"] + +logger = logging.getLogger("google_adk." + __name__) + + +def _to_oci_role(role: Optional[str]) -> str: + """Map ADK content role to OCI GenAI role string.""" + if role in ("model", "assistant"): + return "ASSISTANT" + return "USER" + + +def _build_response_format( + cfg: types.GenerateContentConfig, oci_models: Any +) -> Optional[Any]: + """Map google.genai response config to OCI ResponseFormat. + + - ``response_schema`` (Pydantic class, dict, or genai Schema) → + ``JsonSchemaResponseFormat`` (strict structured output). + - ``response_mime_type == "application/json"`` only → + ``JsonObjectResponseFormat``. + - ``response_mime_type == "text/plain"`` → ``TextResponseFormat`` (default + behaviour; only emitted when explicitly requested). + """ + schema = cfg.response_schema + mime = cfg.response_mime_type or "" + + if schema is not None: + schema_dict: dict[str, Any] + if hasattr(schema, "model_json_schema"): + # Pydantic v2 model class + schema_dict = schema.model_json_schema() + elif hasattr(schema, "to_json_dict"): + # google.genai Schema instance + schema_dict = schema.to_json_dict() + elif isinstance(schema, dict): + schema_dict = schema + else: + return None + return oci_models.JsonSchemaResponseFormat( + type="JSON_SCHEMA", + json_schema=oci_models.ResponseJsonSchema( + name=schema_dict.get("title", "response"), + description=schema_dict.get("description"), + schema=schema_dict, + is_strict=True, + ), + ) + + if mime == "application/json": + return oci_models.JsonObjectResponseFormat(type="JSON_OBJECT") + if mime == "text/plain": + return oci_models.TextResponseFormat(type="TEXT") + return None + + +def _media_blocks_for_part(part: types.Part) -> list[Any]: + """Map a multimodal Part (inline_data / file_data) to OCI ChatContent blocks. + + OCI Generative AI Inference (/20231130/) accepts ImageContent / AudioContent / + VideoContent / DocumentContent, each carrying a URL in ``{kind}_url.url``. + Inline bytes are wrapped as ``data:;base64,<...>``; file_data passes + the ``file_uri`` through. + + Returns an empty list for parts that have no media payload. + """ + import oci.generative_ai_inference.models as oci_models + + url: Optional[str] = None + mime: Optional[str] = None + + if part.inline_data and part.inline_data.data is not None: + mime = part.inline_data.mime_type or "application/octet-stream" + raw = part.inline_data.data + if isinstance(raw, (bytes, bytearray)): + encoded = base64.b64encode(bytes(raw)).decode("ascii") + else: + encoded = str(raw) + url = f"data:{mime};base64,{encoded}" + elif part.file_data and part.file_data.file_uri: + url = part.file_data.file_uri + mime = part.file_data.mime_type + + if not url: + return [] + + category = (mime or "").split("/", 1)[0].lower() + if category == "image": + return [oci_models.ImageContent( + type="IMAGE", image_url=oci_models.ImageUrl(url=url) + )] + if category == "audio": + return [oci_models.AudioContent( + type="AUDIO", audio_url=oci_models.AudioUrl(url=url) + )] + if category == "video": + return [oci_models.VideoContent( + type="VIDEO", video_url=oci_models.VideoUrl(url=url) + )] + # Documents (application/pdf, text/*, etc.) and any other mime + return [oci_models.DocumentContent( + type="DOCUMENT", document_url=oci_models.DocumentUrl(url=url) + )] + + +def _content_to_oci_message(content: types.Content) -> Any: + """Convert an ADK Content object to an OCI GenAI message. + + OCI GenAI uses: + - ``UserMessage`` for user turns + - ``AssistantMessage`` for model turns (may include ``FunctionCall`` items + in ``tool_calls``) + - ``ToolMessage`` for tool results (function_response parts) + """ + import oci.generative_ai_inference.models as oci_models + + text_parts: list[str] = [] + media_blocks: list[Any] = [] + tool_calls: list[Any] = [] + tool_results: list[tuple[str, str]] = [] # (tool_call_id, result_text) + + for part in content.parts or []: + if part.text: + text_parts.append(part.text) + elif part.function_call: + # FunctionCall is the OCI subtype of ToolCall that carries name+arguments + tool_calls.append( + oci_models.FunctionCall( + id=part.function_call.id or "", + type=oci_models.FunctionCall.TYPE_FUNCTION, + name=part.function_call.name, + arguments=json.dumps(part.function_call.args or {}), + ) + ) + elif part.function_response: + result = part.function_response.response or {} + tool_results.append(( + part.function_response.id or "", + json.dumps(result) if isinstance(result, dict) else str(result), + )) + elif part.inline_data or part.file_data: + media_blocks.extend(_media_blocks_for_part(part)) + + role = _to_oci_role(content.role) + + # Tool results map to ToolMessage (one per result) + if tool_results: + call_id, result_text = tool_results[0] + return oci_models.ToolMessage( + role=oci_models.ToolMessage.ROLE_TOOL, + tool_call_id=call_id, + content=[oci_models.TextContent(type="TEXT", text=result_text)], + ) + + if role == "ASSISTANT": + oci_content: list[Any] = [] + if text_parts: + oci_content.append( + oci_models.TextContent(type="TEXT", text="\n".join(text_parts)) + ) + return oci_models.AssistantMessage( + role=oci_models.AssistantMessage.ROLE_ASSISTANT, + content=oci_content, + tool_calls=tool_calls or None, + ) + + user_content: list[Any] = [] + if text_parts: + user_content.append( + oci_models.TextContent(type="TEXT", text="\n".join(text_parts)) + ) + user_content.extend(media_blocks) + return oci_models.UserMessage( + role=oci_models.UserMessage.ROLE_USER, + content=user_content, + ) + + +def _oci_response_to_llm_response(response: Any) -> LlmResponse: + """Convert an OCI GenAI chat response to an LlmResponse.""" + chat_response = response.data.chat_response + parts: list[types.Part] = [] + input_tokens = 0 + output_tokens = 0 + reasoning_tokens = 0 + + if hasattr(chat_response, "usage"): + usage = chat_response.usage + input_tokens = getattr(usage, "prompt_tokens", 0) or 0 + output_tokens = getattr(usage, "completion_tokens", 0) or 0 + details = getattr(usage, "completion_tokens_details", None) + if details is not None: + reasoning_tokens = getattr(details, "reasoning_tokens", 0) or 0 + + if hasattr(chat_response, "choices") and chat_response.choices: + choice = chat_response.choices[0] + message = getattr(choice, "message", None) + if message: + # Text content + for block in getattr(message, "content", None) or []: + if hasattr(block, "text") and block.text: + parts.append(types.Part.from_text(text=block.text)) + + # Tool calls — OCI returns FunctionCall objects directly in tool_calls + for fc in getattr(message, "tool_calls", None) or []: + args: dict[str, Any] = {} + try: + args = json.loads(fc.arguments) if fc.arguments else {} + except (json.JSONDecodeError, TypeError): + args = {} + part = types.Part.from_function_call( + name=fc.name, + args=args, + ) + part.function_call.id = getattr(fc, "id", "") or "" + parts.append(part) + + return LlmResponse( + content=types.Content(role="model", parts=parts), + usage_metadata=types.GenerateContentResponseUsageMetadata( + prompt_token_count=input_tokens, + candidates_token_count=output_tokens, + total_token_count=input_tokens + output_tokens, + thoughts_token_count=reasoning_tokens or None, + ), + ) + + +def _function_declaration_to_oci_tool( + fn: types.FunctionDeclaration, +) -> Any: + """Convert an ADK FunctionDeclaration to an OCI GenAI Tool.""" + import oci.generative_ai_inference.models as oci_models + + parameters: dict[str, Any] = {"type": "object", "properties": {}} + if fn.parameters_json_schema: + parameters = fn.parameters_json_schema + elif fn.parameters and fn.parameters.properties: + props = {} + for k, v in fn.parameters.properties.items(): + props[k] = v.model_dump(by_alias=True, exclude_none=True) + parameters = { + "type": "object", + "properties": props, + } + if fn.parameters.required: + parameters["required"] = fn.parameters.required + + return oci_models.FunctionDefinition( + type=oci_models.FunctionDefinition.TYPE_FUNCTION, + name=fn.name, + description=fn.description or "", + parameters=parameters, + ) + + +class OCIGenAILlm(BaseLlm): + """Integration with OCI Generative AI models. + + Supports models hosted on Oracle Cloud Infrastructure Generative AI service, + including Meta Llama, Google Gemini, Google Gemma, and other GenericChat + compatible models. + + Example usage:: + + from google.adk.models.oci_genai_llm import OCIGenAILlm + from google.adk.agents import LlmAgent + + agent = LlmAgent( + model=OCIGenAILlm( + model="google.gemini-2.0-flash-001", + compartment_id="ocid1.compartment.oc1...", + ), + ... + ) + + Attributes: + model: OCI model ID (e.g. ``google.gemini-2.0-flash-001``). Used as the + ``model_id`` for on-demand serving. For dedicated serving, set + ``endpoint_id`` instead; ``model`` is then informational only. + endpoint_id: Dedicated endpoint OCID (``ocid1.generativeaiendpoint...``). + When set, requests use ``DedicatedServingMode``; otherwise on-demand + mode is used. Falls back to ``OCI_ENDPOINT_ID`` env var when not set. + compartment_id: OCI compartment OCID. Falls back to the + ``OCI_COMPARTMENT_ID`` environment variable when not set. + service_endpoint: OCI Generative AI service endpoint URL. Defaults to + the us-chicago-1 endpoint or ``OCI_SERVICE_ENDPOINT`` env var. + auth_type: OCI authentication type. One of ``API_KEY`` (default), + ``INSTANCE_PRINCIPAL``, or ``RESOURCE_PRINCIPAL``. + auth_profile: Config profile to use for ``API_KEY`` auth (default: + ``DEFAULT``). + auth_file_location: Path to the OCI config file used for ``API_KEY`` + auth (default: ``~/.oci/config``). + max_tokens: Maximum number of tokens to generate (default: 2048). + """ + + model: str = "google.gemini-2.5-flash" + endpoint_id: Optional[str] = None + compartment_id: Optional[str] = None + service_endpoint: Optional[str] = None + auth_type: str = "API_KEY" + auth_profile: str = "DEFAULT" + auth_file_location: str = "~/.oci/config" + max_tokens: int = 2048 + + @classmethod + @override + def supported_models(cls) -> list[str]: + return [ + r"meta\.llama-.*", + r"google\.gemini-.*", + r"google\.gemma-.*", + r"xai\.grok-.*", + r"mistralai\.mistral-.*", + r"mistralai\.mixtral-.*", + r"nvidia\..*", + ] + + @override + async def generate_content_async( + self, + llm_request: LlmRequest, + stream: bool = False, + ) -> AsyncGenerator[LlmResponse, None]: + if stream: + async for response in self._generate_content_streaming(llm_request): + yield response + else: + response = await asyncio.to_thread(self._call_oci, llm_request) + yield _oci_response_to_llm_response(response) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _resolve_compartment_id(self) -> str: + compartment_id = self.compartment_id or os.environ.get("OCI_COMPARTMENT_ID") + if not compartment_id: + raise ValueError( + "compartment_id must be set on OCIGenAILlm or via the" + " OCI_COMPARTMENT_ID environment variable." + ) + return compartment_id + + def _resolve_service_endpoint(self) -> str: + return ( + self.service_endpoint + or os.environ.get("OCI_SERVICE_ENDPOINT") + or "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com" + ) + + def _build_client(self, service_endpoint: str) -> Any: + """Create an OCI GenerativeAiInferenceClient from auth config.""" + import oci + import oci.auth.signers + import oci.generative_ai_inference + + if self.auth_type == "INSTANCE_PRINCIPAL": + signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + return oci.generative_ai_inference.GenerativeAiInferenceClient( + config={}, + signer=signer, + service_endpoint=service_endpoint, + ) + elif self.auth_type == "RESOURCE_PRINCIPAL": + signer = oci.auth.signers.get_resource_principals_signer() + return oci.generative_ai_inference.GenerativeAiInferenceClient( + config={}, + signer=signer, + service_endpoint=service_endpoint, + ) + else: # API_KEY (default) + config = oci.config.from_file( + file_location=self.auth_file_location, + profile_name=self.auth_profile, + ) + return oci.generative_ai_inference.GenerativeAiInferenceClient( + config=config, + service_endpoint=service_endpoint, + ) + + def _build_chat_details( + self, llm_request: LlmRequest, is_stream: bool = False + ) -> Any: + """Build OCI ChatDetails from an LlmRequest.""" + import oci.generative_ai_inference.models as oci_models + + messages = [ + _content_to_oci_message(c) for c in llm_request.contents or [] + ] + + # Prepend SystemMessage when a system instruction is present + if llm_request.config and llm_request.config.system_instruction: + si = llm_request.config.system_instruction + if isinstance(si, str) and si: + messages = [ + oci_models.SystemMessage( + role=oci_models.SystemMessage.ROLE_SYSTEM, + content=[oci_models.TextContent(type="TEXT", text=si)], + ) + ] + messages + + # Convert tool declarations if present + oci_tools: Optional[list[Any]] = None + if ( + llm_request.config + and llm_request.config.tools + and llm_request.config.tools[0].function_declarations + ): + oci_tools = [ + _function_declaration_to_oci_tool(fn) + for fn in llm_request.config.tools[0].function_declarations + ] + + chat_request_kwargs: dict[str, Any] = dict( + api_format=oci_models.BaseChatRequest.API_FORMAT_GENERIC, + messages=messages, + max_tokens=self.max_tokens, + ) + + # Sampling and decoding parameters from llm_request.config. + cfg = getattr(llm_request, "config", None) + if cfg is not None: + if cfg.max_output_tokens is not None: + chat_request_kwargs["max_tokens"] = cfg.max_output_tokens + if cfg.temperature is not None: + chat_request_kwargs["temperature"] = cfg.temperature + if cfg.top_p is not None: + chat_request_kwargs["top_p"] = cfg.top_p + if cfg.top_k is not None: + chat_request_kwargs["top_k"] = int(cfg.top_k) + if cfg.frequency_penalty is not None: + chat_request_kwargs["frequency_penalty"] = cfg.frequency_penalty + if cfg.presence_penalty is not None: + chat_request_kwargs["presence_penalty"] = cfg.presence_penalty + if cfg.seed is not None: + chat_request_kwargs["seed"] = cfg.seed + if cfg.stop_sequences: + chat_request_kwargs["stop"] = list(cfg.stop_sequences) + + # Structured-output: response_schema (Pydantic / dict / google.genai + # Schema) → JsonSchemaResponseFormat. response_mime_type alone (without + # a schema) → JsonObjectResponseFormat (free-form JSON). + response_format = _build_response_format(cfg, oci_models) + if response_format is not None: + chat_request_kwargs["response_format"] = response_format + + if oci_tools: + chat_request_kwargs["tools"] = oci_tools + if is_stream: + chat_request_kwargs["is_stream"] = True + chat_request_kwargs["stream_options"] = oci_models.StreamOptions( + is_include_usage=True + ) + + return oci_models.ChatDetails( + compartment_id=self._resolve_compartment_id(), + serving_mode=self._build_serving_mode(oci_models), + chat_request=oci_models.GenericChatRequest(**chat_request_kwargs), + ) + + def _build_serving_mode(self, oci_models: Any) -> Any: + endpoint_id = self.endpoint_id or os.environ.get("OCI_ENDPOINT_ID") + if endpoint_id: + return oci_models.DedicatedServingMode(endpoint_id=endpoint_id) + return oci_models.OnDemandServingMode(model_id=self.model) + + def _call_oci(self, llm_request: LlmRequest) -> Any: + """Synchronous non-streaming OCI GenAI call, run in a thread pool.""" + client = self._build_client(self._resolve_service_endpoint()) + chat_details = self._build_chat_details(llm_request, is_stream=False) + logger.debug("Sending request to OCI GenAI: model=%s", self.model) + return client.chat(chat_details) + + def _call_oci_stream(self, llm_request: LlmRequest) -> list[dict[str, Any]]: + """Synchronous streaming call — collects all SSE event dicts in a thread. + + The OCI SDK wraps an SSE response in ``oci._vendor.sseclient.SSEClient`` + when ``is_stream=True`` is set on the request body. Each event's + ``data`` field is an OpenAI-compatible JSON chunk or the sentinel + ``[DONE]``. + """ + client = self._build_client(self._resolve_service_endpoint()) + chat_details = self._build_chat_details(llm_request, is_stream=True) + logger.debug( + "Sending streaming request to OCI GenAI: model=%s", self.model + ) + response = client.chat(chat_details) + + chunks: list[dict[str, Any]] = [] + try: + for event in response.data.events(): + raw = getattr(event, "data", None) + if not raw or raw.strip() == "[DONE]": + break + try: + chunks.append(json.loads(raw)) + except (json.JSONDecodeError, TypeError): + logger.debug("Could not parse SSE event data: %r", raw) + finally: + close = getattr(response.data, "close", None) + if callable(close): + close() + return chunks + + async def _generate_content_streaming( + self, llm_request: LlmRequest + ) -> AsyncGenerator[LlmResponse, None]: + """Yield partial then final LlmResponse from an OCI SSE stream. + + The OCI SDK is fully synchronous, so we collect all SSE chunks in a + background thread via ``asyncio.to_thread`` and then yield responses + from the accumulated data — matching the pattern used by + ``AnthropicLlm._generate_content_streaming``. + """ + chunks = await asyncio.to_thread(self._call_oci_stream, llm_request) + + text_acc: str = "" + tool_acc: dict[int, dict[str, Any]] = {} + input_tokens: int = 0 + output_tokens: int = 0 + reasoning_tokens: int = 0 + + for chunk in chunks: + # Usage chunk (camelCase per OCI GenAI /20231130/ schema). + usage = chunk.get("usage") + if usage: + input_tokens = usage.get("promptTokens", 0) or 0 + output_tokens = usage.get("completionTokens", 0) or 0 + details = usage.get("completionTokensDetails") or {} + reasoning_tokens = details.get("reasoningTokens", 0) or 0 + continue + + message = chunk.get("message") + if not message: + continue + + # Text content: list of {type: TEXT, text: ...} blocks. + for block in message.get("content") or []: + if block.get("type") == "TEXT" and block.get("text"): + delta_text = block["text"] + text_acc += delta_text + yield LlmResponse( + content=types.Content( + role="model", + parts=[types.Part.from_text(text=delta_text)], + ), + partial=True, + ) + + # Tool calls: OCI emits the whole call in one chunk for Gemini, but + # accumulate name/arguments defensively in case other providers split + # them across events. + for tc_idx, tc in enumerate(message.get("toolCalls") or []): + idx = tc.get("index", tc_idx) + if idx not in tool_acc: + tool_acc[idx] = {"id": "", "name": "", "arguments": ""} + if tc.get("id"): + tool_acc[idx]["id"] = tc["id"] + if tc.get("name"): + tool_acc[idx]["name"] = tc["name"] + if tc.get("arguments"): + tool_acc[idx]["arguments"] += tc["arguments"] + + # Build final aggregated response + all_parts: list[types.Part] = [] + if text_acc: + all_parts.append(types.Part.from_text(text=text_acc)) + for tc in sorted(tool_acc.values(), key=lambda x: x.get("name", "")): + args: dict[str, Any] = {} + try: + args = json.loads(tc["arguments"]) if tc["arguments"] else {} + except (json.JSONDecodeError, TypeError): + args = {} + part = types.Part.from_function_call(name=tc["name"], args=args) + part.function_call.id = tc["id"] + all_parts.append(part) + + yield LlmResponse( + content=types.Content(role="model", parts=all_parts), + usage_metadata=types.GenerateContentResponseUsageMetadata( + prompt_token_count=input_tokens, + candidates_token_count=output_tokens, + total_token_count=input_tokens + output_tokens, + thoughts_token_count=reasoning_tokens or None, + ), + partial=False, + ) diff --git a/tests/integration/models/test_oci_genai_llm.py b/tests/integration/models/test_oci_genai_llm.py new file mode 100644 index 0000000000..cd533625de --- /dev/null +++ b/tests/integration/models/test_oci_genai_llm.py @@ -0,0 +1,623 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for OCIGenAILlm against live OCI Generative AI service. + +Required environment variables: + OCI_COMPARTMENT_ID — OCI compartment OCID + OCI_REGION — OCI region (default: us-chicago-1) + +Optional: + OCI_AUTH_TYPE — API_KEY | INSTANCE_PRINCIPAL | RESOURCE_PRINCIPAL + (default: API_KEY) + OCI_AUTH_PROFILE — OCI config profile (default: DEFAULT) + OCI_AUTH_FILE — path to OCI config file (default: ~/.oci/config) +""" + +import json +import os + +from google.adk.models.llm_request import LlmRequest +from google.adk.models.oci_genai_llm import OCIGenAILlm +from google.genai import types +from google.genai.types import Content +from google.genai.types import Part +import pytest + + +# --------------------------------------------------------------------------- +# Skip the entire module when required env vars are absent +# --------------------------------------------------------------------------- + +# OCI tests do not use any Google backend (GOOGLE_AI / Vertex AI). +# Override the autouse llm_backend fixture from the integration conftest so +# these tests are not duplicated across backends. +@pytest.fixture(autouse=True) +def llm_backend(): + yield + + +pytestmark = pytest.mark.skipif( + not os.environ.get("OCI_COMPARTMENT_ID"), + reason=( + "OCI integration tests require OCI_COMPARTMENT_ID to be set. " + "Set OCI_COMPARTMENT_ID (and optionally OCI_REGION) to run." + ), +) + +_COMPARTMENT_ID = os.environ.get("OCI_COMPARTMENT_ID", "") +_REGION = os.environ.get("OCI_REGION", "us-chicago-1") +_SERVICE_ENDPOINT = ( + f"https://inference.generativeai.{_REGION}.oci.oraclecloud.com" +) +_AUTH_TYPE = os.environ.get("OCI_AUTH_TYPE", "API_KEY") +_AUTH_PROFILE = os.environ.get("OCI_AUTH_PROFILE", "DEFAULT") +_AUTH_FILE = os.environ.get("OCI_AUTH_FILE", "~/.oci/config") + +_GEMINI_MODEL = "google.gemini-2.5-flash" + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def gemini_llm() -> OCIGenAILlm: + return OCIGenAILlm( + model=_GEMINI_MODEL, + compartment_id=_COMPARTMENT_ID, + service_endpoint=_SERVICE_ENDPOINT, + auth_type=_AUTH_TYPE, + auth_profile=_AUTH_PROFILE, + auth_file_location=_AUTH_FILE, + max_tokens=512, + ) + + + +def _simple_request(model: str, text: str = "Reply with one word: hello.") -> LlmRequest: + return LlmRequest( + model=model, + contents=[Content(role="user", parts=[Part.from_text(text=text)])], + ) + + +def _request_with_system(model: str) -> LlmRequest: + return LlmRequest( + model=model, + contents=[ + Content( + role="user", + parts=[Part.from_text(text="What is your name?")], + ) + ], + config=types.GenerateContentConfig( + system_instruction="Your name is Oracle. Always introduce yourself as Oracle.", + ), + ) + + +def _request_with_tool(model: str) -> LlmRequest: + return LlmRequest( + model=model, + contents=[ + Content( + role="user", + parts=[Part.from_text(text="What is the weather in Chicago?")], + ) + ], + config=types.GenerateContentConfig( + tools=[ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name="get_weather", + description="Get the current weather for a city.", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "city": types.Schema( + type=types.Type.STRING, + description="The city name.", + ) + }, + required=["city"], + ), + ) + ] + ) + ] + ), + ) + + +# --------------------------------------------------------------------------- +# Gemini (google.gemini-2.0-flash-001) tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_gemini_generate_content_text(gemini_llm): + """Gemini on OCI returns a non-empty text response.""" + responses = [ + r + async for r in gemini_llm.generate_content_async( + _simple_request(_GEMINI_MODEL), stream=False + ) + ] + assert len(responses) == 1 + assert responses[0].content.role == "model" + assert responses[0].content.parts + assert responses[0].content.parts[0].text.strip() + + +@pytest.mark.asyncio +async def test_gemini_generate_content_usage_metadata(gemini_llm): + """Response includes token usage metadata.""" + responses = [ + r + async for r in gemini_llm.generate_content_async( + _simple_request(_GEMINI_MODEL), stream=False + ) + ] + usage = responses[0].usage_metadata + assert usage.prompt_token_count > 0 + assert usage.candidates_token_count > 0 + assert usage.total_token_count == ( + usage.prompt_token_count + usage.candidates_token_count + ) + + +@pytest.mark.asyncio +async def test_gemini_generate_content_with_system_instruction(gemini_llm): + """System instruction is respected.""" + responses = [ + r + async for r in gemini_llm.generate_content_async( + _request_with_system(_GEMINI_MODEL), stream=False + ) + ] + text = responses[0].content.parts[0].text.lower() + assert "oracle" in text + + +@pytest.mark.asyncio +async def test_gemini_generate_content_tool_call(gemini_llm): + """Gemini returns a function call when a tool is provided.""" + responses = [ + r + async for r in gemini_llm.generate_content_async( + _request_with_tool(_GEMINI_MODEL), stream=False + ) + ] + parts = responses[0].content.parts + function_calls = [p for p in parts if p.function_call] + assert function_calls, "Expected at least one function call in the response" + fc = function_calls[0].function_call + assert fc.name == "get_weather" + assert "city" in fc.args + + +@pytest.mark.asyncio +async def test_gemini_generate_content_streaming_text(gemini_llm): + """Streaming returns partial chunks followed by a final non-partial response.""" + responses = [ + r + async for r in gemini_llm.generate_content_async( + _simple_request(_GEMINI_MODEL), stream=True + ) + ] + assert responses, "Expected at least one response chunk" + partial_responses = [r for r in responses if r.partial] + final_responses = [r for r in responses if not r.partial] + assert partial_responses, "Expected at least one partial (streaming) chunk" + assert len(final_responses) == 1, "Expected exactly one final (non-partial) response" + full_text = "".join( + p.text + for r in partial_responses + for p in (r.content.parts or []) + if p.text + ) + assert full_text.strip(), "Streamed text should be non-empty" + + +@pytest.mark.asyncio +async def test_gemini_generate_content_streaming_usage_metadata(gemini_llm): + """Final streaming response includes token usage metadata.""" + responses = [ + r + async for r in gemini_llm.generate_content_async( + _simple_request(_GEMINI_MODEL), stream=True + ) + ] + final = next(r for r in responses if not r.partial) + usage = final.usage_metadata + assert usage is not None + assert usage.prompt_token_count > 0 + assert usage.candidates_token_count > 0 + assert usage.total_token_count == ( + usage.prompt_token_count + usage.candidates_token_count + ) + + +@pytest.mark.asyncio +async def test_gemini_generate_content_streaming_tool_call(gemini_llm): + """Streaming returns a function call when a tool is provided.""" + responses = [ + r + async for r in gemini_llm.generate_content_async( + _request_with_tool(_GEMINI_MODEL), stream=True + ) + ] + final = next(r for r in responses if not r.partial) + parts = final.content.parts or [] + function_calls = [p for p in parts if p.function_call] + assert function_calls, "Expected at least one function call in the streaming response" + fc = function_calls[0].function_call + assert fc.name == "get_weather" + assert "city" in fc.args + + +@pytest.mark.asyncio +async def test_gemini_generate_content_concurrent(gemini_llm): + """Multiple concurrent non-streaming requests complete independently.""" + import asyncio + + async def single_call(text: str) -> str: + responses = [ + r + async for r in gemini_llm.generate_content_async( + _simple_request(_GEMINI_MODEL, text=text), stream=False + ) + ] + return responses[0].content.parts[0].text + + results = await asyncio.gather(*[ + single_call(f"Reply with the number {i} only.") + for i in range(3) + ]) + assert len(results) == 3 + for result in results: + assert result.strip(), "Each concurrent response should be non-empty" + + +@pytest.mark.asyncio +async def test_gemini_multi_turn(gemini_llm): + """Multi-turn conversation passes history correctly.""" + history = [ + Content(role="user", parts=[Part.from_text(text="My favourite colour is blue.")]), + Content(role="model", parts=[Part.from_text(text="Got it, blue is a great colour!")]), + ] + follow_up = Content( + role="user", + parts=[Part.from_text(text="What is my favourite colour?")], + ) + request = LlmRequest( + model=_GEMINI_MODEL, + contents=history + [follow_up], + ) + responses = [r async for r in gemini_llm.generate_content_async(request)] + text = responses[0].content.parts[0].text.lower() + assert "blue" in text + + +# --------------------------------------------------------------------------- +# Cross-provider on-demand smoke tests +# +# Skipped unless the corresponding model env var is set so cost stays opt-in. +# Set OCI_LLAMA_MODEL / OCI_MISTRAL_MODEL / OCI_GROK_MODEL / OCI_NVIDIA_MODEL +# to a model id available in your tenancy/region (e.g. "meta.llama-3.3-70b-instruct"). +# --------------------------------------------------------------------------- + + +def _provider_llm(env_var: str) -> "OCIGenAILlm | None": + model_id = os.environ.get(env_var) + if not model_id: + return None + return OCIGenAILlm( + model=model_id, + compartment_id=_COMPARTMENT_ID, + service_endpoint=_SERVICE_ENDPOINT, + auth_type=_AUTH_TYPE, + auth_profile=_AUTH_PROFILE, + auth_file_location=_AUTH_FILE, + max_tokens=256, + ) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + not os.environ.get("OCI_LLAMA_MODEL"), + reason="Set OCI_LLAMA_MODEL= to enable.", +) +async def test_llama_on_demand_generate_text(): + llm = _provider_llm("OCI_LLAMA_MODEL") + responses = [ + r + async for r in llm.generate_content_async( + _simple_request(llm.model), stream=False + ) + ] + assert len(responses) == 1 + assert responses[0].content.parts[0].text.strip() + + +@pytest.mark.asyncio +@pytest.mark.skipif( + not os.environ.get("OCI_MISTRAL_MODEL"), + reason="Set OCI_MISTRAL_MODEL= to enable.", +) +async def test_mistral_on_demand_generate_text(): + llm = _provider_llm("OCI_MISTRAL_MODEL") + responses = [ + r + async for r in llm.generate_content_async( + _simple_request(llm.model), stream=False + ) + ] + assert responses[0].content.parts[0].text.strip() + + +@pytest.mark.asyncio +@pytest.mark.skipif( + not os.environ.get("OCI_GROK_MODEL"), + reason="Set OCI_GROK_MODEL= to enable.", +) +async def test_grok_on_demand_generate_text(): + llm = _provider_llm("OCI_GROK_MODEL") + responses = [ + r + async for r in llm.generate_content_async( + _simple_request(llm.model), stream=False + ) + ] + assert responses[0].content.parts[0].text.strip() + + +@pytest.mark.asyncio +@pytest.mark.skipif( + not os.environ.get("OCI_NVIDIA_MODEL"), + reason="Set OCI_NVIDIA_MODEL= to enable.", +) +async def test_nvidia_on_demand_generate_text(): + llm = _provider_llm("OCI_NVIDIA_MODEL") + responses = [ + r + async for r in llm.generate_content_async( + _simple_request(llm.model), stream=False + ) + ] + assert responses[0].content.parts[0].text.strip() + + +# --------------------------------------------------------------------------- +# Dedicated serving mode +# +# Set OCI_DEDICATED_ENDPOINT_ID=ocid1.generativeaiendpoint.oc1... to enable. +# OCI_DEDICATED_MODEL is informational; defaults to the dedicated endpoint's +# bound model (the SDK ignores `model` when serving_mode is dedicated). +# --------------------------------------------------------------------------- + + +_DEDICATED_ENDPOINT_ID = os.environ.get("OCI_DEDICATED_ENDPOINT_ID", "") +_DEDICATED_MODEL = os.environ.get( + "OCI_DEDICATED_MODEL", "meta.llama-3.3-70b-instruct" +) + + +@pytest.fixture +def dedicated_llm() -> OCIGenAILlm: + return OCIGenAILlm( + model=_DEDICATED_MODEL, + endpoint_id=_DEDICATED_ENDPOINT_ID, + compartment_id=_COMPARTMENT_ID, + service_endpoint=_SERVICE_ENDPOINT, + auth_type=_AUTH_TYPE, + auth_profile=_AUTH_PROFILE, + auth_file_location=_AUTH_FILE, + max_tokens=256, + ) + + +@pytest.mark.asyncio +@pytest.mark.skipif( + not _DEDICATED_ENDPOINT_ID, + reason="Set OCI_DEDICATED_ENDPOINT_ID to a dedicated endpoint OCID to enable.", +) +async def test_dedicated_generate_content_text(dedicated_llm): + responses = [ + r + async for r in dedicated_llm.generate_content_async( + _simple_request(_DEDICATED_MODEL), stream=False + ) + ] + assert len(responses) == 1 + assert responses[0].content.parts[0].text.strip() + + +@pytest.mark.asyncio +@pytest.mark.skipif( + not _DEDICATED_ENDPOINT_ID, + reason="Set OCI_DEDICATED_ENDPOINT_ID to a dedicated endpoint OCID to enable.", +) +async def test_dedicated_generate_content_streaming(dedicated_llm): + chunks = [] + async for r in dedicated_llm.generate_content_async( + _simple_request(_DEDICATED_MODEL, text="Count from 1 to 3."), + stream=True, + ): + chunks.append(r) + assert len(chunks) >= 2 # at least one partial + one final + final = chunks[-1] + assert final.usage_metadata is not None + + +# --------------------------------------------------------------------------- +# Sampling parameters (live) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_gemini_max_output_tokens_caps_response(gemini_llm): + """max_output_tokens is honoured: completion tokens never exceed the budget. + + Note: Gemini 2.5 spends part of the budget on reasoning tokens before any + visible output. We pick a budget large enough to leave some text but small + enough to clearly cap an alphabet-recitation response, and we assert on the + reported token count rather than character count (which is flaky). + """ + budget = 64 + request = LlmRequest( + model=_GEMINI_MODEL, + contents=[Content(role="user", parts=[ + Part.from_text(text="Recite the alphabet, A through Z, comma separated.") + ])], + config=types.GenerateContentConfig(max_output_tokens=budget), + ) + responses = [r async for r in gemini_llm.generate_content_async(request)] + um = responses[0].usage_metadata + assert um.candidates_token_count is not None + assert um.candidates_token_count <= budget + + +@pytest.mark.asyncio +async def test_gemini_low_temperature_deterministic_with_seed(gemini_llm): + """temperature=0 + seed should yield consistent answers across two calls.""" + request = LlmRequest( + model=_GEMINI_MODEL, + contents=[Content(role="user", parts=[ + Part.from_text(text="Reply with exactly: 'green'") + ])], + config=types.GenerateContentConfig(temperature=0.0, seed=12345), + ) + call_a = [r async for r in gemini_llm.generate_content_async(request)] + call_b = [r async for r in gemini_llm.generate_content_async(request)] + assert "green" in call_a[0].content.parts[0].text.lower() + assert "green" in call_b[0].content.parts[0].text.lower() + + +@pytest.mark.asyncio +async def test_gemini_stop_sequences_terminate_output(gemini_llm): + request = LlmRequest( + model=_GEMINI_MODEL, + contents=[Content(role="user", parts=[ + Part.from_text(text="Print: APPLE | BANANA | CHERRY") + ])], + config=types.GenerateContentConfig( + temperature=0.0, stop_sequences=["BANANA"] + ), + ) + responses = [r async for r in gemini_llm.generate_content_async(request)] + text = responses[0].content.parts[0].text + assert "BANANA" not in text + + +# --------------------------------------------------------------------------- +# Multimodal: inline image (live) +# +# Uses a tiny 1x1 red PNG so the request is cheap. Gemini 2.5 Flash on OCI +# supports image inputs via ImageContent. +# --------------------------------------------------------------------------- + + +def _make_red_png_1x1() -> bytes: + """Generate a guaranteed-valid 1x1 red PNG with correct CRCs.""" + import struct, zlib + sig = b"\x89PNG\r\n\x1a\n" + def chunk(t: bytes, d: bytes) -> bytes: + return struct.pack(">I", len(d)) + t + d + struct.pack(">I", zlib.crc32(t + d)) + ihdr = struct.pack(">IIBBBBB", 1, 1, 8, 2, 0, 0, 0) # 1x1 RGB + idat = zlib.compress(b"\x00\xff\x00\x00") # filter byte + RGB(255,0,0) + return sig + chunk(b"IHDR", ihdr) + chunk(b"IDAT", idat) + chunk(b"IEND", b"") + + +_TINY_RED_PNG = _make_red_png_1x1() + + +@pytest.mark.asyncio +async def test_gemini_inline_image_input(gemini_llm): + request = LlmRequest( + model=_GEMINI_MODEL, + contents=[Content(role="user", parts=[ + Part.from_text( + text="What is the dominant colour of this image? " + "Reply with just the colour name." + ), + Part(inline_data=types.Blob(mime_type="image/png", data=_TINY_RED_PNG)), + ])], + config=types.GenerateContentConfig(temperature=0.0, max_output_tokens=256), + ) + responses = [r async for r in gemini_llm.generate_content_async(request)] + parts = responses[0].content.parts + assert parts, "Expected the model to produce a visible answer" + text = parts[0].text.lower() + assert "red" in text + + +# --------------------------------------------------------------------------- +# Structured output: response_schema (live) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_gemini_response_schema_returns_valid_json(gemini_llm): + schema = { + "title": "CityFact", + "type": "object", + "properties": { + "city": {"type": "string"}, + "country": {"type": "string"}, + }, + "required": ["city", "country"], + "additionalProperties": False, + } + request = LlmRequest( + model=_GEMINI_MODEL, + contents=[Content(role="user", parts=[ + Part.from_text(text="Give me a fact about Paris.") + ])], + config=types.GenerateContentConfig( + response_mime_type="application/json", + response_schema=schema, + temperature=0.0, + ), + ) + responses = [r async for r in gemini_llm.generate_content_async(request)] + raw = responses[0].content.parts[0].text + payload = json.loads(raw) + assert "city" in payload + assert "country" in payload + + +# --------------------------------------------------------------------------- +# Reasoning-token surfacing (live) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_gemini_reasoning_tokens_reported(gemini_llm): + """Gemini 2.5 emits reasoningTokens in completionTokensDetails — surface them.""" + request = LlmRequest( + model=_GEMINI_MODEL, + contents=[Content(role="user", parts=[ + Part.from_text(text="If a train travels 60km in 30 minutes, what is its speed?") + ])], + config=types.GenerateContentConfig(temperature=0.0), + ) + responses = [r async for r in gemini_llm.generate_content_async(request)] + um = responses[0].usage_metadata + assert um is not None + # Reasoning tokens are optional; assert it's an int when present + assert um.thoughts_token_count is None or um.thoughts_token_count > 0 diff --git a/tests/unittests/models/test_oci_genai_llm.py b/tests/unittests/models/test_oci_genai_llm.py new file mode 100644 index 0000000000..a2374d9e7e --- /dev/null +++ b/tests/unittests/models/test_oci_genai_llm.py @@ -0,0 +1,1259 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the OCI Generative AI LLM integration.""" + +import asyncio +import json +import os +from typing import Any +from unittest import mock +from unittest.mock import MagicMock +from unittest.mock import patch + +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.models.oci_genai_llm import _content_to_oci_message +from google.adk.models.oci_genai_llm import _function_declaration_to_oci_tool +from google.adk.models.oci_genai_llm import _oci_response_to_llm_response +from google.adk.models.oci_genai_llm import OCIGenAILlm +from google.genai import types +from google.genai.types import Content +from google.genai.types import Part +import pytest + + +# --------------------------------------------------------------------------- +# Helpers: build fake OCI SDK response objects without importing oci +# --------------------------------------------------------------------------- + + +def _make_oci_response( + text: str = "Hello from OCI.", + tool_calls: list = None, + prompt_tokens: int = 10, + completion_tokens: int = 5, +) -> MagicMock: + """Build a minimal MagicMock that mirrors the OCI GenAI chat response.""" + usage = MagicMock() + usage.prompt_tokens = prompt_tokens + usage.completion_tokens = completion_tokens + + content_block = MagicMock() + content_block.text = text + + message = MagicMock() + message.content = [content_block] + message.tool_calls = tool_calls or [] + + choice = MagicMock() + choice.message = message + + chat_response = MagicMock() + chat_response.choices = [choice] + chat_response.usage = usage + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +def _make_tool_call_response(name: str, args: dict) -> MagicMock: + """Build a fake OCI tool-call response using FunctionCall (OCI SDK subtype).""" + import oci.generative_ai_inference.models as oci_models + + fc = oci_models.FunctionCall( + id="call_abc123", + type=oci_models.FunctionCall.TYPE_FUNCTION, + name=name, + arguments=json.dumps(args), + ) + + usage = MagicMock() + usage.prompt_tokens = 20 + usage.completion_tokens = 15 + + message = MagicMock() + message.content = [] + message.tool_calls = [fc] + + choice = MagicMock() + choice.message = message + + chat_response = MagicMock() + chat_response.choices = [choice] + chat_response.usage = usage + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def oci_llm(): + return OCIGenAILlm( + model="google.gemini-2.5-flash", + compartment_id="ocid1.compartment.oc1..example", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ) + + +@pytest.fixture +def llm_request(): + return LlmRequest( + model="google.gemini-2.5-flash", + contents=[ + Content(role="user", parts=[Part.from_text(text="Hello")]) + ], + config=types.GenerateContentConfig( + system_instruction="You are a helpful assistant.", + ), + ) + + +# --------------------------------------------------------------------------- +# supported_models +# --------------------------------------------------------------------------- + + +def test_supported_models_gemini(): + assert any( + "gemini" in p for p in OCIGenAILlm.supported_models() + ) + + +def test_supported_models_llama(): + assert any("llama" in p for p in OCIGenAILlm.supported_models()) + + +def test_supported_models_gemma(): + assert any("gemma" in p for p in OCIGenAILlm.supported_models()) + + +def test_supported_models_registry(): + from google.adk.models.registry import LLMRegistry + + assert LLMRegistry.resolve("google.gemini-2.0-flash-001") is OCIGenAILlm + assert LLMRegistry.resolve("meta.llama-3.1-8b-instruct") is OCIGenAILlm + assert LLMRegistry.resolve("google.gemma-3-27b-it") is OCIGenAILlm + + +# --------------------------------------------------------------------------- +# _content_to_oci_message +# --------------------------------------------------------------------------- + + +def test_content_to_oci_message_user_text(): + import oci.generative_ai_inference.models as oci_models + + content = Content(role="user", parts=[Part.from_text(text="Hi there")]) + msg = _content_to_oci_message(content) + assert isinstance(msg, oci_models.UserMessage) + assert msg.role == oci_models.UserMessage.ROLE_USER + assert msg.content[0].text == "Hi there" + + +def test_content_to_oci_message_assistant_text(): + import oci.generative_ai_inference.models as oci_models + + content = Content(role="model", parts=[Part.from_text(text="I can help.")]) + msg = _content_to_oci_message(content) + assert isinstance(msg, oci_models.AssistantMessage) + assert msg.role == oci_models.AssistantMessage.ROLE_ASSISTANT + assert msg.content[0].text == "I can help." + + +def test_content_to_oci_message_multi_part_text(): + import oci.generative_ai_inference.models as oci_models + + content = Content( + role="user", + parts=[ + Part.from_text(text="First"), + Part.from_text(text="Second"), + ], + ) + msg = _content_to_oci_message(content) + assert isinstance(msg, oci_models.UserMessage) + assert "First" in msg.content[0].text + assert "Second" in msg.content[0].text + + +def test_content_to_oci_message_function_call(): + import oci.generative_ai_inference.models as oci_models + + part = Part.from_function_call(name="get_weather", args={"city": "Toronto"}) + content = Content(role="model", parts=[part]) + msg = _content_to_oci_message(content) + assert isinstance(msg, oci_models.AssistantMessage) + assert msg.tool_calls is not None + assert len(msg.tool_calls) == 1 + fc = msg.tool_calls[0] + assert isinstance(fc, oci_models.FunctionCall) + assert fc.name == "get_weather" + assert json.loads(fc.arguments) == {"city": "Toronto"} + + +def test_content_to_oci_message_function_response(): + import oci.generative_ai_inference.models as oci_models + + part = Part.from_function_response( + name="get_weather", response={"result": "Sunny, 22°C"} + ) + part.function_response.id = "call_xyz" + content = Content(role="user", parts=[part]) + msg = _content_to_oci_message(content) + assert isinstance(msg, oci_models.ToolMessage) + assert msg.tool_call_id == "call_xyz" + assert msg.content[0].text + + +# --------------------------------------------------------------------------- +# _oci_response_to_llm_response +# --------------------------------------------------------------------------- + + +def test_oci_response_to_llm_response_text(): + response = _make_oci_response( + text="Here is your answer.", prompt_tokens=8, completion_tokens=4 + ) + llm_resp = _oci_response_to_llm_response(response) + + assert isinstance(llm_resp, LlmResponse) + assert llm_resp.content.role == "model" + assert llm_resp.content.parts[0].text == "Here is your answer." + assert llm_resp.usage_metadata.prompt_token_count == 8 + assert llm_resp.usage_metadata.candidates_token_count == 4 + assert llm_resp.usage_metadata.total_token_count == 12 + + +def test_oci_response_to_llm_response_tool_call(): + response = _make_tool_call_response( + name="get_weather", args={"city": "Chicago"} + ) + llm_resp = _oci_response_to_llm_response(response) + + assert llm_resp.content.role == "model" + fc = llm_resp.content.parts[0].function_call + assert fc.name == "get_weather" + assert fc.args == {"city": "Chicago"} + assert fc.id == "call_abc123" + + +def test_oci_response_to_llm_response_empty_text(): + response = _make_oci_response(text="") + response.data.chat_response.choices[0].message.content = [] + llm_resp = _oci_response_to_llm_response(response) + assert llm_resp.content.parts == [] + + +# --------------------------------------------------------------------------- +# _function_declaration_to_oci_tool +# --------------------------------------------------------------------------- + + +def test_function_declaration_to_oci_tool_no_parameters(): + import oci.generative_ai_inference.models as oci_models + + fn = types.FunctionDeclaration( + name="ping", + description="Check if the service is alive.", + ) + tool = _function_declaration_to_oci_tool(fn) + assert isinstance(tool, oci_models.FunctionDefinition) + assert tool.name == "ping" + assert tool.description == "Check if the service is alive." + assert tool.parameters["type"] == "object" + assert tool.parameters["properties"] == {} + + +def test_function_declaration_to_oci_tool_with_parameters(): + import oci.generative_ai_inference.models as oci_models + + fn = types.FunctionDeclaration( + name="get_weather", + description="Get weather for a city.", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "city": types.Schema( + type=types.Type.STRING, + description="City name", + ) + }, + required=["city"], + ), + ) + tool = _function_declaration_to_oci_tool(fn) + assert isinstance(tool, oci_models.FunctionDefinition) + assert tool.name == "get_weather" + assert "city" in tool.parameters["properties"] + assert tool.parameters["required"] == ["city"] + + +def test_function_declaration_to_oci_tool_json_schema(): + import oci.generative_ai_inference.models as oci_models + + fn = types.FunctionDeclaration( + name="validate", + description="Validates a payload.", + parameters_json_schema={ + "type": "object", + "properties": {"value": {"type": "string"}}, + "required": ["value"], + }, + ) + tool = _function_declaration_to_oci_tool(fn) + assert isinstance(tool, oci_models.FunctionDefinition) + assert tool.parameters["required"] == ["value"] + + +# --------------------------------------------------------------------------- +# OCIGenAILlm.generate_content_async +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_generate_content_async_text(oci_llm, llm_request): + fake_response = _make_oci_response(text="Hi! I am Gemini on OCI.") + + with patch.object(oci_llm, "_call_oci", return_value=fake_response): + responses = [ + r async for r in oci_llm.generate_content_async(llm_request) + ] + + assert len(responses) == 1 + assert responses[0].content.parts[0].text == "Hi! I am Gemini on OCI." + + +@pytest.mark.asyncio +async def test_generate_content_async_yields_llm_response(oci_llm, llm_request): + with patch.object(oci_llm, "_call_oci", return_value=_make_oci_response()): + responses = [ + r async for r in oci_llm.generate_content_async(llm_request) + ] + assert all(isinstance(r, LlmResponse) for r in responses) + + +@pytest.mark.asyncio +async def test_generate_content_async_with_tools(oci_llm): + request = LlmRequest( + model="google.gemini-2.0-flash-001", + contents=[ + Content( + role="user", + parts=[Part.from_text(text="What is the weather in Chicago?")], + ) + ], + config=types.GenerateContentConfig( + tools=[ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name="get_weather", + description="Get weather for a city.", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "city": types.Schema(type=types.Type.STRING) + }, + required=["city"], + ), + ) + ] + ) + ] + ), + ) + tool_response = _make_tool_call_response("get_weather", {"city": "Chicago"}) + + with patch.object(oci_llm, "_call_oci", return_value=tool_response): + responses = [r async for r in oci_llm.generate_content_async(request)] + + fc = responses[0].content.parts[0].function_call + assert fc.name == "get_weather" + assert fc.args["city"] == "Chicago" + + +# --------------------------------------------------------------------------- +# OCIGenAILlm — streaming (stream=True) +# --------------------------------------------------------------------------- + + +def _make_sse_chunks( + text_tokens: list[str], + tool_calls: list[dict] | None = None, + prompt_tokens: int = 10, + completion_tokens: int = 5, +) -> list[dict[str, Any]]: + """Build SSE chunks matching the real OCI GenAI /20231130/ streaming schema. + + Schema (verified against live OCI Gemini stream): + text: {"index": 0, "message": {"role": "ASSISTANT", + "content": [{"type": "TEXT", "text": "..."}]}} + tools: {"index": 0, "message": {"role": "ASSISTANT", + "toolCalls": [{"type": "FUNCTION", "name": "...", + "arguments": "{...}"}]}} + finish: {"finishReason": "stop"} + usage: {"usage": {"promptTokens": N, "completionTokens": N, + "totalTokens": N}} # camelCase! + """ + chunks = [] + + for token in text_tokens: + chunks.append({ + "index": 0, + "message": { + "role": "ASSISTANT", + "content": [{"type": "TEXT", "text": token}], + }, + }) + + for tc_idx, tc in enumerate(tool_calls or []): + chunks.append({ + "index": 0, + "message": { + "role": "ASSISTANT", + "toolCalls": [{ + "type": "FUNCTION", + "id": tc["id"], + "name": tc["name"], + "arguments": json.dumps(tc["args"]), + }], + }, + }) + + chunks.append({"finishReason": "stop"}) + chunks.append({ + "usage": { + "promptTokens": prompt_tokens, + "completionTokens": completion_tokens, + "totalTokens": prompt_tokens + completion_tokens, + }, + }) + return chunks + + +@pytest.mark.asyncio +async def test_streaming_yields_partial_then_final(oci_llm, llm_request): + """stream=True yields partial=True chunks then a final partial=False response.""" + chunks = _make_sse_chunks(["Hello", " world", "!"]) + + with patch.object(oci_llm, "_call_oci_stream", return_value=chunks): + responses = [ + r async for r in oci_llm.generate_content_async(llm_request, stream=True) + ] + + partial = [r for r in responses if r.partial] + final = [r for r in responses if not r.partial] + + assert len(partial) == 3 # one per text token + assert len(final) == 1 + assert partial[0].content.parts[0].text == "Hello" + assert partial[1].content.parts[0].text == " world" + assert partial[2].content.parts[0].text == "!" + # Final aggregates all text + assert final[0].content.parts[0].text == "Hello world!" + + +@pytest.mark.asyncio +async def test_streaming_final_has_usage_metadata(oci_llm, llm_request): + """Final streaming response includes token usage.""" + chunks = _make_sse_chunks(["Hi"], prompt_tokens=8, completion_tokens=3) + + with patch.object(oci_llm, "_call_oci_stream", return_value=chunks): + responses = [ + r async for r in oci_llm.generate_content_async(llm_request, stream=True) + ] + + final = responses[-1] + assert not final.partial + assert final.usage_metadata.prompt_token_count == 8 + assert final.usage_metadata.candidates_token_count == 3 + assert final.usage_metadata.total_token_count == 11 + + +@pytest.mark.asyncio +async def test_streaming_tool_call(oci_llm): + """Streaming assembles tool call arguments from delta chunks.""" + request = LlmRequest( + model="google.gemini-2.5-flash", + contents=[ + Content(role="user", parts=[Part.from_text(text="Weather in Chicago?")]) + ], + ) + chunks = _make_sse_chunks( + text_tokens=[], + tool_calls=[{"id": "call_stream_1", "name": "get_weather", "args": {"city": "Chicago"}}], + ) + + with patch.object(oci_llm, "_call_oci_stream", return_value=chunks): + responses = [ + r async for r in oci_llm.generate_content_async(request, stream=True) + ] + + final = responses[-1] + assert not final.partial + fc = final.content.parts[0].function_call + assert fc.name == "get_weather" + assert fc.args == {"city": "Chicago"} + assert fc.id == "call_stream_1" + + +@pytest.mark.asyncio +async def test_streaming_empty_chunks(oci_llm, llm_request): + """Empty SSE chunk list yields a single empty final response.""" + with patch.object(oci_llm, "_call_oci_stream", return_value=[]): + responses = [ + r async for r in oci_llm.generate_content_async(llm_request, stream=True) + ] + + assert len(responses) == 1 + assert not responses[0].partial + + +@pytest.mark.asyncio +async def test_nonstreaming_uses_call_oci_not_call_oci_stream(oci_llm, llm_request): + """stream=False path calls _call_oci, not _call_oci_stream.""" + with patch.object(oci_llm, "_call_oci", return_value=_make_oci_response()) as mock_call, \ + patch.object(oci_llm, "_call_oci_stream") as mock_stream: + responses = [r async for r in oci_llm.generate_content_async(llm_request, stream=False)] + + mock_call.assert_called_once() + mock_stream.assert_not_called() + assert len(responses) == 1 + + +@pytest.mark.asyncio +async def test_streaming_uses_call_oci_stream_not_call_oci(oci_llm, llm_request): + """stream=True path calls _call_oci_stream, not _call_oci.""" + chunks = _make_sse_chunks(["hi"]) + + with patch.object(oci_llm, "_call_oci_stream", return_value=chunks) as mock_stream, \ + patch.object(oci_llm, "_call_oci") as mock_call: + responses = [r async for r in oci_llm.generate_content_async(llm_request, stream=True)] + + mock_stream.assert_called_once() + mock_call.assert_not_called() + + +@patch("oci.config.from_file", return_value={}) +@patch("oci.generative_ai_inference.GenerativeAiInferenceClient") +def test_call_oci_stream_iterates_sse_via_events_method( + mock_client_cls, _mock_cfg +): + """_call_oci_stream must use response.data.events(), not iterate response.data. + + Regression guard: OCI's SDK returns an SSEClient that exposes events() and + close() but is not directly iterable. Iterating response.data raises + TypeError at runtime against real OCI. + """ + + class FakeSSEEvent: + + def __init__(self, data: str): + self.data = data + + class FakeSSEClient: + """Mimics OCI's SSEClient: exposes events() + close(), not __iter__.""" + + def __init__(self, events: list): + self._events = events + self.closed = False + + def events(self): + return iter(self._events) + + def close(self): + self.closed = True + + def __iter__(self): # pragma: no cover — must NOT be reached + raise TypeError("'SSEClient' object is not iterable") + + sse_payload = [ + FakeSSEEvent(json.dumps({ + "index": 0, + "message": { + "role": "ASSISTANT", + "content": [{"type": "TEXT", "text": "Hi"}], + }, + })), + FakeSSEEvent(json.dumps({"finishReason": "stop"})), + FakeSSEEvent(json.dumps({ + "usage": { + "promptTokens": 4, + "completionTokens": 1, + "totalTokens": 5, + }, + })), + FakeSSEEvent("[DONE]"), + ] + fake_sse = FakeSSEClient(sse_payload) + + mock_client_instance = MagicMock() + mock_client_cls.return_value = mock_client_instance + fake_response = MagicMock() + fake_response.data = fake_sse + mock_client_instance.chat.return_value = fake_response + + llm = OCIGenAILlm( + model="google.gemini-2.5-flash", + compartment_id="ocid1.compartment.oc1..example", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ) + request = LlmRequest( + model="google.gemini-2.5-flash", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + ) + chunks = llm._call_oci_stream(request) + + assert len(chunks) == 3 # text + finish + usage; [DONE] sentinel breaks the loop + assert chunks[0]["message"]["content"][0]["text"] == "Hi" + assert chunks[1]["finishReason"] == "stop" + assert chunks[2]["usage"]["totalTokens"] == 5 + assert fake_sse.closed, "SSEClient.close() must be called after iteration" + + +# --------------------------------------------------------------------------- +# OCIGenAILlm — concurrent async calls +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_concurrent_async_calls(oci_llm): + """Multiple concurrent generate_content_async calls complete independently.""" + responses_by_call = {} + + async def run_call(call_id: int): + request = LlmRequest( + model="google.gemini-2.5-flash", + contents=[Content(role="user", parts=[Part.from_text(text=f"Call {call_id}")])], + ) + with patch.object( + oci_llm, "_call_oci", + return_value=_make_oci_response(text=f"Response {call_id}"), + ): + results = [r async for r in oci_llm.generate_content_async(request)] + responses_by_call[call_id] = results + + await asyncio.gather(*[run_call(i) for i in range(5)]) + + assert len(responses_by_call) == 5 + for call_id, results in responses_by_call.items(): + assert results[0].content.parts[0].text == f"Response {call_id}" + + +@pytest.mark.asyncio +async def test_concurrent_streaming_calls(oci_llm): + """Multiple concurrent streaming calls complete independently.""" + + async def run_streaming(call_id: int): + request = LlmRequest( + model="google.gemini-2.5-flash", + contents=[Content(role="user", parts=[Part.from_text(text=f"Stream {call_id}")])], + ) + chunks = _make_sse_chunks([f"Stream{call_id}"]) + with patch.object(oci_llm, "_call_oci_stream", return_value=chunks): + return [r async for r in oci_llm.generate_content_async(request, stream=True)] + + all_results = await asyncio.gather(*[run_streaming(i) for i in range(3)]) + + for call_id, results in enumerate(all_results): + final = results[-1] + assert not final.partial + assert f"Stream{call_id}" in final.content.parts[0].text + + +# --------------------------------------------------------------------------- +# OCIGenAILlm — configuration & auth +# --------------------------------------------------------------------------- + + +def test_missing_compartment_id_raises(llm_request): + llm = OCIGenAILlm(model="google.gemini-2.5-flash") + with patch.dict(os.environ, {k: v for k, v in os.environ.items() if k != "OCI_COMPARTMENT_ID"}): + os.environ.pop("OCI_COMPARTMENT_ID", None) + with pytest.raises(ValueError, match="compartment_id"): + llm._resolve_compartment_id() + + +def test_compartment_id_from_env(llm_request): + llm = OCIGenAILlm(model="google.gemini-2.0-flash-001") + with patch.dict(os.environ, {"OCI_COMPARTMENT_ID": "ocid1.compartment.example"}): + assert llm._resolve_compartment_id() == "ocid1.compartment.example" + + +def test_service_endpoint_default(): + llm = OCIGenAILlm(model="google.gemini-2.0-flash-001") + endpoint = llm._resolve_service_endpoint() + assert "us-chicago-1" in endpoint + + +def test_service_endpoint_from_env(): + llm = OCIGenAILlm(model="google.gemini-2.0-flash-001") + custom = "https://inference.generativeai.eu-frankfurt-1.oci.oraclecloud.com" + with patch.dict(os.environ, {"OCI_SERVICE_ENDPOINT": custom}): + assert llm._resolve_service_endpoint() == custom + + +def test_service_endpoint_explicit_overrides_env(): + llm = OCIGenAILlm( + model="google.gemini-2.0-flash-001", + service_endpoint="https://custom.endpoint.example.com", + ) + with patch.dict(os.environ, {"OCI_SERVICE_ENDPOINT": "https://ignored.example.com"}): + assert llm._resolve_service_endpoint() == "https://custom.endpoint.example.com" + + +@patch("oci.config.from_file", return_value={"region": "us-chicago-1"}) +@patch("oci.generative_ai_inference.GenerativeAiInferenceClient") +def test_build_client_api_key(mock_client_cls, mock_from_file): + llm = OCIGenAILlm( + model="google.gemini-2.0-flash-001", + auth_type="API_KEY", + auth_profile="DEFAULT", + auth_file_location="~/.oci/config", + ) + llm._build_client("https://inference.generativeai.us-chicago-1.oci.oraclecloud.com") + mock_from_file.assert_called_once_with( + file_location="~/.oci/config", profile_name="DEFAULT" + ) + mock_client_cls.assert_called_once() + + +@patch("oci.auth.signers.InstancePrincipalsSecurityTokenSigner") +@patch("oci.generative_ai_inference.GenerativeAiInferenceClient") +def test_build_client_instance_principal(mock_client_cls, mock_signer_cls): + llm = OCIGenAILlm( + model="google.gemini-2.0-flash-001", + auth_type="INSTANCE_PRINCIPAL", + ) + llm._build_client("https://inference.generativeai.us-chicago-1.oci.oraclecloud.com") + mock_signer_cls.assert_called_once() + mock_client_cls.assert_called_once() + _, kwargs = mock_client_cls.call_args + assert kwargs["config"] == {} + + +@patch("oci.auth.signers.get_resource_principals_signer") +@patch("oci.generative_ai_inference.GenerativeAiInferenceClient") +def test_build_client_resource_principal(mock_client_cls, mock_signer_fn): + llm = OCIGenAILlm( + model="google.gemini-2.0-flash-001", + auth_type="RESOURCE_PRINCIPAL", + ) + llm._build_client("https://inference.generativeai.us-chicago-1.oci.oraclecloud.com") + mock_signer_fn.assert_called_once() + mock_client_cls.assert_called_once() + + +# --------------------------------------------------------------------------- +# OCIGenAILlm._call_oci — verify OCI SDK is called with correct parameters +# --------------------------------------------------------------------------- + + +@patch("oci.config.from_file", return_value={}) +@patch("oci.generative_ai_inference.GenerativeAiInferenceClient") +def test_call_oci_passes_model_and_compartment(mock_client_cls, _mock_cfg): + mock_client_instance = MagicMock() + mock_client_cls.return_value = mock_client_instance + mock_client_instance.chat.return_value = _make_oci_response() + + import oci.generative_ai_inference.models as oci_models # noqa: F401 + + llm = OCIGenAILlm( + model="google.gemini-2.0-flash-001", + compartment_id="ocid1.compartment.oc1..example", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ) + request = LlmRequest( + model="google.gemini-2.0-flash-001", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + ) + llm._call_oci(request) + + mock_client_instance.chat.assert_called_once() + chat_details = mock_client_instance.chat.call_args[0][0] + assert chat_details.compartment_id == "ocid1.compartment.oc1..example" + assert chat_details.serving_mode.model_id == "google.gemini-2.0-flash-001" + + +@patch("oci.config.from_file", return_value={}) +@patch("oci.generative_ai_inference.GenerativeAiInferenceClient") +def test_call_oci_passes_system_instruction(mock_client_cls, _mock_cfg): + import oci.generative_ai_inference.models as oci_models + + mock_client_instance = MagicMock() + mock_client_cls.return_value = mock_client_instance + mock_client_instance.chat.return_value = _make_oci_response() + + llm = OCIGenAILlm( + model="google.gemini-2.0-flash-001", + compartment_id="ocid1.compartment.oc1..example", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ) + request = LlmRequest( + model="google.gemini-2.0-flash-001", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + config=types.GenerateContentConfig( + system_instruction="Be concise.", + ), + ) + llm._call_oci(request) + + chat_details = mock_client_instance.chat.call_args[0][0] + messages = chat_details.chat_request.messages + # System instruction is prepended as a SystemMessage + assert isinstance(messages[0], oci_models.SystemMessage) + assert messages[0].content[0].text == "Be concise." + + +@patch("oci.config.from_file", return_value={}) +@patch("oci.generative_ai_inference.GenerativeAiInferenceClient") +def test_call_oci_passes_tools(mock_client_cls, _mock_cfg): + mock_client_instance = MagicMock() + mock_client_cls.return_value = mock_client_instance + mock_client_instance.chat.return_value = _make_oci_response() + + llm = OCIGenAILlm( + model="google.gemini-2.0-flash-001", + compartment_id="ocid1.compartment.oc1..example", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ) + request = LlmRequest( + model="google.gemini-2.0-flash-001", + contents=[Content(role="user", parts=[Part.from_text(text="Weather?")])], + config=types.GenerateContentConfig( + tools=[ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name="get_weather", + description="Get weather.", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "city": types.Schema(type=types.Type.STRING) + }, + ), + ) + ] + ) + ] + ), + ) + llm._call_oci(request) + + chat_details = mock_client_instance.chat.call_args[0][0] + assert chat_details.chat_request.tools is not None + assert len(chat_details.chat_request.tools) == 1 + assert chat_details.chat_request.tools[0].name == "get_weather" + + +# --------------------------------------------------------------------------- +# Serving mode: on-demand (default) vs dedicated (endpoint_id) +# --------------------------------------------------------------------------- + + +@patch("oci.config.from_file", return_value={}) +@patch("oci.generative_ai_inference.GenerativeAiInferenceClient") +def test_call_oci_uses_on_demand_serving_mode_by_default( + mock_client_cls, _mock_cfg +): + import oci.generative_ai_inference.models as oci_models + + mock_client_instance = MagicMock() + mock_client_cls.return_value = mock_client_instance + mock_client_instance.chat.return_value = _make_oci_response() + + llm = OCIGenAILlm( + model="google.gemini-2.5-flash", + compartment_id="ocid1.compartment.oc1..example", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ) + llm._call_oci( + LlmRequest( + model="google.gemini-2.5-flash", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + ) + ) + + chat_details = mock_client_instance.chat.call_args[0][0] + assert isinstance(chat_details.serving_mode, oci_models.OnDemandServingMode) + assert chat_details.serving_mode.model_id == "google.gemini-2.5-flash" + + +@patch("oci.config.from_file", return_value={}) +@patch("oci.generative_ai_inference.GenerativeAiInferenceClient") +def test_call_oci_uses_dedicated_serving_mode_when_endpoint_id_set( + mock_client_cls, _mock_cfg +): + import oci.generative_ai_inference.models as oci_models + + mock_client_instance = MagicMock() + mock_client_cls.return_value = mock_client_instance + mock_client_instance.chat.return_value = _make_oci_response() + + endpoint_ocid = "ocid1.generativeaiendpoint.oc1.us-chicago-1.example" + llm = OCIGenAILlm( + model="meta.llama-3.1-70b-instruct", + endpoint_id=endpoint_ocid, + compartment_id="ocid1.compartment.oc1..example", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ) + llm._call_oci( + LlmRequest( + model="meta.llama-3.1-70b-instruct", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + ) + ) + + chat_details = mock_client_instance.chat.call_args[0][0] + assert isinstance(chat_details.serving_mode, oci_models.DedicatedServingMode) + assert chat_details.serving_mode.endpoint_id == endpoint_ocid + + +@patch.dict(os.environ, {"OCI_ENDPOINT_ID": "ocid1.generativeaiendpoint.oc1..env"}) +@patch("oci.config.from_file", return_value={}) +@patch("oci.generative_ai_inference.GenerativeAiInferenceClient") +def test_call_oci_uses_dedicated_serving_mode_from_env_var( + mock_client_cls, _mock_cfg +): + import oci.generative_ai_inference.models as oci_models + + mock_client_instance = MagicMock() + mock_client_cls.return_value = mock_client_instance + mock_client_instance.chat.return_value = _make_oci_response() + + llm = OCIGenAILlm( + model="meta.llama-3.1-70b-instruct", + compartment_id="ocid1.compartment.oc1..example", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ) + llm._call_oci( + LlmRequest( + model="meta.llama-3.1-70b-instruct", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + ) + ) + + chat_details = mock_client_instance.chat.call_args[0][0] + assert isinstance(chat_details.serving_mode, oci_models.DedicatedServingMode) + assert chat_details.serving_mode.endpoint_id == "ocid1.generativeaiendpoint.oc1..env" + + +@patch("oci.config.from_file", return_value={}) +@patch("oci.generative_ai_inference.GenerativeAiInferenceClient") +def test_explicit_endpoint_id_overrides_env_var(mock_client_cls, _mock_cfg): + import oci.generative_ai_inference.models as oci_models + + mock_client_instance = MagicMock() + mock_client_cls.return_value = mock_client_instance + mock_client_instance.chat.return_value = _make_oci_response() + + with patch.dict(os.environ, {"OCI_ENDPOINT_ID": "ocid1.generativeaiendpoint.oc1..env"}): + llm = OCIGenAILlm( + model="meta.llama-3.1-70b-instruct", + endpoint_id="ocid1.generativeaiendpoint.oc1..explicit", + compartment_id="ocid1.compartment.oc1..example", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ) + llm._call_oci( + LlmRequest( + model="meta.llama-3.1-70b-instruct", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + ) + ) + + chat_details = mock_client_instance.chat.call_args[0][0] + assert chat_details.serving_mode.endpoint_id == "ocid1.generativeaiendpoint.oc1..explicit" + + +# --------------------------------------------------------------------------- +# Sampling parameters and max_output_tokens passthrough +# --------------------------------------------------------------------------- + + +@patch("oci.config.from_file", return_value={}) +@patch("oci.generative_ai_inference.GenerativeAiInferenceClient") +def test_call_oci_passes_sampling_params(mock_client_cls, _mock_cfg): + mock_client_instance = MagicMock() + mock_client_cls.return_value = mock_client_instance + mock_client_instance.chat.return_value = _make_oci_response() + + llm = OCIGenAILlm( + model="google.gemini-2.5-flash", + compartment_id="ocid1.compartment.oc1..example", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ) + request = LlmRequest( + model="google.gemini-2.5-flash", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + config=types.GenerateContentConfig( + max_output_tokens=128, + temperature=0.7, + top_p=0.9, + top_k=40, + frequency_penalty=0.1, + presence_penalty=0.2, + seed=42, + stop_sequences=["END", "STOP"], + ), + ) + llm._call_oci(request) + + cr = mock_client_instance.chat.call_args[0][0].chat_request + assert cr.max_tokens == 128 + assert cr.temperature == 0.7 + assert cr.top_p == 0.9 + assert cr.top_k == 40 + assert cr.frequency_penalty == 0.1 + assert cr.presence_penalty == 0.2 + assert cr.seed == 42 + assert cr.stop == ["END", "STOP"] + + +@patch("oci.config.from_file", return_value={}) +@patch("oci.generative_ai_inference.GenerativeAiInferenceClient") +def test_call_oci_omits_unset_sampling_params(mock_client_cls, _mock_cfg): + mock_client_instance = MagicMock() + mock_client_cls.return_value = mock_client_instance + mock_client_instance.chat.return_value = _make_oci_response() + + llm = OCIGenAILlm( + model="google.gemini-2.5-flash", + compartment_id="ocid1.compartment.oc1..example", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ) + llm._call_oci(LlmRequest( + model="google.gemini-2.5-flash", + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], + )) + cr = mock_client_instance.chat.call_args[0][0].chat_request + assert cr.temperature is None + assert cr.top_p is None + assert cr.top_k is None + assert cr.stop is None + + +# --------------------------------------------------------------------------- +# Multimodal content +# --------------------------------------------------------------------------- + + +@patch("oci.config.from_file", return_value={}) +@patch("oci.generative_ai_inference.GenerativeAiInferenceClient") +def test_inline_image_becomes_image_content_with_data_url( + mock_client_cls, _mock_cfg +): + import oci.generative_ai_inference.models as oci_models + mock_client_instance = MagicMock() + mock_client_cls.return_value = mock_client_instance + mock_client_instance.chat.return_value = _make_oci_response() + + llm = OCIGenAILlm( + model="google.gemini-2.5-flash", + compartment_id="ocid1.compartment.oc1..example", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ) + png_bytes = b"\x89PNG\r\n\x1a\n_fake" + request = LlmRequest( + model="google.gemini-2.5-flash", + contents=[Content(role="user", parts=[ + Part.from_text(text="What is this?"), + Part(inline_data=types.Blob(mime_type="image/png", data=png_bytes)), + ])], + ) + llm._call_oci(request) + + msg = mock_client_instance.chat.call_args[0][0].chat_request.messages[0] + assert isinstance(msg, oci_models.UserMessage) + blocks = msg.content + assert len(blocks) == 2 + assert isinstance(blocks[0], oci_models.TextContent) + assert blocks[0].text == "What is this?" + assert isinstance(blocks[1], oci_models.ImageContent) + assert blocks[1].image_url.url.startswith("data:image/png;base64,") + import base64 as _b64 + encoded = blocks[1].image_url.url.split(",", 1)[1] + assert _b64.b64decode(encoded) == png_bytes + + +@patch("oci.config.from_file", return_value={}) +@patch("oci.generative_ai_inference.GenerativeAiInferenceClient") +def test_file_data_audio_becomes_audio_content(mock_client_cls, _mock_cfg): + import oci.generative_ai_inference.models as oci_models + mock_client_instance = MagicMock() + mock_client_cls.return_value = mock_client_instance + mock_client_instance.chat.return_value = _make_oci_response() + + llm = OCIGenAILlm( + model="google.gemini-2.5-flash", + compartment_id="ocid1.compartment.oc1..example", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ) + request = LlmRequest( + model="google.gemini-2.5-flash", + contents=[Content(role="user", parts=[ + Part(file_data=types.FileData( + file_uri="https://example.com/clip.mp3", + mime_type="audio/mpeg", + )), + ])], + ) + llm._call_oci(request) + + msg = mock_client_instance.chat.call_args[0][0].chat_request.messages[0] + blocks = [b for b in msg.content if isinstance(b, oci_models.AudioContent)] + assert len(blocks) == 1 + assert blocks[0].audio_url.url == "https://example.com/clip.mp3" + + +@patch("oci.config.from_file", return_value={}) +@patch("oci.generative_ai_inference.GenerativeAiInferenceClient") +def test_inline_pdf_becomes_document_content(mock_client_cls, _mock_cfg): + import oci.generative_ai_inference.models as oci_models + mock_client_instance = MagicMock() + mock_client_cls.return_value = mock_client_instance + mock_client_instance.chat.return_value = _make_oci_response() + + llm = OCIGenAILlm( + model="google.gemini-2.5-flash", + compartment_id="ocid1.compartment.oc1..example", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ) + request = LlmRequest( + model="google.gemini-2.5-flash", + contents=[Content(role="user", parts=[ + Part(inline_data=types.Blob(mime_type="application/pdf", data=b"%PDF-1.4")), + ])], + ) + llm._call_oci(request) + + msg = mock_client_instance.chat.call_args[0][0].chat_request.messages[0] + blocks = [b for b in msg.content if isinstance(b, oci_models.DocumentContent)] + assert len(blocks) == 1 + assert blocks[0].document_url.url.startswith("data:application/pdf;base64,") + + +# --------------------------------------------------------------------------- +# Response format / structured output +# --------------------------------------------------------------------------- + + +@patch("oci.config.from_file", return_value={}) +@patch("oci.generative_ai_inference.GenerativeAiInferenceClient") +def test_response_schema_emits_json_schema_response_format( + mock_client_cls, _mock_cfg +): + import oci.generative_ai_inference.models as oci_models + mock_client_instance = MagicMock() + mock_client_cls.return_value = mock_client_instance + mock_client_instance.chat.return_value = _make_oci_response() + + schema = { + "title": "Weather", + "type": "object", + "properties": {"city": {"type": "string"}, "temp_c": {"type": "number"}}, + "required": ["city", "temp_c"], + } + llm = OCIGenAILlm( + model="google.gemini-2.5-flash", + compartment_id="ocid1.compartment.oc1..example", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ) + llm._call_oci(LlmRequest( + model="google.gemini-2.5-flash", + contents=[Content(role="user", parts=[Part.from_text(text="Chicago weather?")])], + config=types.GenerateContentConfig( + response_mime_type="application/json", + response_schema=schema, + ), + )) + + rf = mock_client_instance.chat.call_args[0][0].chat_request.response_format + assert isinstance(rf, oci_models.JsonSchemaResponseFormat) + assert rf.json_schema.name == "Weather" + assert rf.json_schema.schema == schema + assert rf.json_schema.is_strict is True + + +@patch("oci.config.from_file", return_value={}) +@patch("oci.generative_ai_inference.GenerativeAiInferenceClient") +def test_response_mime_type_only_emits_json_object_format( + mock_client_cls, _mock_cfg +): + import oci.generative_ai_inference.models as oci_models + mock_client_instance = MagicMock() + mock_client_cls.return_value = mock_client_instance + mock_client_instance.chat.return_value = _make_oci_response() + + llm = OCIGenAILlm( + model="google.gemini-2.5-flash", + compartment_id="ocid1.compartment.oc1..example", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ) + llm._call_oci(LlmRequest( + model="google.gemini-2.5-flash", + contents=[Content(role="user", parts=[Part.from_text(text="JSON please")])], + config=types.GenerateContentConfig(response_mime_type="application/json"), + )) + + rf = mock_client_instance.chat.call_args[0][0].chat_request.response_format + assert isinstance(rf, oci_models.JsonObjectResponseFormat) + + +# --------------------------------------------------------------------------- +# Reasoning-token surfacing +# --------------------------------------------------------------------------- + + +@patch("oci.config.from_file", return_value={}) +@patch("oci.generative_ai_inference.GenerativeAiInferenceClient") +def test_nonstreaming_surfaces_reasoning_tokens(mock_client_cls, _mock_cfg): + mock_client_instance = MagicMock() + mock_client_cls.return_value = mock_client_instance + resp = _make_oci_response(prompt_tokens=10, completion_tokens=5) + resp.data.chat_response.usage.completion_tokens_details = MagicMock( + reasoning_tokens=42 + ) + mock_client_instance.chat.return_value = resp + + llm = OCIGenAILlm( + model="google.gemini-2.5-flash", + compartment_id="ocid1.compartment.oc1..example", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + ) + out = _oci_response_to_llm_response(resp) + assert out.usage_metadata.thoughts_token_count == 42 + + +@pytest.mark.asyncio +async def test_streaming_surfaces_reasoning_tokens(oci_llm, llm_request): + chunks = _make_sse_chunks(["Hi"], prompt_tokens=8, completion_tokens=3) + # Inject reasoning tokens into the usage chunk + chunks[-1]["usage"]["completionTokensDetails"] = {"reasoningTokens": 17} + + with patch.object(oci_llm, "_call_oci_stream", return_value=chunks): + responses = [ + r async for r in oci_llm.generate_content_async(llm_request, stream=True) + ] + final = responses[-1] + assert not final.partial + assert final.usage_metadata.thoughts_token_count == 17 From 1d296112d7ae92c82da88c23187f295e893fa4d7 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Wed, 6 May 2026 14:07:27 -0400 Subject: [PATCH 2/2] feat(oci): expose reasoning_effort on OCIGenAILlm constructor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add reasoning_effort: Optional[str] = None as a constructor field. When set, it's passed as `reasoning_effort` on the OCI GenericChatRequest — honoured by GPT-5 family, Gemini 2.5, Grok reasoning variants, and Cohere Command-A-Reasoning; ignored by non-reasoning models. This is the single biggest cost knob for reasoning-capable models on OCI: setting "LOW" typically cuts reasoning-token spend 5-10× vs the default. The other GenericChatRequest fields (verbosity, parallel_tool_calls, logit_bias, n, metadata, etc.) are not exposed — they either duplicate existing controls or are too niche to justify maintenance surface. We only ship what's a missing primitive. Verified live against OCI us-chicago-1: openai.gpt-5.5 + reasoning models all accept the field. --- src/google/adk/models/oci_genai_llm.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/google/adk/models/oci_genai_llm.py b/src/google/adk/models/oci_genai_llm.py index 779b7ed5f1..824e993fe2 100644 --- a/src/google/adk/models/oci_genai_llm.py +++ b/src/google/adk/models/oci_genai_llm.py @@ -337,6 +337,13 @@ class OCIGenAILlm(BaseLlm): auth_file_location: Path to the OCI config file used for ``API_KEY`` auth (default: ``~/.oci/config``). max_tokens: Maximum number of tokens to generate (default: 2048). + reasoning_effort: Reasoning-token budget for reasoning-capable models. + One of ``"NONE"``, ``"MINIMAL"``, ``"LOW"``, ``"MEDIUM"``, ``"HIGH"``, + or ``None`` (default — let OCI pick). Honoured by GPT-5 family, + Gemini 2.5, Grok reasoning variants, and Cohere Command-A-Reasoning; + ignored by non-reasoning models. The single most impactful cost knob + for reasoning models — ``"LOW"`` typically cuts reasoning-token spend + 5-10× vs the default. """ model: str = "google.gemini-2.5-flash" @@ -347,6 +354,7 @@ class OCIGenAILlm(BaseLlm): auth_profile: str = "DEFAULT" auth_file_location: str = "~/.oci/config" max_tokens: int = 2048 + reasoning_effort: Optional[str] = None @classmethod @override @@ -490,6 +498,10 @@ def _build_chat_details( if response_format is not None: chat_request_kwargs["response_format"] = response_format + # Constructor-level reasoning_effort applies regardless of per-request cfg. + if self.reasoning_effort is not None: + chat_request_kwargs["reasoning_effort"] = self.reasoning_effort + if oci_tools: chat_request_kwargs["tools"] = oci_tools if is_stream: