From 0014d954b7b0aa3a4dd6ef68aab8c00c6de5e94e Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Wed, 1 Apr 2026 16:12:11 +0200 Subject: [PATCH] chore(mypy): add mypy type hints to distributions and testing modules Add type hints to make the following files pass strict mypy checking: - src/llama_stack/distributions/template.py: Add explicit type annotations to filtered dict and list variables, cast jinja2 render return value to str, add cast to typing imports - src/llama_stack/testing/api_recorder.py: Add string annotation for ResponseStorage forward reference, add type annotation for headers dict, add assertion for non-None storage, add type: ignore[method-assign] comments for all method monkey-patching operations Signed-off-by: Mustafa Elbehery --- pyproject.toml | 8 +-- src/llama_stack/distributions/template.py | 45 ++++++++------- src/llama_stack/testing/api_recorder.py | 67 +++++++++++------------ 3 files changed, 58 insertions(+), 62 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 31e7a8d2dc..1ac0421eae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -315,7 +315,7 @@ exclude = [ # All files now have type annotations! 🎉 # # ============================================================================ - # Section 2: Files that need strict typing issues fixed (131 files) + # Section 2: Files that need strict typing issues fixed (127 files) # ============================================================================ # These files have some type hints but fail strict type checking due to # incomplete annotations, Any usage, or other strict mode violations. @@ -443,12 +443,6 @@ exclude = [ "^src/llama_stack/providers/utils/tools/mcp\\.py$", "^src/llama_stack/providers/utils/tools/ttl_dict\\.py$", "^src/llama_stack/providers/utils/vector_io/vector_utils\\.py$", - # Distributions (1 files) - "^src/llama_stack/distributions/template\\.py$", - # Other (3 files) - "^src/llama_stack/log\\.py$", - "^src/llama_stack/models/llama/sku_types\\.py$", - "^src/llama_stack/testing/api_recorder\\.py$", # # ============================================================================ # Directory Excludes (35 directories) diff --git a/src/llama_stack/distributions/template.py b/src/llama_stack/distributions/template.py index dfae2688a7..b49385b94a 100644 --- a/src/llama_stack/distributions/template.py +++ b/src/llama_stack/distributions/template.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, cast import jinja2 import rich @@ -50,7 +50,7 @@ def filter_empty_values(obj: Any) -> Any: return None if isinstance(obj, dict): - filtered = {} + filtered: dict[Any, Any] = {} for key, value in obj.items(): # Special handling for specific fields if key == "module" and isinstance(value, str) and value == "": @@ -70,12 +70,12 @@ def filter_empty_values(obj: Any) -> Any: return filtered elif isinstance(obj, list): - filtered = [] + filtered_list: list[Any] = [] for item in obj: filtered_item = filter_empty_values(item) if filtered_item is not None: - filtered.append(filtered_item) - return filtered + filtered_list.append(filtered_item) + return filtered_list else: # For all other types (including empty strings and dicts that aren't module/config), @@ -232,9 +232,9 @@ def run_config( f"No config class for provider type: {provider.provider_type} for API: {api_str}" ) - config_class = instantiate_class_type(config_class) - if hasattr(config_class, "sample_run_config"): - config = config_class.sample_run_config(__distro_dir__=f"~/.llama/distributions/{name}") + config_cls = instantiate_class_type(config_class) + if hasattr(config_cls, "sample_run_config"): + config = config_cls.sample_run_config(__distro_dir__=f"~/.llama/distributions/{name}") else: config = {} # BuildProvider does not have a config attribute; skip assignment @@ -353,14 +353,14 @@ def generate_markdown_docs(self) -> str: providers_table += f"| {api} | {providers_str} |\n" if self.template_path is not None: - template = self.template_path.read_text() + template_str = self.template_path.read_text() comment = "\n" orphantext = "---\norphan: true\n---\n" - if template.startswith(orphantext): - template = template.replace(orphantext, orphantext + comment) + if template_str.startswith(orphantext): + template_str = template_str.replace(orphantext, orphantext + comment) else: - template = comment + template + template_str = comment + template_str # Render template with rich-generated table env = jinja2.Environment( @@ -369,7 +369,7 @@ def generate_markdown_docs(self) -> str: # NOTE: autoescape is required to prevent XSS attacks autoescape=True, ) - template = env.from_string(template) + template = env.from_string(template_str) default_models = [] if self.available_models_by_provider: @@ -389,14 +389,17 @@ def generate_markdown_docs(self) -> str: ) ) - return template.render( - name=self.name, - description=self.description, - providers=self.providers, - providers_table=providers_table, - run_config_env_vars=self.run_config_env_vars, - default_models=default_models, - run_configs=list(self.run_configs.keys()), + return cast( + str, + template.render( + name=self.name, + description=self.description, + providers=self.providers, + providers_table=providers_table, + run_config_env_vars=self.run_config_env_vars, + default_models=default_models, + run_configs=list(self.run_configs.keys()), + ), ) return "" diff --git a/src/llama_stack/testing/api_recorder.py b/src/llama_stack/testing/api_recorder.py index 5233931c6f..9097394a70 100644 --- a/src/llama_stack/testing/api_recorder.py +++ b/src/llama_stack/testing/api_recorder.py @@ -4,8 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from __future__ import annotations # for forward references - import hashlib import json import os @@ -29,7 +27,7 @@ # client initialization happens in one async context, but tests run in different # contexts, and we need the mode/storage to persist across all contexts. _current_mode: str | None = None -_current_storage: ResponseStorage | None = None +_current_storage: "ResponseStorage | None" = None _original_methods: dict[str, Any] = {} # Per-test deterministic ID counters (test_id -> id_kind -> counter) @@ -290,8 +288,8 @@ def patched_prepare_request(self, request): return None - LlamaStackClient._prepare_request = patched_prepare_request - OpenAI._prepare_request = patched_prepare_request + LlamaStackClient._prepare_request = patched_prepare_request # type: ignore[method-assign] + OpenAI._prepare_request = patched_prepare_request # type: ignore[method-assign] # currently, unpatch is never called @@ -302,9 +300,9 @@ def unpatch_httpx_for_test_id(): from llama_stack_client import LlamaStackClient - LlamaStackClient._prepare_request = _original_methods["llama_stack_client_prepare_request"] + LlamaStackClient._prepare_request = _original_methods["llama_stack_client_prepare_request"] # type: ignore[method-assign] del _original_methods["llama_stack_client_prepare_request"] - OpenAI._prepare_request = _original_methods["openai_prepare_request"] + OpenAI._prepare_request = _original_methods["openai_prepare_request"] # type: ignore[method-assign] del _original_methods["openai_prepare_request"] @@ -792,6 +790,7 @@ async def __aenter__(self): "body": self._body, "is_streaming": False, } + assert _current_storage is not None, "Storage must be initialized" _current_storage.store_recording(self._request_hash, request_data, response_data) # Create a mock response that returns the captured body @@ -917,7 +916,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint if "cloud.databricks.com" in url: url = "__databricks__" + url.split("cloud.databricks.com")[-1] method = "POST" - headers = {} + headers: dict[str, Any] = {} body = kwargs request_hash = normalize_inference_request(method, url, headers, body) @@ -1112,11 +1111,11 @@ async def patched_responses_create(self, *args, **kwargs): ) # Apply OpenAI patches - AsyncChatCompletions.create = patched_chat_completions_create - AsyncCompletions.create = patched_completions_create - AsyncEmbeddings.create = patched_embeddings_create - AsyncModels.list = patched_models_list - AsyncResponses.create = patched_responses_create + AsyncChatCompletions.create = patched_chat_completions_create # type: ignore[method-assign] + AsyncCompletions.create = patched_completions_create # type: ignore[method-assign] + AsyncEmbeddings.create = patched_embeddings_create # type: ignore[method-assign] + AsyncModels.list = patched_models_list # type: ignore[method-assign] + AsyncResponses.create = patched_responses_create # type: ignore[method-assign] # Create patched methods for Ollama client async def patched_ollama_generate(self, *args, **kwargs): @@ -1150,12 +1149,12 @@ async def patched_ollama_list(self, *args, **kwargs): ) # Apply Ollama patches - OllamaAsyncClient.generate = patched_ollama_generate - OllamaAsyncClient.chat = patched_ollama_chat - OllamaAsyncClient.embed = patched_ollama_embed - OllamaAsyncClient.ps = patched_ollama_ps - OllamaAsyncClient.pull = patched_ollama_pull - OllamaAsyncClient.list = patched_ollama_list + OllamaAsyncClient.generate = patched_ollama_generate # type: ignore[method-assign] + OllamaAsyncClient.chat = patched_ollama_chat # type: ignore[method-assign] + OllamaAsyncClient.embed = patched_ollama_embed # type: ignore[method-assign] + OllamaAsyncClient.ps = patched_ollama_ps # type: ignore[method-assign] + OllamaAsyncClient.pull = patched_ollama_pull # type: ignore[method-assign] + OllamaAsyncClient.list = patched_ollama_list # type: ignore[method-assign] # Create patched methods for tool runtimes async def patched_tavily_invoke_tool( @@ -1166,14 +1165,14 @@ async def patched_tavily_invoke_tool( ) # Apply tool runtime patches - TavilySearchToolRuntimeImpl.invoke_tool = patched_tavily_invoke_tool + TavilySearchToolRuntimeImpl.invoke_tool = patched_tavily_invoke_tool # type: ignore[method-assign] # Create patched method for aiohttp rerank requests def patched_aiohttp_session_post(self, url, **kwargs): return _patched_aiohttp_post(_original_methods["aiohttp_post"], self, url, **kwargs) # Apply aiohttp patch - aiohttp.ClientSession.post = patched_aiohttp_session_post + aiohttp.ClientSession.post = patched_aiohttp_session_post # type: ignore[method-assign] def unpatch_inference_clients(): @@ -1195,25 +1194,25 @@ def unpatch_inference_clients(): from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl # Restore OpenAI client methods - AsyncChatCompletions.create = _original_methods["chat_completions_create"] - AsyncCompletions.create = _original_methods["completions_create"] - AsyncEmbeddings.create = _original_methods["embeddings_create"] - AsyncModels.list = _original_methods["models_list"] - AsyncResponses.create = _original_methods["responses_create"] + AsyncChatCompletions.create = _original_methods["chat_completions_create"] # type: ignore[method-assign] + AsyncCompletions.create = _original_methods["completions_create"] # type: ignore[method-assign] + AsyncEmbeddings.create = _original_methods["embeddings_create"] # type: ignore[method-assign] + AsyncModels.list = _original_methods["models_list"] # type: ignore[method-assign] + AsyncResponses.create = _original_methods["responses_create"] # type: ignore[method-assign] # Restore Ollama client methods if they were patched - OllamaAsyncClient.generate = _original_methods["ollama_generate"] - OllamaAsyncClient.chat = _original_methods["ollama_chat"] - OllamaAsyncClient.embed = _original_methods["ollama_embed"] - OllamaAsyncClient.ps = _original_methods["ollama_ps"] - OllamaAsyncClient.pull = _original_methods["ollama_pull"] - OllamaAsyncClient.list = _original_methods["ollama_list"] + OllamaAsyncClient.generate = _original_methods["ollama_generate"] # type: ignore[method-assign] + OllamaAsyncClient.chat = _original_methods["ollama_chat"] # type: ignore[method-assign] + OllamaAsyncClient.embed = _original_methods["ollama_embed"] # type: ignore[method-assign] + OllamaAsyncClient.ps = _original_methods["ollama_ps"] # type: ignore[method-assign] + OllamaAsyncClient.pull = _original_methods["ollama_pull"] # type: ignore[method-assign] + OllamaAsyncClient.list = _original_methods["ollama_list"] # type: ignore[method-assign] # Restore tool runtime methods - TavilySearchToolRuntimeImpl.invoke_tool = _original_methods["tavily_invoke_tool"] + TavilySearchToolRuntimeImpl.invoke_tool = _original_methods["tavily_invoke_tool"] # type: ignore[method-assign] # Restore aiohttp method - aiohttp.ClientSession.post = _original_methods["aiohttp_post"] + aiohttp.ClientSession.post = _original_methods["aiohttp_post"] # type: ignore[method-assign] _original_methods.clear()