Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions src/model_manager/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,31 @@ void Model::LoadModelDetails(const nlohmann::json& model_json) {
}

bool has_resolved_details = model_json.contains("model") &&
model_json.contains("provider") &&
model_json.contains("secret") &&
model_json.contains("tuple_format") &&
model_json.contains("batch_size");
model_json.contains("provider");

nlohmann::json db_model_args;

if (has_resolved_details) {
model_details_.model = model_json.at("model").get<std::string>();
model_details_.provider_name = model_json.at("provider").get<std::string>();
model_details_.secret = model_json["secret"].get<std::unordered_map<std::string, std::string>>();
model_details_.tuple_format = model_json.at("tuple_format").get<std::string>();
model_details_.batch_size = model_json.at("batch_size").get<int>();
if (model_json.contains("secret")) {
model_details_.secret = model_json["secret"].get<std::unordered_map<std::string, std::string>>();
} else {
auto secret_name = "__default_" + model_details_.provider_name;
if (model_details_.provider_name == AZURE) {
secret_name += "_llm";
}
if (model_json.contains("secret_name")) {
secret_name = model_json["secret_name"].get<std::string>();
}
model_details_.secret = SecretManager::GetSecret(secret_name);
}
model_details_.tuple_format = model_json.contains("tuple_format")
? model_json.at("tuple_format").get<std::string>()
: "XML";
model_details_.batch_size = model_json.contains("batch_size")
? model_json.at("batch_size").get<int>()
: 2048;

if (model_json.contains("model_parameters")) {
auto& mp = model_json.at("model_parameters");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "../mock_provider.hpp"
#include "../../ollama_test_utils.hpp"
#include "flock/core/config.hpp"
#include "flock/functions/aggregate/aggregate.hpp"
#include "flock/model_manager/model.hpp"
Expand All @@ -24,12 +25,14 @@ class LLMAggregateTestBase : public ::testing::Test {

void SetUp() override {
auto con = Config::GetConnection();
Config::ConfigureTables(con, ConfigType::LOCAL);
con.Query(" CREATE SECRET ("
" TYPE OPENAI,"
" API_KEY 'your-api-key');");
con.Query(" CREATE SECRET ("
" TYPE OLLAMA,"
" API_URL '127.0.0.1:11434');");
SeedOllamaTestModel(con);

// Create a shared mock provider for expectations
mock_provider = std::make_shared<MockProvider>(ModelDetails{});
Expand Down
8 changes: 6 additions & 2 deletions test/unit/functions/aggregate/llm_first.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "flock/functions/aggregate/llm_first_or_last.hpp"
#include "../../ollama_test_utils.hpp"
#include "llm_aggregate_function_test_base.hpp"

namespace flock {
Expand Down Expand Up @@ -185,17 +186,20 @@ TEST_F(LLMFirstTest, AudioTranscription) {
// Test audio transcription error handling for Ollama
TEST_F(LLMFirstTest, AudioTranscriptionOllamaError) {
auto con = Config::GetConnection();
const auto ollama_model = GetOllamaTestModelName();
EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_))
.WillOnce(::testing::Throw(std::runtime_error("Audio transcription is not currently supported by Ollama.")));

const auto results = con.Query(
"SELECT llm_first("
"{'model_name': 'gemma3:4b'}, "
"{'model_name': '" +
ollama_model + "'}, "
"{'prompt': 'Select the best audio', "
"'context_columns': ["
"{'data': audio_url, "
"'type': 'audio', "
"'transcription_model': 'gemma3:4b'}"
"'transcription_model': '" +
ollama_model + "'}"
"]}) AS result FROM VALUES "
"('https://example.com/audio1.mp3'), "
"('https://example.com/audio2.mp3') AS tbl(audio_url);");
Expand Down
8 changes: 6 additions & 2 deletions test/unit/functions/aggregate/llm_last.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "flock/functions/aggregate/llm_first_or_last.hpp"
#include "../../ollama_test_utils.hpp"
#include "llm_aggregate_function_test_base.hpp"

namespace flock {
Expand Down Expand Up @@ -185,17 +186,20 @@ TEST_F(LLMLastTest, AudioTranscription) {
// Test audio transcription error handling for Ollama
TEST_F(LLMLastTest, AudioTranscriptionOllamaError) {
auto con = Config::GetConnection();
const auto ollama_model = GetOllamaTestModelName();
EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_))
.WillOnce(::testing::Throw(std::runtime_error("Audio transcription is not currently supported by Ollama.")));

const auto results = con.Query(
"SELECT llm_last("
"{'model_name': 'gemma3:4b'}, "
"{'model_name': '" +
ollama_model + "'}, "
"{'prompt': 'Select the worst audio', "
"'context_columns': ["
"{'data': audio_url, "
"'type': 'audio', "
"'transcription_model': 'gemma3:4b'}"
"'transcription_model': '" +
ollama_model + "'}"
"]}) AS result FROM VALUES "
"('https://example.com/audio1.mp3'), "
"('https://example.com/audio2.mp3') AS tbl(audio_url);");
Expand Down
8 changes: 6 additions & 2 deletions test/unit/functions/aggregate/llm_reduce.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "flock/functions/aggregate/llm_reduce.hpp"
#include "../../ollama_test_utils.hpp"
#include "llm_aggregate_function_test_base.hpp"

namespace flock {
Expand Down Expand Up @@ -209,17 +210,20 @@ TEST_F(LLMReduceTest, AudioAndTextColumns) {
// Test audio transcription error handling for Ollama
TEST_F(LLMReduceTest, AudioTranscriptionOllamaError) {
auto con = Config::GetConnection();
const auto ollama_model = GetOllamaTestModelName();
EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_))
.WillOnce(::testing::Throw(std::runtime_error("Audio transcription is not currently supported by Ollama.")));

const auto results = con.Query(
"SELECT llm_reduce("
"{'model_name': 'gemma3:4b'}, "
"{'model_name': '" +
ollama_model + "'}, "
"{'prompt': 'Summarize this audio', "
"'context_columns': ["
"{'data': audio_url, "
"'type': 'audio', "
"'transcription_model': 'gemma3:4b'}"
"'transcription_model': '" +
ollama_model + "'}"
"]}) AS result FROM VALUES ('https://example.com/audio.mp3') AS tbl(audio_url);");

ASSERT_TRUE(results->HasError());
Expand Down
8 changes: 6 additions & 2 deletions test/unit/functions/aggregate/llm_rerank.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "flock/functions/aggregate/llm_rerank.hpp"
#include "../../ollama_test_utils.hpp"
#include "llm_aggregate_function_test_base.hpp"
#include <numeric>

Expand Down Expand Up @@ -196,17 +197,20 @@ TEST_F(LLMRerankTest, AudioTranscription) {
// Test audio transcription error handling for Ollama
TEST_F(LLMRerankTest, AudioTranscriptionOllamaError) {
auto con = Config::GetConnection();
const auto ollama_model = GetOllamaTestModelName();
EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_))
.WillOnce(::testing::Throw(std::runtime_error("Audio transcription is not currently supported by Ollama.")));

const auto results = con.Query(
"SELECT llm_rerank("
"{'model_name': 'gemma3:4b'}, "
"{'model_name': '" +
ollama_model + "'}, "
"{'prompt': 'Rank these audio files', "
"'context_columns': ["
"{'data': audio_url, "
"'type': 'audio', "
"'transcription_model': 'gemma3:4b'}"
"'transcription_model': '" +
ollama_model + "'}"
"]}) AS result FROM VALUES "
"('https://example.com/audio1.mp3'), "
"('https://example.com/audio2.mp3') AS tbl(audio_url);");
Expand Down
10 changes: 7 additions & 3 deletions test/unit/functions/scalar/llm_complete.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "flock/functions/scalar/llm_complete.hpp"
#include "../../ollama_test_utils.hpp"
#include "llm_function_test_base.hpp"

namespace flock {
Expand Down Expand Up @@ -227,19 +228,22 @@ TEST_F(LLMCompleteTest, LLMCompleteWithAudioAndText) {
// Test audio transcription error handling
TEST_F(LLMCompleteTest, LLMCompleteAudioTranscriptionError) {
auto con = Config::GetConnection();
const auto ollama_model = GetOllamaTestModelName();
// Mock transcription model to throw error (simulating Ollama behavior)
EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_))
.WillOnce(::testing::Throw(std::runtime_error("Audio transcription is not currently supported by Ollama.")));

// Test with Ollama which doesn't support transcription
const auto results = con.Query(
"SELECT llm_complete("
"{'model_name': 'gemma3:4b'}, "
"{'model_name': '" +
ollama_model + "'}, "
"{'prompt': 'Summarize this audio', "
"'context_columns': ["
"{'data': audio_url, "
"'type': 'audio', "
"'transcription_model': 'gemma3:4b'}"
"'transcription_model': '" +
ollama_model + "'}"
"]}) AS result FROM VALUES ('https://example.com/audio.mp3') AS tbl(audio_url);");

// Should fail because Ollama doesn't support transcription
Expand All @@ -262,4 +266,4 @@ TEST_F(LLMCompleteTest, LLMCompleteAudioMissingTranscriptionModel) {
ASSERT_TRUE(results->HasError());
}

}// namespace flock
}// namespace flock
8 changes: 6 additions & 2 deletions test/unit/functions/scalar/llm_filter.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "flock/functions/scalar/llm_filter.hpp"
#include "../../ollama_test_utils.hpp"
#include "llm_function_test_base.hpp"

namespace flock {
Expand Down Expand Up @@ -196,6 +197,7 @@ TEST_F(LLMFilterTest, LLMFilterWithAudioAndText) {
// Test audio transcription error handling for Ollama
TEST_F(LLMFilterTest, LLMFilterAudioTranscriptionOllamaError) {
auto con = Config::GetConnection();
const auto ollama_model = GetOllamaTestModelName();

// Mock transcription model to throw error (simulating Ollama behavior)
EXPECT_CALL(*mock_provider, AddTranscriptionRequest(::testing::_))
Expand All @@ -204,12 +206,14 @@ TEST_F(LLMFilterTest, LLMFilterAudioTranscriptionOllamaError) {
// Test with Ollama which doesn't support transcription
const auto results = con.Query(
"SELECT llm_filter("
"{'model_name': 'gemma3:4b'}, "
"{'model_name': '" +
ollama_model + "'}, "
"{'prompt': 'Is the sentiment positive?', "
"'context_columns': ["
"{'data': audio_url, "
"'type': 'audio', "
"'transcription_model': 'gemma3:4b'}"
"'transcription_model': '" +
ollama_model + "'}"
"]}) AS result FROM VALUES ('https://example.com/audio.mp3') AS tbl(audio_url);");

// Should fail because Ollama doesn't support transcription
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "flock/functions/scalar/llm_complete.hpp"
#include "flock/functions/scalar/llm_embedding.hpp"
#include "flock/functions/scalar/llm_filter.hpp"
#include "../../ollama_test_utils.hpp"
#include "llm_function_test_base.hpp"

namespace flock {
Expand All @@ -9,12 +10,14 @@ namespace flock {
template<typename FunctionClass>
void LLMFunctionTestBase<FunctionClass>::SetUp() {
auto con = Config::GetConnection();
Config::ConfigureTables(con, ConfigType::LOCAL);
con.Query(" CREATE SECRET ("
" TYPE OPENAI,"
" API_KEY 'your-api-key');");
con.Query(" CREATE SECRET ("
" TYPE OLLAMA,"
" API_URL '127.0.0.1:11434');");
SeedOllamaTestModel(con);

mock_provider = std::make_shared<MockProvider>(ModelDetails{});
Model::SetMockProvider(mock_provider);
Expand Down
13 changes: 11 additions & 2 deletions test/unit/model_manager/model_manager_test.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "flock/model_manager/model.hpp"
#include "../ollama_test_utils.hpp"
#include "nlohmann/json.hpp"
#include <gtest/gtest.h>
#include <memory>
Expand All @@ -11,6 +12,7 @@ class ModelManagerTest : public ::testing::Test {
protected:
void SetUp() override {
auto con = Config::GetConnection();
Config::ConfigureTables(con, ConfigType::LOCAL);
con.Query(" CREATE SECRET ("
" TYPE OPENAI,"
" API_KEY 'your-api-key');");
Expand All @@ -22,6 +24,13 @@ class ModelManagerTest : public ::testing::Test {
con.Query(" CREATE SECRET ("
" TYPE OLLAMA,"
" API_URL '127.0.0.1:11434');");
con.Query(" DELETE FROM flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE "
" WHERE model_name IN ('gpt-4o-test', 'azure-gpt-4o-mini');");
con.Query(" INSERT INTO flock_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE "
" (model_name, model, provider_name, model_args) "
" VALUES ('gpt-4o-test', 'gpt-4o', 'openai', '{}'), "
" ('azure-gpt-4o-mini', 'gpt-4o-mini', 'azure', '{}');");
SeedOllamaTestModel(con);
}

void TearDown() override {
Expand Down Expand Up @@ -99,7 +108,7 @@ TEST_F(ModelManagerTest, ProviderSelection) {
});
// Test Ollama provider
json ollama_config = {
{"model_name", "gemma3:4b"}};
{"model_name", GetOllamaTestModelName()}};
EXPECT_NO_THROW({
Model ollama_model(ollama_config);
EXPECT_EQ(ollama_model.GetModelDetails().provider_name, "ollama");
Expand Down Expand Up @@ -127,4 +136,4 @@ TEST_F(ModelManagerTest, GetModelDetails) {
EXPECT_EQ(details.batch_size, 10);
}

}// namespace flock
}// namespace flock
7 changes: 4 additions & 3 deletions test/unit/model_manager/model_providers_test.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "../functions/mock_provider.hpp"
#include "../ollama_test_utils.hpp"
#include "flock/model_manager/providers/adapters/anthropic.hpp"
#include "flock/model_manager/providers/adapters/azure.hpp"
#include "flock/model_manager/providers/adapters/ollama.hpp"
Expand Down Expand Up @@ -102,7 +103,7 @@ TEST(ModelProvidersTest, AzureProviderTest) {
TEST(ModelProvidersTest, OllamaProviderTest) {
ModelDetails model_details;
model_details.model_name = "test_model";
model_details.model = "gemma3:4b";
model_details.model = GetOllamaTestModelName();
model_details.provider_name = "ollama";
model_details.model_parameters = {{"temperature", 0.7}};
model_details.secret = {{"api_url", "http://localhost:11434"}};
Expand Down Expand Up @@ -159,7 +160,7 @@ TEST(ModelProvidersTest, OllamaProviderTest) {
TEST(ModelProvidersTest, OllamaProviderTranscriptionError) {
ModelDetails model_details;
model_details.model_name = "test_model";
model_details.model = "gemma3:4b";
model_details.model = GetOllamaTestModelName();
model_details.provider_name = "ollama";
model_details.model_parameters = {{"temperature", 0.7}};
model_details.secret = {{"api_url", "http://localhost:11434"}};
Expand Down Expand Up @@ -260,4 +261,4 @@ TEST(ModelProvidersTest, AnthropicProviderTypeTest) {
EXPECT_EQ(GetProviderName(FLOCKMTL_ANTHROPIC), "anthropic");
}

}// namespace flock
}// namespace flock
Loading