diff --git a/src/model_manager/model.cpp b/src/model_manager/model.cpp index 2f6d5f18..8ac60c50 100644 --- a/src/model_manager/model.cpp +++ b/src/model_manager/model.cpp @@ -26,19 +26,31 @@ void Model::LoadModelDetails(const nlohmann::json& model_json) { } bool has_resolved_details = model_json.contains("model") && - model_json.contains("provider") && - model_json.contains("secret") && - model_json.contains("tuple_format") && - model_json.contains("batch_size"); + model_json.contains("provider"); nlohmann::json db_model_args; if (has_resolved_details) { model_details_.model = model_json.at("model").get(); model_details_.provider_name = model_json.at("provider").get(); - model_details_.secret = model_json["secret"].get>(); - model_details_.tuple_format = model_json.at("tuple_format").get(); - model_details_.batch_size = model_json.at("batch_size").get(); + if (model_json.contains("secret")) { + model_details_.secret = model_json["secret"].get>(); + } else { + auto secret_name = "__default_" + model_details_.provider_name; + if (model_details_.provider_name == AZURE) { + secret_name += "_llm"; + } + if (model_json.contains("secret_name")) { + secret_name = model_json["secret_name"].get(); + } + model_details_.secret = SecretManager::GetSecret(secret_name); + } + model_details_.tuple_format = model_json.contains("tuple_format") + ? model_json.at("tuple_format").get() + : "XML"; + model_details_.batch_size = model_json.contains("batch_size") + ? model_json.at("batch_size").get() + : 2048; if (model_json.contains("model_parameters")) { auto& mp = model_json.at("model_parameters"); diff --git a/test/unit/functions/aggregate/llm_aggregate_function_test_base.hpp b/test/unit/functions/aggregate/llm_aggregate_function_test_base.hpp index c84cac5c..79a6cc75 100644 --- a/test/unit/functions/aggregate/llm_aggregate_function_test_base.hpp +++ b/test/unit/functions/aggregate/llm_aggregate_function_test_base.hpp @@ -1,6 +1,7 @@ #pragma once #include "../mock_provider.hpp" +#include "../../ollama_test_utils.hpp" #include "flock/core/config.hpp" #include "flock/functions/aggregate/aggregate.hpp" #include "flock/model_manager/model.hpp" @@ -24,12 +25,14 @@ class LLMAggregateTestBase : public ::testing::Test { void SetUp() override { auto con = Config::GetConnection(); + Config::ConfigureTables(con, ConfigType::LOCAL); con.Query(" CREATE SECRET (" " TYPE OPENAI," " API_KEY 'your-api-key');"); con.Query(" CREATE SECRET (" " TYPE OLLAMA," " API_URL '127.0.0.1:11434');"); + SeedOllamaTestModel(con); // Create a shared mock provider for expectations mock_provider = std::make_shared(ModelDetails{}); diff --git a/test/unit/functions/aggregate/llm_first.cpp b/test/unit/functions/aggregate/llm_first.cpp index e9f8132d..70c7f601 100644 --- a/test/unit/functions/aggregate/llm_first.cpp +++ b/test/unit/functions/aggregate/llm_first.cpp @@ -1,4 +1,5 @@ #include "flock/functions/aggregate/llm_first_or_last.hpp" +#include "../../ollama_test_utils.hpp" #include "llm_aggregate_function_test_base.hpp" namespace flock { @@ -185,17 +186,20 @@ TEST_F(LLMFirstTest, AudioTranscription) { // Test audio transcription error handling for Ollama TEST_F(LLMFirstTest, AudioTranscriptionOllamaError) { auto con = Config::GetConnection(); + const auto ollama_model = GetOllamaTestModelName(); EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) .WillOnce(::testing::Throw(std::runtime_error("Audio transcription is not currently supported by Ollama."))); const auto results = con.Query( "SELECT llm_first(" - "{'model_name': 'gemma3:4b'}, " + "{'model_name': '" + + ollama_model + "'}, " "{'prompt': 'Select the best audio', " "'context_columns': [" "{'data': audio_url, " "'type': 'audio', " - "'transcription_model': 'gemma3:4b'}" + "'transcription_model': '" + + ollama_model + "'}" "]}) AS result FROM VALUES " "('https://example.com/audio1.mp3'), " "('https://example.com/audio2.mp3') AS tbl(audio_url);"); diff --git a/test/unit/functions/aggregate/llm_last.cpp b/test/unit/functions/aggregate/llm_last.cpp index f3f515d8..05129d4f 100644 --- a/test/unit/functions/aggregate/llm_last.cpp +++ b/test/unit/functions/aggregate/llm_last.cpp @@ -1,4 +1,5 @@ #include "flock/functions/aggregate/llm_first_or_last.hpp" +#include "../../ollama_test_utils.hpp" #include "llm_aggregate_function_test_base.hpp" namespace flock { @@ -185,17 +186,20 @@ TEST_F(LLMLastTest, AudioTranscription) { // Test audio transcription error handling for Ollama TEST_F(LLMLastTest, AudioTranscriptionOllamaError) { auto con = Config::GetConnection(); + const auto ollama_model = GetOllamaTestModelName(); EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) .WillOnce(::testing::Throw(std::runtime_error("Audio transcription is not currently supported by Ollama."))); const auto results = con.Query( "SELECT llm_last(" - "{'model_name': 'gemma3:4b'}, " + "{'model_name': '" + + ollama_model + "'}, " "{'prompt': 'Select the worst audio', " "'context_columns': [" "{'data': audio_url, " "'type': 'audio', " - "'transcription_model': 'gemma3:4b'}" + "'transcription_model': '" + + ollama_model + "'}" "]}) AS result FROM VALUES " "('https://example.com/audio1.mp3'), " "('https://example.com/audio2.mp3') AS tbl(audio_url);"); diff --git a/test/unit/functions/aggregate/llm_reduce.cpp b/test/unit/functions/aggregate/llm_reduce.cpp index 50e27a4c..f77c583d 100644 --- a/test/unit/functions/aggregate/llm_reduce.cpp +++ b/test/unit/functions/aggregate/llm_reduce.cpp @@ -1,4 +1,5 @@ #include "flock/functions/aggregate/llm_reduce.hpp" +#include "../../ollama_test_utils.hpp" #include "llm_aggregate_function_test_base.hpp" namespace flock { @@ -209,17 +210,20 @@ TEST_F(LLMReduceTest, AudioAndTextColumns) { // Test audio transcription error handling for Ollama TEST_F(LLMReduceTest, AudioTranscriptionOllamaError) { auto con = Config::GetConnection(); + const auto ollama_model = GetOllamaTestModelName(); EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) .WillOnce(::testing::Throw(std::runtime_error("Audio transcription is not currently supported by Ollama."))); const auto results = con.Query( "SELECT llm_reduce(" - "{'model_name': 'gemma3:4b'}, " + "{'model_name': '" + + ollama_model + "'}, " "{'prompt': 'Summarize this audio', " "'context_columns': [" "{'data': audio_url, " "'type': 'audio', " - "'transcription_model': 'gemma3:4b'}" + "'transcription_model': '" + + ollama_model + "'}" "]}) AS result FROM VALUES ('https://example.com/audio.mp3') AS tbl(audio_url);"); ASSERT_TRUE(results->HasError()); diff --git a/test/unit/functions/aggregate/llm_rerank.cpp b/test/unit/functions/aggregate/llm_rerank.cpp index 7edbd556..f36d9c43 100644 --- a/test/unit/functions/aggregate/llm_rerank.cpp +++ b/test/unit/functions/aggregate/llm_rerank.cpp @@ -1,4 +1,5 @@ #include "flock/functions/aggregate/llm_rerank.hpp" +#include "../../ollama_test_utils.hpp" #include "llm_aggregate_function_test_base.hpp" #include @@ -196,17 +197,20 @@ TEST_F(LLMRerankTest, AudioTranscription) { // Test audio transcription error handling for Ollama TEST_F(LLMRerankTest, AudioTranscriptionOllamaError) { auto con = Config::GetConnection(); + const auto ollama_model = GetOllamaTestModelName(); EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) .WillOnce(::testing::Throw(std::runtime_error("Audio transcription is not currently supported by Ollama."))); const auto results = con.Query( "SELECT llm_rerank(" - "{'model_name': 'gemma3:4b'}, " + "{'model_name': '" + + ollama_model + "'}, " "{'prompt': 'Rank these audio files', " "'context_columns': [" "{'data': audio_url, " "'type': 'audio', " - "'transcription_model': 'gemma3:4b'}" + "'transcription_model': '" + + ollama_model + "'}" "]}) AS result FROM VALUES " "('https://example.com/audio1.mp3'), " "('https://example.com/audio2.mp3') AS tbl(audio_url);"); diff --git a/test/unit/functions/scalar/llm_complete.cpp b/test/unit/functions/scalar/llm_complete.cpp index c2936597..f8d50f6d 100644 --- a/test/unit/functions/scalar/llm_complete.cpp +++ b/test/unit/functions/scalar/llm_complete.cpp @@ -1,4 +1,5 @@ #include "flock/functions/scalar/llm_complete.hpp" +#include "../../ollama_test_utils.hpp" #include "llm_function_test_base.hpp" namespace flock { @@ -227,6 +228,7 @@ TEST_F(LLMCompleteTest, LLMCompleteWithAudioAndText) { // Test audio transcription error handling TEST_F(LLMCompleteTest, LLMCompleteAudioTranscriptionError) { auto con = Config::GetConnection(); + const auto ollama_model = GetOllamaTestModelName(); // Mock transcription model to throw error (simulating Ollama behavior) EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) .WillOnce(::testing::Throw(std::runtime_error("Audio transcription is not currently supported by Ollama."))); @@ -234,12 +236,14 @@ TEST_F(LLMCompleteTest, LLMCompleteAudioTranscriptionError) { // Test with Ollama which doesn't support transcription const auto results = con.Query( "SELECT llm_complete(" - "{'model_name': 'gemma3:4b'}, " + "{'model_name': '" + + ollama_model + "'}, " "{'prompt': 'Summarize this audio', " "'context_columns': [" "{'data': audio_url, " "'type': 'audio', " - "'transcription_model': 'gemma3:4b'}" + "'transcription_model': '" + + ollama_model + "'}" "]}) AS result FROM VALUES ('https://example.com/audio.mp3') AS tbl(audio_url);"); // Should fail because Ollama doesn't support transcription @@ -262,4 +266,4 @@ TEST_F(LLMCompleteTest, LLMCompleteAudioMissingTranscriptionModel) { ASSERT_TRUE(results->HasError()); } -}// namespace flock \ No newline at end of file +}// namespace flock diff --git a/test/unit/functions/scalar/llm_filter.cpp b/test/unit/functions/scalar/llm_filter.cpp index 66d1fbb0..5538e804 100644 --- a/test/unit/functions/scalar/llm_filter.cpp +++ b/test/unit/functions/scalar/llm_filter.cpp @@ -1,4 +1,5 @@ #include "flock/functions/scalar/llm_filter.hpp" +#include "../../ollama_test_utils.hpp" #include "llm_function_test_base.hpp" namespace flock { @@ -196,6 +197,7 @@ TEST_F(LLMFilterTest, LLMFilterWithAudioAndText) { // Test audio transcription error handling for Ollama TEST_F(LLMFilterTest, LLMFilterAudioTranscriptionOllamaError) { auto con = Config::GetConnection(); + const auto ollama_model = GetOllamaTestModelName(); // Mock transcription model to throw error (simulating Ollama behavior) EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_)) @@ -204,12 +206,14 @@ TEST_F(LLMFilterTest, LLMFilterAudioTranscriptionOllamaError) { // Test with Ollama which doesn't support transcription const auto results = con.Query( "SELECT llm_filter(" - "{'model_name': 'gemma3:4b'}, " + "{'model_name': '" + + ollama_model + "'}, " "{'prompt': 'Is the sentiment positive?', " "'context_columns': [" "{'data': audio_url, " "'type': 'audio', " - "'transcription_model': 'gemma3:4b'}" + "'transcription_model': '" + + ollama_model + "'}" "]}) AS result FROM VALUES ('https://example.com/audio.mp3') AS tbl(audio_url);"); // Should fail because Ollama doesn't support transcription diff --git a/test/unit/functions/scalar/llm_function_test_base_instantiations.cpp b/test/unit/functions/scalar/llm_function_test_base_instantiations.cpp index 6aaba0fa..82fa984c 100644 --- a/test/unit/functions/scalar/llm_function_test_base_instantiations.cpp +++ b/test/unit/functions/scalar/llm_function_test_base_instantiations.cpp @@ -1,6 +1,7 @@ #include "flock/functions/scalar/llm_complete.hpp" #include "flock/functions/scalar/llm_embedding.hpp" #include "flock/functions/scalar/llm_filter.hpp" +#include "../../ollama_test_utils.hpp" #include "llm_function_test_base.hpp" namespace flock { @@ -9,12 +10,14 @@ namespace flock { template void LLMFunctionTestBase::SetUp() { auto con = Config::GetConnection(); + Config::ConfigureTables(con, ConfigType::LOCAL); con.Query(" CREATE SECRET (" " TYPE OPENAI," " API_KEY 'your-api-key');"); con.Query(" CREATE SECRET (" " TYPE OLLAMA," " API_URL '127.0.0.1:11434');"); + SeedOllamaTestModel(con); mock_provider = std::make_shared(ModelDetails{}); Model::SetMockProvider(mock_provider); diff --git a/test/unit/model_manager/model_manager_test.cpp b/test/unit/model_manager/model_manager_test.cpp index e67cde06..4536594c 100644 --- a/test/unit/model_manager/model_manager_test.cpp +++ b/test/unit/model_manager/model_manager_test.cpp @@ -1,4 +1,5 @@ #include "flock/model_manager/model.hpp" +#include "../ollama_test_utils.hpp" #include "nlohmann/json.hpp" #include #include @@ -11,6 +12,7 @@ class ModelManagerTest : public ::testing::Test { protected: void SetUp() override { auto con = Config::GetConnection(); + Config::ConfigureTables(con, ConfigType::LOCAL); con.Query(" CREATE SECRET (" " TYPE OPENAI," " API_KEY 'your-api-key');"); @@ -22,6 +24,13 @@ class ModelManagerTest : public ::testing::Test { con.Query(" CREATE SECRET (" " TYPE OLLAMA," " API_URL '127.0.0.1:11434');"); + con.Query(" DELETE FROM flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " + " WHERE model_name IN ('gpt-4o-test', 'azure-gpt-4o-mini');"); + con.Query(" INSERT INTO flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " + " (model_name, model, provider_name, model_args) " + " VALUES ('gpt-4o-test', 'gpt-4o', 'openai', '{}'), " + " ('azure-gpt-4o-mini', 'gpt-4o-mini', 'azure', '{}');"); + SeedOllamaTestModel(con); } void TearDown() override { @@ -99,7 +108,7 @@ TEST_F(ModelManagerTest, ProviderSelection) { }); // Test Ollama provider json ollama_config = { - {"model_name", "gemma3:4b"}}; + {"model_name", GetOllamaTestModelName()}}; EXPECT_NO_THROW({ Model ollama_model(ollama_config); EXPECT_EQ(ollama_model.GetModelDetails().provider_name, "ollama"); @@ -127,4 +136,4 @@ TEST_F(ModelManagerTest, GetModelDetails) { EXPECT_EQ(details.batch_size, 10); } -}// namespace flock \ No newline at end of file +}// namespace flock diff --git a/test/unit/model_manager/model_providers_test.cpp b/test/unit/model_manager/model_providers_test.cpp index 8d3410e4..7c5c1af8 100644 --- a/test/unit/model_manager/model_providers_test.cpp +++ b/test/unit/model_manager/model_providers_test.cpp @@ -1,4 +1,5 @@ #include "../functions/mock_provider.hpp" +#include "../ollama_test_utils.hpp" #include "flock/model_manager/providers/adapters/anthropic.hpp" #include "flock/model_manager/providers/adapters/azure.hpp" #include "flock/model_manager/providers/adapters/ollama.hpp" @@ -102,7 +103,7 @@ TEST(ModelProvidersTest, AzureProviderTest) { TEST(ModelProvidersTest, OllamaProviderTest) { ModelDetails model_details; model_details.model_name = "test_model"; - model_details.model = "gemma3:4b"; + model_details.model = GetOllamaTestModelName(); model_details.provider_name = "ollama"; model_details.model_parameters = {{"temperature", 0.7}}; model_details.secret = {{"api_url", "http://localhost:11434"}}; @@ -159,7 +160,7 @@ TEST(ModelProvidersTest, OllamaProviderTest) { TEST(ModelProvidersTest, OllamaProviderTranscriptionError) { ModelDetails model_details; model_details.model_name = "test_model"; - model_details.model = "gemma3:4b"; + model_details.model = GetOllamaTestModelName(); model_details.provider_name = "ollama"; model_details.model_parameters = {{"temperature", 0.7}}; model_details.secret = {{"api_url", "http://localhost:11434"}}; @@ -260,4 +261,4 @@ TEST(ModelProvidersTest, AnthropicProviderTypeTest) { EXPECT_EQ(GetProviderName(FLOCKMTL_ANTHROPIC), "anthropic"); } -}// namespace flock \ No newline at end of file +}// namespace flock diff --git a/test/unit/ollama_test_utils.hpp b/test/unit/ollama_test_utils.hpp new file mode 100644 index 00000000..04c2b626 --- /dev/null +++ b/test/unit/ollama_test_utils.hpp @@ -0,0 +1,59 @@ +#pragma once + +#include "flock/core/config.hpp" +#include "nlohmann/json.hpp" +#include +#include +#include +#include +#include + +namespace flock { + +inline std::string ReadCommandOutput(const std::string& command) { + std::array buffer {}; + std::string output; + + auto pipe = popen(command.c_str(), "r"); + if (!pipe) { + throw std::runtime_error("Failed to query Ollama models"); + } + + while (fgets(buffer.data(), static_cast(buffer.size()), pipe) != nullptr) { + output += buffer.data(); + } + + const auto status = pclose(pipe); + if (status != 0) { + throw std::runtime_error("Failed to query Ollama models. Is Ollama running?"); + } + + return output; +} + +inline std::string GetOllamaTestModelName() { + const auto* model_name = std::getenv("FLOCK_OLLAMA_TEST_MODEL"); + if (model_name && model_name[0] != '\0') { + return model_name; + } + + const auto response = ReadCommandOutput("curl -fsS http://127.0.0.1:11434/api/tags"); + const auto tags = nlohmann::json::parse(response); + if (!tags.contains("models") || !tags["models"].is_array() || tags["models"].empty()) { + throw std::runtime_error("No Ollama models are downloaded. Pull a model or set FLOCK_OLLAMA_TEST_MODEL."); + } + + return tags["models"][0]["name"].get(); +} + +inline void SeedOllamaTestModel(duckdb::Connection& con) { + const auto model_name = GetOllamaTestModelName(); + con.Query(" DELETE FROM flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " + " WHERE provider_name = 'ollama';"); + con.Query(" INSERT INTO flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE " + " (model_name, model, provider_name, model_args) " + " VALUES ('" + + model_name + "', '" + model_name + "', 'ollama', '{}');"); +} + +}// namespace flock diff --git a/test/unit/prompt_manager/prompt_manager_test.cpp b/test/unit/prompt_manager/prompt_manager_test.cpp index 9aff5918..be26f24d 100644 --- a/test/unit/prompt_manager/prompt_manager_test.cpp +++ b/test/unit/prompt_manager/prompt_manager_test.cpp @@ -11,6 +11,21 @@ namespace flock { using json = nlohmann::json; +void SeedProductSummaryPrompts() { + auto con = Config::GetConnection(); + Config::ConfigureTables(con, ConfigType::LOCAL); + con.Query(" DELETE FROM flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " + " WHERE prompt_name = 'product_summary';"); + con.Query(" INSERT INTO flock_config.FLOCKMTL_PROMPT_INTERNAL_TABLE " + " (prompt_name, prompt, version) " + " VALUES ('product_summary', " + " 'Summarize the product with a persuasive tone suitable for a sales page.', " + " 4), " + " ('product_summary', " + " 'Generate a summary with a focus on technical specifications.', " + " 6);"); +} + // Test cases for PromptManager::ToString TEST(PromptManager, ToString) { EXPECT_EQ(PromptManager::ToString(PromptSection::USER_PROMPT), "{{USER_PROMPT}}"); @@ -191,6 +206,7 @@ TEST(PromptManager, CreatePromptDetailsEmptyJson) { // Test with prompt_name and a specific version TEST(PromptManager, CreatePromptDetailsWithExplicitVersion) { + SeedProductSummaryPrompts(); const json prompt_json = { {"prompt_name", "product_summary"}, {"version", "4"}}; @@ -209,6 +225,7 @@ TEST(PromptManager, CreatePromptDetailsNonExistentPrompt) { // Test with a non-existent version of existing prompt TEST(PromptManager, CreatePromptDetailsNonExistentVersion) { + SeedProductSummaryPrompts(); const json prompt_json = { {"prompt_name", "product_summary"}, {"version", "999"}}; @@ -245,6 +262,7 @@ TEST(PromptManager, CreatePromptDetailsMultipleFieldsPromptOnly) { } TEST(PromptManager, CreatePromptDetailsOnlyPromptName) { + SeedProductSummaryPrompts(); const json prompt_json = {{"prompt_name", "product_summary"}}; const auto [prompt_name, prompt, version] = PromptManager::CreatePromptDetails(prompt_json); EXPECT_EQ(prompt_name, "product_summary"); @@ -425,4 +443,4 @@ TEST_F(TranscribeAudioColumnTest, TranscribeAudioColumnOutputFormat) { EXPECT_EQ(result["data"][0], expected_transcription); } -}// namespace flock \ No newline at end of file +}// namespace flock