diff --git a/src/skillspector/llm_utils.py b/src/skillspector/llm_utils.py index ab6e551..d1c5104 100644 --- a/src/skillspector/llm_utils.py +++ b/src/skillspector/llm_utils.py @@ -33,13 +33,15 @@ from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import BaseMessage -from skillspector.constants import MODEL_CONFIG from skillspector.model_info import get_max_input_tokens, get_max_output_tokens from skillspector.providers import ( create_chat_model, + get_metadata_provider, raise_no_llm_api_key_configured, resolve_chat_model_credentials, + resolve_provider_credentials, ) +from skillspector.providers.openai import OpenAIProvider def _resolve_llm_credentials() -> tuple[str, str | None]: @@ -57,6 +59,18 @@ def _resolve_llm_credentials() -> tuple[str, str | None]: return creds +def _resolve_default_chat_model() -> str: + """Return the default chat model for the endpoint that will be used.""" + if resolve_provider_credentials() is not None: + return get_metadata_provider().resolve_model() + + openai_provider = OpenAIProvider() + if openai_provider.resolve_credentials() is not None: + return openai_provider.resolve_model() + + raise_no_llm_api_key_configured() + + def is_llm_available() -> tuple[bool, str | None]: """Return ``(available, error_message)`` describing LLM credential status.""" try: @@ -77,7 +91,7 @@ def get_chat_model(model: str | None = None) -> BaseChatModel: Raises: ValueError: when no API key is configured (see ``is_llm_available``). """ - model = model or MODEL_CONFIG["default"] + model = model or _resolve_default_chat_model() return create_chat_model( model=model, max_tokens=get_max_output_tokens(model), diff --git a/tests/unit/test_llm_utils.py b/tests/unit/test_llm_utils.py index 5e89ead..6c013c8 100644 --- a/tests/unit/test_llm_utils.py +++ b/tests/unit/test_llm_utils.py @@ -35,12 +35,15 @@ is_llm_available, ) from skillspector.providers import NO_LLM_API_KEY_MESSAGE, resolve_provider_credentials +from skillspector.providers.nv_build import NvBuildProvider +from skillspector.providers.openai import OpenAIProvider _LLM_ENV_VARS = ( "ANTHROPIC_API_KEY", "OPENAI_API_KEY", "OPENAI_BASE_URL", "NVIDIA_INFERENCE_KEY", + "SKILLSPECTOR_MODEL", "SKILLSPECTOR_PROVIDER", ) @@ -178,3 +181,37 @@ def test_returns_false_with_message_when_no_credentials(self) -> None: ok, msg = is_llm_available() assert ok is False assert msg == NO_LLM_API_KEY_MESSAGE + + +class TestGetChatModel: + def test_openai_fallback_uses_openai_default_model( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "sk-test-openai-only") + + llm = get_chat_model() + + assert _chat_model_name(llm) == OpenAIProvider.DEFAULT_MODEL + + def test_explicit_model_still_overrides_openai_fallback( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "sk-test-openai-only") + + llm = get_chat_model(model="custom/model") + + assert _chat_model_name(llm) == "custom/model" + + def test_provider_credentials_use_provider_default_model( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + monkeypatch.setenv("NVIDIA_INFERENCE_KEY", "nvapi-test") + monkeypatch.setenv("OPENAI_API_KEY", "sk-test-openai") + + llm = get_chat_model() + + assert _chat_model_name(llm) == NvBuildProvider.DEFAULT_MODEL + + +def _chat_model_name(llm: object) -> str: + return str(getattr(llm, "model_name", None) or getattr(llm, "model", None))