Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion aider/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,10 @@ def validate_environment(self):
if res:
return res

provider = self.info.get("litellm_provider", "").lower()
if provider == "custom_openai":
return dict(keys_in_environment=True, missing_keys=[])

# https://github.com/BerriAI/litellm/issues/3190

model = self.name
Expand All @@ -769,7 +773,6 @@ def validate_environment(self):
if res["missing_keys"]:
return res

provider = self.info.get("litellm_provider", "").lower()
if provider == "cohere_chat":
return validate_variables(["COHERE_API_KEY"])
if provider == "gemini":
Expand Down
39 changes: 39 additions & 0 deletions tests/basic/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,45 @@ def test_check_for_dependencies_other_model(self, mock_check_pip):
# Verify check_pip_install_extra was not called
mock_check_pip.assert_not_called()

@patch("aider.models.Model.get_model_info")
@patch("aider.models.litellm.validate_environment")
def test_custom_openai_metadata_bypasses_litellm_validation(
self, mock_validate_environment, mock_get_model_info
):
"""Test that metadata-backed custom_openai models bypass litellm validation."""
mock_get_model_info.return_value = {"litellm_provider": "custom_openai", "mode": "chat"}
mock_validate_environment.return_value = {
"keys_in_environment": False,
"missing_keys": ["OPENAI_API_KEY"],
}

model = Model("custom_openai/my-openai-model")

self.assertTrue(model.keys_in_environment)
self.assertEqual(model.missing_keys, [])
mock_validate_environment.assert_not_called()

@patch("aider.models.Model.get_model_info")
@patch("aider.models.litellm.validate_environment")
def test_non_custom_metadata_still_uses_litellm_validation(
self, mock_validate_environment, mock_get_model_info
):
"""Test that only the exact custom_openai provider bypasses litellm validation."""
mock_get_model_info.return_value = {
"litellm_provider": "custom_openai_plus",
"mode": "chat",
}
mock_validate_environment.return_value = {
"keys_in_environment": False,
"missing_keys": ["SOME_KEY"],
}

model = Model("custom_openai/my-openai-model")

self.assertFalse(model.keys_in_environment)
self.assertEqual(model.missing_keys, ["SOME_KEY"])
mock_validate_environment.assert_called_once_with("custom_openai/my-openai-model")

def test_get_repo_map_tokens(self):
# Test default case (no max_input_tokens in info)
model = Model("gpt-4")
Expand Down