Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/skillspector/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage

from skillspector.constants import MODEL_CONFIG
from skillspector.model_info import get_max_input_tokens, get_max_output_tokens
from skillspector.providers import (
create_chat_model,
get_metadata_provider,
raise_no_llm_api_key_configured,
resolve_chat_model_credentials,
resolve_provider_credentials,
)
from skillspector.providers.openai import OpenAIProvider


def _resolve_llm_credentials() -> tuple[str, str | None]:
Expand All @@ -57,6 +59,18 @@ def _resolve_llm_credentials() -> tuple[str, str | None]:
return creds


def _resolve_default_chat_model() -> str:
"""Return the default chat model for the endpoint that will be used."""
if resolve_provider_credentials() is not None:
return get_metadata_provider().resolve_model()

openai_provider = OpenAIProvider()
if openai_provider.resolve_credentials() is not None:
return openai_provider.resolve_model()

raise_no_llm_api_key_configured()


def is_llm_available() -> tuple[bool, str | None]:
"""Return ``(available, error_message)`` describing LLM credential status."""
try:
Expand All @@ -77,7 +91,7 @@ def get_chat_model(model: str | None = None) -> BaseChatModel:
Raises:
ValueError: when no API key is configured (see ``is_llm_available``).
"""
model = model or MODEL_CONFIG["default"]
model = model or _resolve_default_chat_model()
return create_chat_model(
model=model,
max_tokens=get_max_output_tokens(model),
Expand Down
37 changes: 37 additions & 0 deletions tests/unit/test_llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,15 @@
is_llm_available,
)
from skillspector.providers import NO_LLM_API_KEY_MESSAGE, resolve_provider_credentials
from skillspector.providers.nv_build import NvBuildProvider
from skillspector.providers.openai import OpenAIProvider

_LLM_ENV_VARS = (
"ANTHROPIC_API_KEY",
"OPENAI_API_KEY",
"OPENAI_BASE_URL",
"NVIDIA_INFERENCE_KEY",
"SKILLSPECTOR_MODEL",
"SKILLSPECTOR_PROVIDER",
)

Expand Down Expand Up @@ -178,3 +181,37 @@ def test_returns_false_with_message_when_no_credentials(self) -> None:
ok, msg = is_llm_available()
assert ok is False
assert msg == NO_LLM_API_KEY_MESSAGE


class TestGetChatModel:
def test_openai_fallback_uses_openai_default_model(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setenv("OPENAI_API_KEY", "sk-test-openai-only")

llm = get_chat_model()

assert _chat_model_name(llm) == OpenAIProvider.DEFAULT_MODEL

def test_explicit_model_still_overrides_openai_fallback(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setenv("OPENAI_API_KEY", "sk-test-openai-only")

llm = get_chat_model(model="custom/model")

assert _chat_model_name(llm) == "custom/model"

def test_provider_credentials_use_provider_default_model(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setenv("NVIDIA_INFERENCE_KEY", "nvapi-test")
monkeypatch.setenv("OPENAI_API_KEY", "sk-test-openai")

llm = get_chat_model()

assert _chat_model_name(llm) == NvBuildProvider.DEFAULT_MODEL


def _chat_model_name(llm: object) -> str:
return str(getattr(llm, "model_name", None) or getattr(llm, "model", None))