Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ exclude = [
# All files now have type annotations! 🎉
#
# ============================================================================
# Section 2: Files that need strict typing issues fixed (131 files)
# Section 2: Files that need strict typing issues fixed (127 files)
# ============================================================================
# These files have some type hints but fail strict type checking due to
# incomplete annotations, Any usage, or other strict mode violations.
Expand Down Expand Up @@ -443,12 +443,6 @@ exclude = [
"^src/llama_stack/providers/utils/tools/mcp\\.py$",
"^src/llama_stack/providers/utils/tools/ttl_dict\\.py$",
"^src/llama_stack/providers/utils/vector_io/vector_utils\\.py$",
# Distributions (1 files)
"^src/llama_stack/distributions/template\\.py$",
# Other (3 files)
"^src/llama_stack/log\\.py$",
"^src/llama_stack/models/llama/sku_types\\.py$",
"^src/llama_stack/testing/api_recorder\\.py$",
#
# ============================================================================
# Directory Excludes (35 directories)
Expand Down
45 changes: 24 additions & 21 deletions src/llama_stack/distributions/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# the root directory of this source tree.

from pathlib import Path
from typing import Any, Literal
from typing import Any, Literal, cast

import jinja2
import rich
Expand Down Expand Up @@ -50,7 +50,7 @@ def filter_empty_values(obj: Any) -> Any:
return None

if isinstance(obj, dict):
filtered = {}
filtered: dict[Any, Any] = {}
for key, value in obj.items():
# Special handling for specific fields
if key == "module" and isinstance(value, str) and value == "":
Expand All @@ -70,12 +70,12 @@ def filter_empty_values(obj: Any) -> Any:
return filtered

elif isinstance(obj, list):
filtered = []
filtered_list: list[Any] = []
for item in obj:
filtered_item = filter_empty_values(item)
if filtered_item is not None:
filtered.append(filtered_item)
return filtered
filtered_list.append(filtered_item)
return filtered_list

else:
# For all other types (including empty strings and dicts that aren't module/config),
Expand Down Expand Up @@ -232,9 +232,9 @@ def run_config(
f"No config class for provider type: {provider.provider_type} for API: {api_str}"
)

config_class = instantiate_class_type(config_class)
if hasattr(config_class, "sample_run_config"):
config = config_class.sample_run_config(__distro_dir__=f"~/.llama/distributions/{name}")
config_cls = instantiate_class_type(config_class)
if hasattr(config_cls, "sample_run_config"):
config = config_cls.sample_run_config(__distro_dir__=f"~/.llama/distributions/{name}")
else:
config = {}
# BuildProvider does not have a config attribute; skip assignment
Expand Down Expand Up @@ -353,14 +353,14 @@ def generate_markdown_docs(self) -> str:
providers_table += f"| {api} | {providers_str} |\n"

if self.template_path is not None:
template = self.template_path.read_text()
template_str = self.template_path.read_text()
comment = "<!-- This file was auto-generated by distro_codegen.py, please edit source -->\n"
orphantext = "---\norphan: true\n---\n"

if template.startswith(orphantext):
template = template.replace(orphantext, orphantext + comment)
if template_str.startswith(orphantext):
template_str = template_str.replace(orphantext, orphantext + comment)
else:
template = comment + template
template_str = comment + template_str

# Render template with rich-generated table
env = jinja2.Environment(
Expand All @@ -369,7 +369,7 @@ def generate_markdown_docs(self) -> str:
# NOTE: autoescape is required to prevent XSS attacks
autoescape=True,
)
template = env.from_string(template)
template = env.from_string(template_str)

default_models = []
if self.available_models_by_provider:
Expand All @@ -389,14 +389,17 @@ def generate_markdown_docs(self) -> str:
)
)

return template.render(
name=self.name,
description=self.description,
providers=self.providers,
providers_table=providers_table,
run_config_env_vars=self.run_config_env_vars,
default_models=default_models,
run_configs=list(self.run_configs.keys()),
return cast(
str,
template.render(
name=self.name,
description=self.description,
providers=self.providers,
providers_table=providers_table,
run_config_env_vars=self.run_config_env_vars,
default_models=default_models,
run_configs=list(self.run_configs.keys()),
),
)
return ""

Expand Down
67 changes: 33 additions & 34 deletions src/llama_stack/testing/api_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from __future__ import annotations # for forward references

import hashlib
import json
import os
Expand All @@ -29,7 +27,7 @@
# client initialization happens in one async context, but tests run in different
# contexts, and we need the mode/storage to persist across all contexts.
_current_mode: str | None = None
_current_storage: ResponseStorage | None = None
_current_storage: "ResponseStorage | None" = None
_original_methods: dict[str, Any] = {}

# Per-test deterministic ID counters (test_id -> id_kind -> counter)
Expand Down Expand Up @@ -290,8 +288,8 @@ def patched_prepare_request(self, request):

return None

LlamaStackClient._prepare_request = patched_prepare_request
OpenAI._prepare_request = patched_prepare_request
LlamaStackClient._prepare_request = patched_prepare_request # type: ignore[method-assign]
OpenAI._prepare_request = patched_prepare_request # type: ignore[method-assign]


# currently, unpatch is never called
Expand All @@ -302,9 +300,9 @@ def unpatch_httpx_for_test_id():

from llama_stack_client import LlamaStackClient

LlamaStackClient._prepare_request = _original_methods["llama_stack_client_prepare_request"]
LlamaStackClient._prepare_request = _original_methods["llama_stack_client_prepare_request"] # type: ignore[method-assign]
del _original_methods["llama_stack_client_prepare_request"]
OpenAI._prepare_request = _original_methods["openai_prepare_request"]
OpenAI._prepare_request = _original_methods["openai_prepare_request"] # type: ignore[method-assign]
del _original_methods["openai_prepare_request"]


Expand Down Expand Up @@ -792,6 +790,7 @@ async def __aenter__(self):
"body": self._body,
"is_streaming": False,
}
assert _current_storage is not None, "Storage must be initialized"
_current_storage.store_recording(self._request_hash, request_data, response_data)

# Create a mock response that returns the captured body
Expand Down Expand Up @@ -917,7 +916,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
if "cloud.databricks.com" in url:
url = "__databricks__" + url.split("cloud.databricks.com")[-1]
method = "POST"
headers = {}
headers: dict[str, Any] = {}
body = kwargs

request_hash = normalize_inference_request(method, url, headers, body)
Expand Down Expand Up @@ -1112,11 +1111,11 @@ async def patched_responses_create(self, *args, **kwargs):
)

# Apply OpenAI patches
AsyncChatCompletions.create = patched_chat_completions_create
AsyncCompletions.create = patched_completions_create
AsyncEmbeddings.create = patched_embeddings_create
AsyncModels.list = patched_models_list
AsyncResponses.create = patched_responses_create
AsyncChatCompletions.create = patched_chat_completions_create # type: ignore[method-assign]
AsyncCompletions.create = patched_completions_create # type: ignore[method-assign]
AsyncEmbeddings.create = patched_embeddings_create # type: ignore[method-assign]
AsyncModels.list = patched_models_list # type: ignore[method-assign]
AsyncResponses.create = patched_responses_create # type: ignore[method-assign]

# Create patched methods for Ollama client
async def patched_ollama_generate(self, *args, **kwargs):
Expand Down Expand Up @@ -1150,12 +1149,12 @@ async def patched_ollama_list(self, *args, **kwargs):
)

# Apply Ollama patches
OllamaAsyncClient.generate = patched_ollama_generate
OllamaAsyncClient.chat = patched_ollama_chat
OllamaAsyncClient.embed = patched_ollama_embed
OllamaAsyncClient.ps = patched_ollama_ps
OllamaAsyncClient.pull = patched_ollama_pull
OllamaAsyncClient.list = patched_ollama_list
OllamaAsyncClient.generate = patched_ollama_generate # type: ignore[method-assign]
OllamaAsyncClient.chat = patched_ollama_chat # type: ignore[method-assign]
OllamaAsyncClient.embed = patched_ollama_embed # type: ignore[method-assign]
OllamaAsyncClient.ps = patched_ollama_ps # type: ignore[method-assign]
OllamaAsyncClient.pull = patched_ollama_pull # type: ignore[method-assign]
OllamaAsyncClient.list = patched_ollama_list # type: ignore[method-assign]

# Create patched methods for tool runtimes
async def patched_tavily_invoke_tool(
Expand All @@ -1166,14 +1165,14 @@ async def patched_tavily_invoke_tool(
)

# Apply tool runtime patches
TavilySearchToolRuntimeImpl.invoke_tool = patched_tavily_invoke_tool
TavilySearchToolRuntimeImpl.invoke_tool = patched_tavily_invoke_tool # type: ignore[method-assign]

# Create patched method for aiohttp rerank requests
def patched_aiohttp_session_post(self, url, **kwargs):
return _patched_aiohttp_post(_original_methods["aiohttp_post"], self, url, **kwargs)

# Apply aiohttp patch
aiohttp.ClientSession.post = patched_aiohttp_session_post
aiohttp.ClientSession.post = patched_aiohttp_session_post # type: ignore[method-assign]


def unpatch_inference_clients():
Expand All @@ -1195,25 +1194,25 @@ def unpatch_inference_clients():
from llama_stack.providers.remote.tool_runtime.tavily_search.tavily_search import TavilySearchToolRuntimeImpl

# Restore OpenAI client methods
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
AsyncCompletions.create = _original_methods["completions_create"]
AsyncEmbeddings.create = _original_methods["embeddings_create"]
AsyncModels.list = _original_methods["models_list"]
AsyncResponses.create = _original_methods["responses_create"]
AsyncChatCompletions.create = _original_methods["chat_completions_create"] # type: ignore[method-assign]
AsyncCompletions.create = _original_methods["completions_create"] # type: ignore[method-assign]
AsyncEmbeddings.create = _original_methods["embeddings_create"] # type: ignore[method-assign]
AsyncModels.list = _original_methods["models_list"] # type: ignore[method-assign]
AsyncResponses.create = _original_methods["responses_create"] # type: ignore[method-assign]

# Restore Ollama client methods if they were patched
OllamaAsyncClient.generate = _original_methods["ollama_generate"]
OllamaAsyncClient.chat = _original_methods["ollama_chat"]
OllamaAsyncClient.embed = _original_methods["ollama_embed"]
OllamaAsyncClient.ps = _original_methods["ollama_ps"]
OllamaAsyncClient.pull = _original_methods["ollama_pull"]
OllamaAsyncClient.list = _original_methods["ollama_list"]
OllamaAsyncClient.generate = _original_methods["ollama_generate"] # type: ignore[method-assign]
OllamaAsyncClient.chat = _original_methods["ollama_chat"] # type: ignore[method-assign]
OllamaAsyncClient.embed = _original_methods["ollama_embed"] # type: ignore[method-assign]
OllamaAsyncClient.ps = _original_methods["ollama_ps"] # type: ignore[method-assign]
OllamaAsyncClient.pull = _original_methods["ollama_pull"] # type: ignore[method-assign]
OllamaAsyncClient.list = _original_methods["ollama_list"] # type: ignore[method-assign]

# Restore tool runtime methods
TavilySearchToolRuntimeImpl.invoke_tool = _original_methods["tavily_invoke_tool"]
TavilySearchToolRuntimeImpl.invoke_tool = _original_methods["tavily_invoke_tool"] # type: ignore[method-assign]

# Restore aiohttp method
aiohttp.ClientSession.post = _original_methods["aiohttp_post"]
aiohttp.ClientSession.post = _original_methods["aiohttp_post"] # type: ignore[method-assign]

_original_methods.clear()

Expand Down
Loading