diff --git a/.gitignore b/.gitignore index 6b94124ed..5a3a16d8f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,5 @@ version.h ports/docwire/.disable_binary_cache .cache compile_commands.json +.zed +.clang-format diff --git a/CMakeLists.txt b/CMakeLists.txt index 98f67adbe..e466e68b1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,6 +2,9 @@ cmake_minimum_required(VERSION 3.17) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) # for VSCode CLangd extension +option(DOCWIRE_LOCAL_CT2 "Enable local AI (Translation / Text Generation / Embedding models)" OFF) +option(DOCWIRE_LLAMA "Enable llama.cpp engine" OFF) + # Get version from ChangeLog.md and store it in DOCWIRE_VERSION and SIMPLE_DOCWIRE_VERSION function(extract_version) file(READ "${CMAKE_CURRENT_SOURCE_DIR}/doc/ChangeLog.md" changelog_text) diff --git a/README.md b/README.md index def8b4e8c..87e64f6aa 100644 --- a/README.md +++ b/README.md @@ -365,7 +365,7 @@ std::filesystem::path("test.zip") | content_type::detector{} | archives_parser{} Classify file in any format (Office, PDF, mail, etc) to any categories using build-in local AI model: ```cpp -std::filesystem::path("...") | ... | local_ai::model_chain_element("Classify to...: agreement, invoice, report...") | out_stream; +std::filesystem::path("...") | ... | ai::local::task("Classify to...: agreement, invoice, report...") | out_stream; ensure(out_stream.str()) == "report"; ``` [Full example](https://docwire.readthedocs.io/en/latest/local_ai_classify_8cpp-example.html) @@ -381,7 +381,7 @@ ensure(out_stream.str()) == "report\n"; Translate document in any format (Office, PDF, mail, etc) to other language using build-in local AI model: ```cpp -std::filesystem::path("...") | ... | local_ai::model_chain_element("Translate to spanish:\n\n") | out_stream; +std::filesystem::path("...") | ... | ai::local::translate("spanish") | out_stream; ensure(fuzzy_match::ratio(out_stream.str(), "La procesación de datos se refiere a las actividades...")) > 80; ``` [Full example](https://docwire.readthedocs.io/en/latest/local_ai_translate_8cpp-example.html) @@ -397,7 +397,7 @@ ensure(fuzzy_match::ratio(out_stream.str(), "El procesamiento de datos se refier Detect sentiment of document in any format (Office, PDF, mail, etc) using build-in local AI model: ```cpp -std::filesystem::path("...") | ... | local_ai::model_chain_element("Detect sentiment:\n\n") | out_stream; +std::filesystem::path("...") | ... | ai::local::task("Detect sentiment:\n\n") | out_stream; ensure(out_stream.str()) == "positive"; ``` [Full example](https://docwire.readthedocs.io/en/latest/local_ai_sentiment_8cpp-example.html) @@ -412,7 +412,7 @@ std::filesystem::path("1.doc") | ... | openai::DetectSentiment(...) | std::cout; Make a summary of document in any format (Office, PDF, mail, etc) using build-in local AI model: ```cpp -std::filesystem::path("...") | ... | local_ai::model_chain_element("Write a short summary...") | out_stream; +std::filesystem::path("...") | ... | ai::local::summarize() | out_stream; ensure(out_stream.str()).is_one_of({ "Data processing is the collection, organization, analysis, and interpretation of data to extract useful insights and support decision-making."... ``` [Full example](https://docwire.readthedocs.io/en/latest/local_ai_summary_8cpp-example.html) @@ -435,7 +435,7 @@ ensure(fuzzy_match::ratio(out_stream.str(), "Data processing involves converting Find phrases, objects and events with smart matching in documents in any format (Office, PDF, mail, etc) using build-in local AI model: ```cpp -std::filesystem::path("...") | ... | local_ai::model_chain_element("Find sentence about \"data conversion\"...") | out_stream; +std::filesystem::path("...") | ... | ai::local::task("Find sentence about \"data conversion\"...") | out_stream; ensure(out_stream.str()).is_one_of({ "Data processing refers to the activities performed on raw data to convert it into meaningful information."... ``` [Full example](https://docwire.readthedocs.io/en/latest/local_ai_find_8cpp-example.html) @@ -461,9 +461,9 @@ ensure(out_msgs[0]->get().values.size()) == 1536; Create embedding for document in any format (Office, PDF, mail, etc) using build-in local AI model, create embeddings for two queries and calculate similarity: ```cpp -std::filesystem::path("data_processing_definition.doc") | ... | local_ai::embed(local_ai::embed::e5_passage_prefix) | passage_msgs; +std::filesystem::path("data_processing_definition.doc") | ... | ai::local::passage::embedder{} | passage_msgs; ... -docwire::data_source{std::string{"What is data processing?"}, ...} | local_ai::embed(local_ai::embed::e5_query_prefix) | similar_query_msgs; +docwire::data_source{std::string{"What is data processing?"}, ...} | ai::local::query::embedder{} | similar_query_msgs; ... double sim = cosine_similarity(passage_embedding.values, similar_query_embedding.values); ... diff --git a/build_demo.sh b/build_demo.sh new file mode 100755 index 000000000..1247a5c56 --- /dev/null +++ b/build_demo.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash +set -e # stop on first error + +# ========================== +# CONFIGURATION — adjust paths +# ========================== +VCPKG_TOOLCHAIN=/home/reeshabh/coderepos/docwire/vcpkg/scripts/buildsystems/vcpkg.cmake +VCPKG_TRIPLET=x64-linux-dynamic +DOCWIRE_DIR=/home/reeshabh/coderepos/docwire/vcpkg/installed/x64-linux-dynamic/share/docwire +BUILD_DIR=./build +DEMO_EXEC=demo + + +# ========================== +# CLEAN BUILD +# ========================== +echo "Cleaning old build folder..." +rm -rf "$BUILD_DIR" +mkdir "$BUILD_DIR" +cd "$BUILD_DIR" + +# ========================== +# CONFIGURE CMAKE +# ========================== +echo "Configuring CMake..." +cmake .. \ + -DCMAKE_TOOLCHAIN_FILE="$VCPKG_TOOLCHAIN" \ + -DVCPKG_TARGET_TRIPLET="$VCPKG_TRIPLET" \ + -Ddocwire_DIR="$DOCWIRE_DIR" \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON + +# ========================== +# BUILD +# ========================== +echo "Building project..." +cmake --build . + +# ========================== +# SET LIBRARY PATH +# ========================== +export LD_LIBRARY_PATH=/home/reeshabh/coderepos/docwire/vcpkg/installed/x64-linux-dynamic/lib:$LD_LIBRARY_PATH + +# ========================== +# RUN DEMO +# ========================== +# echo "Running demo..." +# ./$DEMO_EXEC diff --git a/download_llama_model.sh b/download_llama_model.sh new file mode 100644 index 000000000..3314c92f9 --- /dev/null +++ b/download_llama_model.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash +set -e + +# Configurable defaults + +MODEL_NAME="${MODEL_NAME:-qwen2-7b-instruct}" +MODEL_QUANT="${MODEL_QUANT:-q4_k_m}" +MODEL_REPO="${MODEL_REPO:-Qwen/Qwen2-7B-Instruct-GGUF}" + +# Derived values +MODEL_FILE="${MODEL_NAME}.${MODEL_QUANT}.gguf" +OUTPUT_DIR="${OUTPUT_DIR:-models}" +HF_URL="https://huggingface.co/${MODEL_REPO}/resolve/main/${MODEL_FILE}" + +# Checks +if ! command -v wget &> /dev/null && ! command -v curl &> /dev/null; then + echo "Error: Neither wget nor curl is installed." + exit 1 +fi + +mkdir -p "${OUTPUT_DIR}" +cd "${OUTPUT_DIR}" + +if [ -f "${MODEL_FILE}" ]; then + echo "Model already exists: ${OUTPUT_DIR}/${MODEL_FILE}" + exit 0 +fi + +echo "Downloading model:" +echo " Repository : ${MODEL_REPO}" +echo " File : ${MODEL_FILE}" +echo " Destination: ${OUTPUT_DIR}" +echo "" + +# Download + +if command -v wget &> /dev/null; then + wget -c "${HF_URL}" +else + curl -L -C - -o "${MODEL_FILE}" "${HF_URL}" +fi + +echo "" +echo "Download complete." +echo "Model saved to: ${OUTPUT_DIR}/${MODEL_FILE}" diff --git a/ports/docwire/portfile.cmake b/ports/docwire/portfile.cmake index a8fa6b041..f2bbfdbf8 100644 --- a/ports/docwire/portfile.cmake +++ b/ports/docwire/portfile.cmake @@ -13,6 +13,8 @@ vcpkg_check_features(OUT_FEATURE_OPTIONS FEATURE_OPTIONS asan ADDRESS_SANITIZER tsan THREAD_SANITIZER helgrind HELGRIND_ENABLED + local-ai DOCWIRE_LOCAL_CT2 + llama-engine DOCWIRE_LLAMA ) if(DEFINED ENV{CMAKE_MESSAGE_LOG_LEVEL}) diff --git a/ports/docwire/vcpkg.json b/ports/docwire/vcpkg.json index 4a933ee66..78451b74a 100644 --- a/ports/docwire/vcpkg.json +++ b/ports/docwire/vcpkg.json @@ -28,6 +28,32 @@ "callgrind": { "description": "Enable valgrind callgrind in automatic tests" + }, + "local-ai": + { + "description": "Enable local AI runtime", + "dependencies": [ + "ctranslate2", + "sentencepiece", + "multilingual-e5-small-ct2-int8", + "flan-t5-large-ct2-int8" + ] + }, + "llama-engine": + { + "description": "Enable GGUF-based LLM inference (llama.cpp)", + "dependencies": [ + { "name": "docwire", "features": ["local-ai"] }, + "llama-cpp" + ] + }, + "llama-qwen": + { + "description": "Install Qwen2 7B GGUF model", + "dependencies": [ + { "name": "docwire", "features": ["llama-engine"] }, + "qwen2-7b-instruct-q4-k-m" + ] } }, "dependencies": [ @@ -99,18 +125,6 @@ { "name": "tessdata-fast" }, - { - "name": "ctranslate2" - }, - { - "name": "sentencepiece" - }, - { - "name": "flan-t5-large-ct2-int8" - }, - { - "name": "multilingual-e5-small-ct2-int8" - }, { "name": "rapidfuzz-cpp" }, diff --git a/ports/qwen2-7b-instruct-q4-k-m/portfile.cmake b/ports/qwen2-7b-instruct-q4-k-m/portfile.cmake new file mode 100644 index 000000000..e0685611f --- /dev/null +++ b/ports/qwen2-7b-instruct-q4-k-m/portfile.cmake @@ -0,0 +1,21 @@ +set(MODEL_NAME "qwen2-7b-instruct") +set(MODEL_QUANT "q4_k_m") + +set(MODEL_FILE "${MODEL_NAME}-${MODEL_QUANT}.gguf") + +vcpkg_download_distfile( + MODEL_ARCHIVE + URLS "https://huggingface.co/Qwen/Qwen2-7B-Instruct-GGUF/resolve/main/${MODEL_FILE}" + FILENAME "${MODEL_FILE}" + SHA512 39c1f9702856cf5faff13b672033c5c99246b5393550ed58ab9ba0eb2d5ce5d50cc2710b2d9f08d51ad6ce7f6b66826f5c916128fa06c3ac2f78865e167146b8 +) + +file(INSTALL + ${MODEL_ARCHIVE} + DESTINATION ${CURRENT_PACKAGES_DIR}/share/${PORT} +) + +file(WRITE + ${CURRENT_PACKAGES_DIR}/share/${PORT}/copyright + "Model weights from HuggingFace repository Qwen/Qwen2-7B-Instruct-GGUF." +) diff --git a/ports/qwen2-7b-instruct-q4-k-m/vcpkg.json b/ports/qwen2-7b-instruct-q4-k-m/vcpkg.json new file mode 100644 index 000000000..405b9cfc1 --- /dev/null +++ b/ports/qwen2-7b-instruct-q4-k-m/vcpkg.json @@ -0,0 +1,7 @@ +{ + "name": "qwen2-7b-instruct-q4-k-m", + "version": "1.0.0", + "description": "Qwen2 7B Instruct GGUF model", + "homepage": "https://huggingface.co/Qwen/Qwen2-7B-Instruct-GGUF", + "license": "Apache-2.0" +} diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b9e4c22fc..2a2bf322d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -23,7 +23,19 @@ include(ocr.cmake) include(mail.cmake) include(archives.cmake) include(ai.cmake) -include(local_ai.cmake) + +if(DOCWIRE_LOCAL_CT2) + include(ai_ct2.cmake) +endif() + +if(DOCWIRE_LLAMA) + include(ai_llama.cmake) +endif() + +if(DOCWIRE_LOCAL_CT2 OR DOCWIRE_LLAMA) + include(local_ai.cmake) +endif() + include(fuzzy_match.cmake) include(content_type.cmake) diff --git a/src/ai.cmake b/src/ai.cmake index d345356fd..ff188430c 100644 --- a/src/ai.cmake +++ b/src/ai.cmake @@ -1,13 +1,7 @@ -set(EMPTY_SOURCE ${CMAKE_CURRENT_BINARY_DIR}/ai_empty.cpp) -file(GENERATE OUTPUT ${EMPTY_SOURCE} CONTENT " - #include \"ai_elements.h\" - namespace docwire::ai - { - // This dummy function is required to ensure that the shared library is created. - DOCWIRE_AI_EXPORT void dummy_function_for_docwire_ai() {} - } -") -add_library(docwire_ai SHARED ${EMPTY_SOURCE}) + +add_library(docwire_ai SHARED model_chain_element.cpp ai_summarize.cpp ai_translate.cpp ai_embed.cpp ai_task.cpp) + +target_link_libraries(docwire_ai PUBLIC docwire_core) target_compile_features(docwire_ai PUBLIC cxx_std_20) if(MSVC) @@ -27,4 +21,4 @@ endif() include(GenerateExportHeader) generate_export_header(docwire_ai EXPORT_FILE_NAME ai_export.h) target_include_directories(docwire_ai PUBLIC $) -install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ai_export.h DESTINATION include/docwire) \ No newline at end of file +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ai_export.h DESTINATION include/docwire) diff --git a/src/ai_ct2.cmake b/src/ai_ct2.cmake new file mode 100644 index 000000000..cab93c5e1 --- /dev/null +++ b/src/ai_ct2.cmake @@ -0,0 +1,51 @@ + +message(STATUS "DOCWIRE_LOCAL_CT2 enabled: Building CT2 backend.") + +add_library(docwire_ai_ct2 SHARED ct2_runner.cpp tokenizer.cpp) + +target_compile_definitions(docwire_ai_ct2 PUBLIC DOCWIRE_LOCAL_CT2) + +find_package(Boost REQUIRED COMPONENTS filesystem system json) +find_package(ctranslate2 CONFIG REQUIRED) +find_library(sentencepiece_LIBRARIES sentencepiece REQUIRED) + +if(MSVC) + find_package(absl CONFIG REQUIRED) + list(APPEND sentencepiece_LIBRARIES + absl::strings + absl::flags + absl::flags_parse + absl::log + absl::check) + + find_package(protobuf CONFIG REQUIRED) + list(APPEND sentencepiece_LIBRARIES protobuf::libprotobuf-lite) +endif() + +target_link_libraries(docwire_ai_ct2 PRIVATE docwire_core docwire_ai Boost::filesystem Boost::system Boost::json CTranslate2::ctranslate2 ${sentencepiece_LIBRARIES}) + +docwire_find_resource(FLAN_T5_FULL_PATH REL_PATH "flan-t5-large-ct2-int8" REQUIRED) +docwire_target_resources(docwire_ai_ct2 "flan-t5-large-ct2-int8" SOURCE "${FLAN_T5_FULL_PATH}") + +docwire_find_resource(E5_MODEL_FULL_PATH REL_PATH "multilingual-e5-small-ct2-int8" REQUIRED) +docwire_target_resources(docwire_ai_ct2 "multilingual-e5-small-ct2-int8" SOURCE "${E5_MODEL_FULL_PATH}") + +if(MSVC) + install(FILES $ DESTINATION bin CONFIGURATIONS Debug) +endif() + +include(GenerateExportHeader) + +generate_export_header(docwire_ai_ct2 EXPORT_FILE_NAME ai_ct2_export.h) + +target_include_directories(docwire_ai_ct2 PUBLIC + $ + $ +) + +install(TARGETS docwire_ai_ct2 EXPORT docwire_targets) + +install(FILES + ${CMAKE_CURRENT_BINARY_DIR}/ai_ct2_export.h + DESTINATION include/docwire +) diff --git a/src/ai_embed.cpp b/src/ai_embed.cpp new file mode 100644 index 000000000..e81e9faaf --- /dev/null +++ b/src/ai_embed.cpp @@ -0,0 +1,59 @@ +/*********************************************************************************************************************************************/ +/* DocWire SDK: Award-winning modern data processing in C++20. SourceForge Community Choice & Microsoft support. AI-driven processing. */ +/* Supports nearly 100 data formats, including email boxes and OCR. Boost efficiency in text extraction, web data extraction, data mining, */ +/* document analysis. Offline processing possible for security and confidentiality */ +/* */ +/* Copyright (c) SILVERCODERS Ltd, http://silvercoders.com */ +/* Project homepage: https://github.com/docwire/docwire */ +/* */ +/* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ +/*********************************************************************************************************************************************/ + +#include "ai_embed.h" +#include "ai_elements.h" +#include "data_source.h" +#include "error_tags.h" +#include "log_scope.h" +#include "serialization_message.h" // IWYU pragma: keep +#include +#include "throw_if.h" +#include + +namespace docwire +{ + +template <> +struct pimpl_impl : pimpl_impl_base +{ + std::shared_ptr m_model_runner; + std::string m_prefix; + + pimpl_impl(std::shared_ptr model_runner, std::string prefix) + : m_model_runner(std::move(model_runner)), m_prefix(std::move(prefix)) + { + } +}; + +namespace ai +{ + +embed::embed(std::shared_ptr model_runner, std::string prefix) + : with_pimpl(std::move(model_runner), std::move(prefix)) +{} + +continuation embed::operator()(message_ptr msg, const message_callbacks& emit_message) +{ + log_scope(msg); + if (!msg->is()) + return emit_message(std::move(msg)); + + const data_source& data = msg->get(); + throw_if(!data.has_highest_confidence_mime_type_in({mime_type{"text/plain"}}), "Input for ai::embed must be text/plain", errors::program_logic{}); + std::string data_str = data.string(); + + std::string prefixed_input = impl().m_prefix + data_str; + std::vector embedding_vector = impl().m_model_runner->embed(prefixed_input); + return emit_message(ai::embedding{std::move(embedding_vector)}); +} +} // namespace ai +} // namespace docwire diff --git a/src/ai_embed.h b/src/ai_embed.h new file mode 100644 index 000000000..ff69ac2b2 --- /dev/null +++ b/src/ai_embed.h @@ -0,0 +1,43 @@ +/*********************************************************************************************************************************************/ +/* DocWire SDK: Award-winning modern data processing in C++20. SourceForge Community Choice & Microsoft support. AI-driven processing. */ +/* Supports nearly 100 data formats, including email boxes and OCR. Boost efficiency in text extraction, web data extraction, data mining, */ +/* document analysis. Offline processing possible for security and confidentiality */ +/* */ +/* Copyright (c) SILVERCODERS Ltd, http://silvercoders.com */ +/* Project homepage: https://github.com/docwire/docwire */ +/* */ +/* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ +/*********************************************************************************************************************************************/ + +#ifndef DOCWIRE_AI_EMBED_H +#define DOCWIRE_AI_EMBED_H + +#include "ai_export.h" +#include "ai_runner.h" +#include "chain_element.h" +#include "pimpl.h" + +namespace docwire::ai +{ + +class DOCWIRE_AI_EXPORT embed : public ChainElement, public with_pimpl +{ + public: + /** + * @brief Construct a local AI embed chain element with a specific model runner and prefix. + * + * @param ai_runner The model runner to use for generating embeddings. + * @param prefix The string to prepend to the input text. Use an empty string for no prefix. + */ + explicit embed(std::shared_ptr model_runner, std::string prefix); + continuation operator()(message_ptr msg, const message_callbacks& emit_message) override; + bool is_leaf() const override { return false; } + + private: + using with_pimpl::impl; + +}; + +} // namespace docwire::ai + +#endif // DOCWIRE_AI_EMBED_H diff --git a/src/ai_llama.cmake b/src/ai_llama.cmake new file mode 100644 index 000000000..876d2a403 --- /dev/null +++ b/src/ai_llama.cmake @@ -0,0 +1,26 @@ + +message(STATUS "DOCWIRE_LLAMA enabled: building llama backend") + +add_library(docwire_ai_llama SHARED llama_runner.cpp) + +find_package(llama CONFIG REQUIRED) + +target_link_libraries(docwire_ai_llama PRIVATE docwire_core docwire_ai llama) + +target_compile_definitions(docwire_ai_llama PUBLIC DOCWIRE_LLAMA) + +include(GenerateExportHeader) + +generate_export_header(docwire_ai_llama EXPORT_FILE_NAME ai_llama_export.h) + +target_include_directories(docwire_ai_llama PUBLIC + $ + $ +) + +install(TARGETS docwire_ai_llama EXPORT docwire_targets) + +install(FILES + ${CMAKE_CURRENT_BINARY_DIR}/ai_llama_export.h + DESTINATION include/docwire +) diff --git a/src/ai_runner.h b/src/ai_runner.h new file mode 100644 index 000000000..debec9490 --- /dev/null +++ b/src/ai_runner.h @@ -0,0 +1,74 @@ +/*********************************************************************************************************************************************/ +/* DocWire SDK: Award-winning modern data processing in C++20. SourceForge Community Choice & Microsoft support. AI-driven processing. */ +/* Supports nearly 100 data formats, including email boxes and OCR. Boost efficiency in text extraction, web data extraction, data mining, */ +/* document analysis. Offline processing possible for security and confidentiality */ +/* */ +/* Copyright (c) SILVERCODERS Ltd, http://silvercoders.com */ +/* Project homepage: https://github.com/docwire/docwire */ +/* */ +/* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ +/*********************************************************************************************************************************************/ + +#ifndef DOCWIRE_AI_RUNNER_H +#define DOCWIRE_AI_RUNNER_H + +#include "ai_export.h" +#include +#include +#include + +namespace docwire::ai { + +/** + * @brief Abstract interface for AI model runners. + * + * Implementations load / run / unload AI models and expose a minimal + * synchronous API for text processing and embedding generation. + * + * Thread-safety requirements (MANDATORY for all derived classes): + * - All public virtual methods (process, embed, unload) MUST be safe to + * call concurrently from multiple threads. + * - Implementations must internally synchronize access to shared resources + * (model handles, contexts, samplers, caches, global backend state, etc.). + * - unload() may be called concurrently with process()/embed(); implementations + * must either: + * * defer actual teardown until in-flight calls complete (reference counting, + * call guards, condition variables), or + * * make unload() idempotent and safe to call while other threads are active. + * - The destructor MUST NOT cause undefined behavior when other threads are making + * calls; prefer explicit lifetime management (guards) or documented external + * synchronization. + */ +class DOCWIRE_AI_EXPORT ai_runner { + public: + /** + * @brief Virtual destructor. + * + * Implementations should ensure safe destruction semantics in the presence + * of concurrent calls (see class-level thread-safety requirements). + */ + virtual ~ai_runner() = default; + + /** + * @brief Synchronously process input and return generated text. + * + * Must be thread-safe. + */ + virtual std::string process(const std::string& input) = 0; + + /** + * @brief Generate an embedding for the given input. + * + * Must be thread-safe. + */ + virtual std::vector embed(const std::string&) = 0; + /** + * @brief Unload the model and free associated resources. + * --!Must be thread-safe!-- and safe to call concurrently with process()/embed(). + */ + virtual void unload() = 0; +}; + +} // namespace docwire::ai + +#endif diff --git a/src/ai_summarize.cpp b/src/ai_summarize.cpp new file mode 100644 index 000000000..24844c2d1 --- /dev/null +++ b/src/ai_summarize.cpp @@ -0,0 +1,20 @@ +/*********************************************************************************************************************************************/ +/* DocWire SDK: Award-winning modern data processing in C++20. SourceForge Community Choice & Microsoft support. AI-driven processing. */ +/* Supports nearly 100 data formats, including email boxes and OCR. Boost efficiency in text extraction, web data extraction, data mining, */ +/* document analysis. Offline processing possible for security and confidentiality */ +/* */ +/* Copyright (c) SILVERCODERS Ltd, http://silvercoders.com */ +/* Project homepage: https://github.com/docwire/docwire */ +/* */ +/* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ +/*********************************************************************************************************************************************/ +#include "ai_summarize.h" + +namespace docwire::ai +{ + +summarize::summarize(std::shared_ptr runner, model_lifetime_policy lifetime) + : model_chain_element(summary_prompt, runner, lifetime) +{ +} +} // namespace docwire::ai diff --git a/src/ai_summarize.h b/src/ai_summarize.h new file mode 100644 index 000000000..8f2d914e7 --- /dev/null +++ b/src/ai_summarize.h @@ -0,0 +1,34 @@ +/*********************************************************************************************************************************************/ +/* DocWire SDK: Award-winning modern data processing in C++20. SourceForge Community Choice & + * Microsoft support. AI-driven processing. */ +/* Supports nearly 100 data formats, including email boxes and OCR. Boost efficiency in text + * extraction, web data extraction, data mining, */ +/* document analysis. Offline processing possible for security and confidentiality */ +/* */ +/* Copyright (c) SILVERCODERS Ltd, http://silvercoders.com */ +/* Project homepage: https://github.com/docwire/docwire */ +/* */ +/* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ +/*********************************************************************************************************************************************/ + +#ifndef DOCWIRE_AI_SUMMARIZE_H +#define DOCWIRE_AI_SUMMARIZE_H + +#include "ai_export.h" +#include "model_chain_element.h" + +namespace docwire::ai +{ + +class DOCWIRE_AI_EXPORT summarize : public model_chain_element +{ + public: + explicit summarize(std::shared_ptr runner, model_lifetime_policy lifetime = model_lifetime_policy::persistent); + protected: + static constexpr const char* summary_prompt = + "Your task is to summarize the text:\n\n"; +}; + +} // namespace docwire::ai + +#endif // DOCWIRE_AI_SUMMARIZE_H diff --git a/src/ai_task.cpp b/src/ai_task.cpp new file mode 100644 index 000000000..f0ee741ef --- /dev/null +++ b/src/ai_task.cpp @@ -0,0 +1,21 @@ +/*********************************************************************************************************************************************/ +/* DocWire SDK: Award-winning modern data processing in C++20. SourceForge Community Choice & Microsoft support. AI-driven processing. */ +/* Supports nearly 100 data formats, including email boxes and OCR. Boost efficiency in text extraction, web data extraction, data mining, */ +/* document analysis. Offline processing possible for security and confidentiality */ +/* */ +/* Copyright (c) SILVERCODERS Ltd, http://silvercoders.com */ +/* Project homepage: https://github.com/docwire/docwire */ +/* */ +/* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ +/*********************************************************************************************************************************************/ + +#include "ai_task.h" + +namespace docwire::ai +{ + +task::task(const std::string& prompt, std::shared_ptr runner,model_lifetime_policy lifetime) + : model_chain_element(prompt, runner, lifetime) +{ +} +} // namespace docwire::ai diff --git a/src/ai_task.h b/src/ai_task.h new file mode 100644 index 000000000..644125592 --- /dev/null +++ b/src/ai_task.h @@ -0,0 +1,29 @@ +/*********************************************************************************************************************************************/ +/* DocWire SDK: Award-winning modern data processing in C++20. SourceForge Community Choice & Microsoft support. AI-driven processing. */ +/* Supports nearly 100 data formats, including email boxes and OCR. Boost efficiency in text extraction, web data extraction, data mining, */ +/* document analysis. Offline processing possible for security and confidentiality */ +/* */ +/* Copyright (c) SILVERCODERS Ltd, http://silvercoders.com */ +/* Project homepage: https://github.com/docwire/docwire */ +/* */ +/* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ +/*********************************************************************************************************************************************/ + +#ifndef DOCWIRE_AI_TASK_H +#define DOCWIRE_AI_TASK_H + +#include "ai_export.h" +#include "model_chain_element.h" + +namespace docwire::ai +{ + +class DOCWIRE_AI_EXPORT task : public model_chain_element +{ + public: + explicit task(const std::string& prompt, std::shared_ptr runner, model_lifetime_policy lifetime = model_lifetime_policy::persistent); +}; + +} // namespace docwire::ai + +#endif // DOCWIRE_AI_TASK_H diff --git a/src/ai_translate.cpp b/src/ai_translate.cpp new file mode 100644 index 000000000..bd0f85ea9 --- /dev/null +++ b/src/ai_translate.cpp @@ -0,0 +1,23 @@ +/*********************************************************************************************************************************************/ +/* DocWire SDK: Award-winning modern data processing in C++20. SourceForge Community Choice & Microsoft support. AI-driven processing. */ +/* Supports nearly 100 data formats, including email boxes and OCR. Boost efficiency in text extraction, web data extraction, data mining, */ +/* document analysis. Offline processing possible for security and confidentiality */ +/* */ +/* Copyright (c) SILVERCODERS Ltd, http://silvercoders.com */ +/* Project homepage: https://github.com/docwire/docwire */ +/* */ +/* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ +/*********************************************************************************************************************************************/ + +#include "ai_translate.h" +#include "model_chain_element.h" + +namespace docwire::ai +{ + +translate::translate(const std::string& language, std::shared_ptr runner, model_lifetime_policy lifetime) + : model_chain_element( + "Your task is to translate every message to " + language + " language.\n\n", runner, lifetime) +{ +} +} // namespace docwire::ai diff --git a/src/ai_translate.h b/src/ai_translate.h new file mode 100644 index 000000000..3716f9be0 --- /dev/null +++ b/src/ai_translate.h @@ -0,0 +1,29 @@ +/*********************************************************************************************************************************************/ +/* DocWire SDK: Award-winning modern data processing in C++20. SourceForge Community Choice & Microsoft support. AI-driven processing. */ +/* Supports nearly 100 data formats, including email boxes and OCR. Boost efficiency in text extraction, web data extraction, data mining, */ +/* document analysis. Offline processing possible for security and confidentiality */ +/* */ +/* Copyright (c) SILVERCODERS Ltd, http://silvercoders.com */ +/* Project homepage: https://github.com/docwire/docwire */ +/* */ +/* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ +/*********************************************************************************************************************************************/ + +#ifndef DOCWIRE_AI_TRANSLATE_H +#define DOCWIRE_AI_TRANSLATE_H + +#include "ai_export.h" +#include "model_chain_element.h" + +namespace docwire::ai +{ + +class DOCWIRE_AI_EXPORT translate : public model_chain_element +{ + public: + explicit translate(const std::string& language, std::shared_ptr runner, model_lifetime_policy lifetime = model_lifetime_policy::persistent); +}; + +} // namespace docwire::ai + +#endif // DOCWIRE_AI_TRANSLATE_H diff --git a/src/cli.cmake b/src/cli.cmake index 77f6cd40f..d189572bd 100644 --- a/src/cli.cmake +++ b/src/cli.cmake @@ -2,6 +2,10 @@ add_executable(docwire docwire.cpp) find_package(Boost REQUIRED COMPONENTS program_options) target_link_libraries(docwire PRIVATE docwire_core docwire_office_formats docwire_mail docwire_ocr docwire_archives - docwire_local_ai docwire_openai docwire_ai docwire_content_type docwire_http Boost::program_options) + docwire_openai docwire_ai docwire_content_type docwire_http Boost::program_options) + +if(DOCWIRE_LOCAL_CT2 OR DOCWIRE_LLAMA) + target_link_libraries(docwire PRIVATE docwire_local_ai) +endif() install(TARGETS docwire DESTINATION bin) diff --git a/src/cosine_similarity.cpp b/src/cosine_similarity.cpp index e19ea5a8c..96f93043b 100644 --- a/src/cosine_similarity.cpp +++ b/src/cosine_similarity.cpp @@ -33,7 +33,7 @@ double cosine_similarity(const std::vector& a, const std::vector // Use a practical epsilon for the squared norm to check for zero vectors. // This threshold is aligned with the one used for L2 normalization in - // model_runner.cpp (1e-6f). The squared value is 1e-12. + // ct2_runner.cpp (1e-6f). The squared value is 1e-12. // Returning 0.0 is a common and practical approach, implying orthogonality. constexpr double zero_vector_threshold_sq = 1e-12; if (norm_a < zero_vector_threshold_sq || norm_b < zero_vector_threshold_sq) diff --git a/src/model_runner.cpp b/src/ct2_runner.cpp similarity index 57% rename from src/model_runner.cpp rename to src/ct2_runner.cpp index a9f0095aa..fa05b8eae 100644 --- a/src/model_runner.cpp +++ b/src/ct2_runner.cpp @@ -9,7 +9,7 @@ /* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ /*********************************************************************************************************************************************/ -#include "model_runner.h" +#include "ct2_runner.h" #include #include @@ -31,65 +31,76 @@ namespace docwire namespace { -std::variant load_model(const std::filesystem::path& model_data_path) +std::variant, + std::shared_ptr> +load_model(const std::filesystem::path& model_data_path) { log_scope(model_data_path); - try - { + try { log_scope(); - return std::variant{ - std::in_place_type, - ctranslate2::models::ModelLoader{model_data_path.string()}}; - } - catch (const std::exception& translator_error) - { + return std::make_shared( + ctranslate2::models::ModelLoader{model_data_path.string()}); + } catch (const std::exception& translator_error) { log_scope(translator_error); - try - { + try { log_scope(); - return std::variant{ - std::in_place_type, - ctranslate2::models::ModelLoader{model_data_path.string()}}; - } - catch (const std::exception& encoder_error) - { - std::throw_with_nested(make_error("Failed to load model as either Translator or Encoder", model_data_path, errors::program_corrupted{})); + return std::make_shared( + ctranslate2::models::ModelLoader{model_data_path.string()}); + } catch (const std::exception& encoder_error) { + std::throw_with_nested( + make_error("Failed to load model as either Translator or Encoder", model_data_path, + errors::program_corrupted{})); } } } } // anonymous namespace -template<> -struct pimpl_impl : pimpl_impl_base +template <> struct pimpl_impl : pimpl_impl_base { - std::variant m_model; - local_ai::tokenizer m_tokenizer; + std::mutex model_mutex; + std::variant, + std::shared_ptr> + m_model; + ai::ct2::tokenizer m_tokenizer; + std::filesystem::path m_model_path; pimpl_impl(const std::filesystem::path& model_data_path) - : m_model(load_model(model_data_path)), + : m_model_path(model_data_path), m_model(load_model(model_data_path)), m_tokenizer(model_data_path) - {} + { + } std::vector process(const std::vector& input_tokens) { log_scope(); - throw_if(!std::holds_alternative(m_model), "Model is not a Translator, cannot process.", errors::program_logic{}); - auto& translator = std::get(m_model); + std::shared_ptr translator_ptr; + + { + std::lock_guard lock(model_mutex); + if (std::holds_alternative(m_model)) + m_model = load_model(m_model_path); + + throw_if(!std::holds_alternative>(m_model), + "Model is not a Translator, cannot process.", errors::program_logic{}); + translator_ptr = std::get>(m_model); + } + auto& translator = *translator_ptr; ctranslate2::TranslationOptions options{}; - options.max_decoding_length = 1024; - options.sampling_temperature = 0.0; - options.beam_size = 1; + options.max_decoding_length = 1024; + options.sampling_temperature = 0.0; + options.beam_size = 1; options.disable_unk = true; - options.callback = [](ctranslate2::GenerationStepResult step_result)->bool - { + options.callback = [](ctranslate2::GenerationStepResult step_result) -> bool { log_entry(step_result.token); - return false; - }; - auto results = translator.translate_batch_async({ input_tokens }, options); - throw_if (results.size() != 1, "Unexpected number of results", results.size(), errors::program_logic{}); + return false; + }; + auto results = translator.translate_batch_async({input_tokens}, options); + throw_if(results.size() != 1, "Unexpected number of results", results.size(), + errors::program_logic{}); auto result = results[0].get(); - throw_if (result.hypotheses.size() != 1, "Unexpected number of hypotheses", result.hypotheses.size(), errors::program_logic{}); + throw_if(result.hypotheses.size() != 1, "Unexpected number of hypotheses", + result.hypotheses.size(), errors::program_logic{}); auto hypothesis = result.hypotheses[0]; return hypothesis; } @@ -97,14 +108,28 @@ struct pimpl_impl : pimpl_impl_base std::vector embed(const std::string& input) { log_scope(input); - throw_if(!std::holds_alternative(m_model), "Model is not an Encoder, cannot embed.", errors::program_logic{}); - auto& encoder = std::get(m_model); + std::shared_ptr encoder_ptr; + + { + std::lock_guard lock(model_mutex); + if (std::holds_alternative(m_model)) + m_model = load_model(m_model_path); + throw_if(!std::holds_alternative>(m_model), + "Model is not an Encoder, cannot embed.", errors::program_logic{}); + + encoder_ptr = std::get>(m_model); + + + } + + auto& encoder = *encoder_ptr; // 1. Tokenize - std::vector> tokens_batch = { m_tokenizer.tokenize(input) }; + std::vector> tokens_batch = {m_tokenizer.tokenize(input)}; // 2. Forward through the encoder - std::future future = encoder.forward_batch_async(tokens_batch); + std::future future = + encoder.forward_batch_async(tokens_batch); ctranslate2::EncoderForwardOutput encoder_output = future.get(); ctranslate2::StorageView& last_hidden_state = encoder_output.last_hidden_state; @@ -114,7 +139,8 @@ struct pimpl_impl : pimpl_impl_base // We create a new view of the hidden state that is "shrunk" to the // actual sequence length. This assumes a batch size of 1. const size_t batch_size = tokens_batch.size(); - throw_if(batch_size != 1, "Embedding function currently supports only batch size 1", errors::program_logic{}); + throw_if(batch_size != 1, "Embedding function currently supports only batch size 1", + errors::program_logic{}); const size_t actual_length = tokens_batch[0].size(); const size_t hidden_size = last_hidden_state.dim(-1); @@ -122,15 +148,16 @@ struct pimpl_impl : pimpl_impl_base // into the existing buffer without copying data. // We must cast the dimensions to int64_t to avoid a narrowing conversion error, // as the StorageView shape constructor expects signed integers. - ctranslate2::StorageView effective_hidden_state({1, static_cast(actual_length), static_cast(hidden_size)}, - last_hidden_state.data(), - last_hidden_state.device()); + ctranslate2::StorageView effective_hidden_state( + {1, static_cast(actual_length), static_cast(hidden_size)}, + last_hidden_state.data(), last_hidden_state.device()); ctranslate2::StorageView pooled_result; ctranslate2::ops::Mean(1)(effective_hidden_state, pooled_result); // 4. Manually perform L2 normalization. // This is required for sentence-transformer models like E5. - const ctranslate2::StorageView pooled_result_float = pooled_result.to(ctranslate2::DataType::FLOAT32); + const ctranslate2::StorageView pooled_result_float = + pooled_result.to(ctranslate2::DataType::FLOAT32); const float* pooled_data = pooled_result_float.data(); // Convert to double-precision vector *before* normalization to maintain accuracy. std::vector embedding_values(pooled_data, pooled_data + pooled_result_float.size()); @@ -153,14 +180,14 @@ struct pimpl_impl : pimpl_impl_base } }; -namespace local_ai +namespace ai::ct2 { -model_runner::model_runner(const std::filesystem::path& model_data_path) - : with_pimpl(model_data_path) -{} +ct2_runner::ct2_runner(const std::filesystem::path& model_data_path) : with_pimpl(model_data_path) +{ +} -std::string model_runner::process(const std::string& input) +std::string ct2_runner::process(const std::string& input) { log_scope(input); std::vector input_tokens = impl().m_tokenizer.tokenize(input); @@ -168,11 +195,18 @@ std::string model_runner::process(const std::string& input) return impl().m_tokenizer.detokenize(output_tokens); } -std::vector model_runner::embed(const std::string& input) +std::vector ct2_runner::embed(const std::string& input) { log_scope(input); return impl().embed(input); } -} // namespace local_ai +void ct2_runner::unload() +{ + log_scope(); + std::lock_guard lock(impl().model_mutex); + impl().m_model.emplace(); +} + +} // namespace ai::ct2 } // namespace docwire diff --git a/src/model_runner.h b/src/ct2_runner.h similarity index 74% rename from src/model_runner.h rename to src/ct2_runner.h index 4b10a0705..2687c1ef2 100644 --- a/src/model_runner.h +++ b/src/ct2_runner.h @@ -9,49 +9,56 @@ /* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ /*********************************************************************************************************************************************/ -#ifndef DOCWIRE_LOCAL_AI_MODEL_RUNNER_H -#define DOCWIRE_LOCAL_AI_MODEL_RUNNER_H +#ifndef DOCWIRE_AI_CT2_RUNNER_H +#define DOCWIRE_AI_CT2_RUNNER_H -#include "local_ai_export.h" +#include "ai_ct2_export.h" #include "pimpl.h" #include #include #include +#include "ai_runner.h" -namespace docwire::local_ai +namespace docwire::ai::ct2 { /** - * @brief Class representing the AI model loaded to memory. + * @brief Class representing the C2Translate AI model loaded to memory. * * Constructor loads model to memory and makes it ready for usage. * Destructor frees memory used by model. * It is important not to duplicate the object because memory consumption can be high. */ -class DOCWIRE_LOCAL_AI_EXPORT model_runner : public with_pimpl +class DOCWIRE_AI_CT2_EXPORT ct2_runner : public ai_runner, public with_pimpl { public: /** * @brief Constructor. Loads model to memory. * @param model_data_path Path to the folder containing model files. */ - model_runner(const std::filesystem::path& model_data_path); + ct2_runner(const std::filesystem::path& model_data_path); /** * @brief Process input text using the model. * @param input Text to process. * @return Processed text. */ - std::string process(const std::string& input); + std::string process(const std::string& input) override; /** * @brief Create embedding for the input text using the model. * @param input Text to process. * @return Vector of embedding values. */ - std::vector embed(const std::string& input); + std::vector embed(const std::string& input) override; + + /** + * @brief Unload the model and free associated resources. + * --!Must be thread-safe!-- and safe to call concurrently with process()/embed(). + */ + virtual void unload() override; }; -} // namespace docwire::local_ai +} // namespace docwire::ai::ct2 -#endif // DOCWIRE_LOCAL_AI_MODEL_RUNNER_H +#endif // DOCWIRE_AI_CT2_RUNNER_H diff --git a/src/docwire.cpp b/src/docwire.cpp index 24afd30e5..000879b2e 100755 --- a/src/docwire.cpp +++ b/src/docwire.cpp @@ -8,7 +8,20 @@ /* */ /* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ /*********************************************************************************************************************************************/ +#include "ai_runner.h" +#include "model_chain_element.h" +#include "model_inference_config.h" +#ifdef DOCWIRE_LOCAL_CT2 +#include "ct2_runner.h" +#endif +#ifdef DOCWIRE_LLAMA +#include "llama_runner.h" +#endif +#ifdef DOCWIRE_LOCAL_AI +#include "local_ai_embed.h" +#include "local_ai_task.h" +#endif #include "ai_elements.h" #include #include @@ -20,7 +33,8 @@ #include "archives_parser.h" #include "detect_sentiment.h" #include "embed.h" -#include "local_ai_embed.h" + + #include "extract_entities.h" #include "extract_keywords.h" #include "find.h" @@ -32,7 +46,6 @@ #include #include "mail_parser.h" #include "meta_data_exporter.h" -#include "model_chain_element.h" #include "ocr_parser.h" #include "office_formats_parser.h" #include "output.h" @@ -90,7 +103,7 @@ using magic_enum::istream_operators::operator>>; using magic_enum::ostream_operators::operator<<; enum class OutputType { plain_text, html, csv, metadata }; - +enum class EmbedPrefixType { none, query, passage }; template std::string enum_names_str() { @@ -104,6 +117,35 @@ std::string enum_names_str() } return names_str; } +#ifdef DOCWIRE_LOCAL_AI +static std::shared_ptr +create_local_runner(const boost::program_options::variables_map& vm, + const std::string& default_model) +{ + if (vm.count("local-ai-model")) + { + std::string model_path = vm["local-ai-model"].as(); + if (model_path.ends_with(".gguf")) + { + #ifdef DOCWIRE_LLAMA + ai::model_inference_config config; + config.model_path = model_path; + config.n_ctx = ai::context_size{4096}; + config.n_threads = ai::thread_count{4}; + return std::make_shared(config); + #else + throw std::runtime_error("GGUF model support requires the llama-engine feature"); + #endif + } + + return std::make_shared(model_path); + } + + return std::make_shared( + resource_path(default_model) + ); +} +#endif int main(int argc, char* argv[]) { @@ -117,12 +159,12 @@ int main(int argc, char* argv[]) po::options_description desc("Allowed options"); desc.add_options() ("help", "display help message") - ("version", "display DocWire version") + ("version", "display DocWire version") ("input-file", po::value()->required(), "path to file to process") ("output_type", po::value()->default_value(OutputType::plain_text), enum_names_str().c_str()) ("http-post", po::value(), "url to process data via http post") ("local-ai-prompt", po::value(), "prompt to process text via local AI model") - ("local-ai-embed", po::value()->implicit_value(""), "generate embedding of text via local AI model. Optional argument is a prefix (e.g. \"passage: \" or \"query: \").") + ("local-ai-embed", po::value()->implicit_value(EmbedPrefixType::none), "generate embedding of text via local AI model. Optional argument selects the prefix type: (e.g. \"passage: \" or \"query: \" or \"none: \").") ("local-ai-model", po::value(), "path to local AI model data (build-in default model is used if not specified)") ("openai-chat", po::value(), "prompt to process text and images via OpenAI") ("openai-extract-entities", "extract entities from text and images via OpenAI") @@ -378,19 +420,16 @@ int main(int argc, char* argv[]) vm.count("openai-temperature") ? vm["openai-temperature"].as() : 0, image_detail); } - + #ifdef DOCWIRE_LOCAL_CT2 if (vm.count("local-ai-prompt")) { try { std::string prompt = vm["local-ai-prompt"].as(); - auto model_runner = vm.count("local-ai-model") ? - std::make_shared(vm["local-ai-model"].as()) : - std::make_shared(resource_path("flan-t5-large-ct2-int8")); - + auto runner = create_local_runner(vm, "flan-t5-large-ct2-int8"); chain |= - local_ai::model_chain_element(prompt, model_runner); + ai::local::task(prompt, runner); } catch(const std::exception& e) { @@ -403,12 +442,15 @@ int main(int argc, char* argv[]) { try { - std::string prefix = vm["local-ai-embed"].as(); - auto model_runner = vm.count("local-ai-model") ? - std::make_shared(vm["local-ai-model"].as()) : - std::make_shared(resource_path("multilingual-e5-small-ct2-int8")); - - chain |= local_ai::embed(model_runner, prefix); + EmbedPrefixType prefix_type = vm["local-ai-embed"].as(); + if (prefix_type == EmbedPrefixType::query) + { + chain |= ai::local::query::embedder(); + } else if (prefix_type == EmbedPrefixType::passage) { + chain |= ai::local::passage::embedder(); + } else { + chain |= ai::local::passage::embedder(); + } chain |= [](message_ptr msg, const message_callbacks& emit_message) -> continuation { if (msg->is()) { @@ -432,6 +474,16 @@ int main(int argc, char* argv[]) return 1; } } + #else + if (vm.count("local-ai-prompt") || vm.count("local-ai-embed")) + { + std::cerr << "Error: Local AI features requested, but this build does not include " + "DOCWIRE_LOCAL_CT2 support.\n" + "Rebuild with DOCWIRE_LOCAL_CT2 enabled to use --local-ai-prompt or " + "--local-ai-embed." << std::endl; + return 1; + } + #endif if (vm.count("openai-find")) { diff --git a/src/docwire.h b/src/docwire.h index 5cf3de964..a8b0508d6 100644 --- a/src/docwire.h +++ b/src/docwire.h @@ -13,6 +13,25 @@ #define DOCWIRE_DOCWIRE_H // IWYU pragma: begin_exports +#include "ai_runner.h" +#include "ai_summarize.h" +#include "ai_translate.h" +#include "ai_embed.h" +#include "ai_task.h" +#include "model_chain_element.h" +#ifdef DOCWIRE_LOCAL_CT2 +#include "ct2_runner.h" +#endif +#ifdef DOCWIRE_LLAMA +#include "llama_runner.h" +#include "model_inference_config.h" +#endif +#ifdef DOCWIRE_LOCAL_AI +#include "local_ai_summarize.h" +#include "local_ai_translate.h" +#include "local_ai_embed.h" +#include "local_ai_task.h" +#endif #include "ai_elements.h" #include "classify.h" #include "concepts.h" @@ -33,12 +52,10 @@ #include "find.h" #include "fuzzy_match.h" #include "input.h" -#include "local_ai_embed.h" #include "log.h" #include "output.h" #include "mail_elements.h" #include "mail_parser.h" -#include "model_chain_element.h" #include "ocr_parser.h" #include "office_formats_parser.h" #include "plain_text_exporter.h" diff --git a/src/llama_handler.h b/src/llama_handler.h new file mode 100644 index 000000000..63ad1ef00 --- /dev/null +++ b/src/llama_handler.h @@ -0,0 +1,96 @@ +/*********************************************************************************************************************************************/ +/* DocWire SDK: Award-winning modern data processing in C++20. SourceForge + * Community Choice & Microsoft support. AI-driven processing. */ +/* Supports nearly 100 data formats, including email boxes and OCR. Boost + * efficiency in text extraction, web data extraction, data mining, */ +/* document analysis. Offline processing possible for security and + * confidentiality */ +/* */ +/* Copyright (c) SILVERCODERS Ltd, http://silvercoders.com */ +/* Project homepage: https://github.com/docwire/docwire */ +/* */ +/* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ +/*********************************************************************************************************************************************/ + +#ifndef DOCWIRE_AI_LLAMA_HANDLER_H +#define DOCWIRE_AI_LLAMA_HANDLER_H + +#include +#include + +namespace docwire::ai::llama { +/* + * @brief A generic deleter to delete various Llama objects created on Heap memory; + */ +template struct llama_deleter; + +// handler for llama_model +template <> struct llama_deleter { + void operator()(llama_model *ptr) const noexcept { + if (ptr) + llama_model_free(ptr); + } +}; + +// handler for llama_context +template <> struct llama_deleter { + void operator()(llama_context *ptr) const noexcept { + if (ptr) + llama_free(ptr); + } +}; + +// handler for llama_sampler +template <> struct llama_deleter { + void operator()(llama_sampler *ptr) const noexcept { + if (ptr) + llama_sampler_free(ptr); + } +}; + +// all future handlers come below + +/** + * @brief Takes in llama model related heap initializers and attaches + * them to their respective deleters defined above. + * + * Responsibilities: + * - Passes refernces for llama heap objects to llama_deleter + * - resets and release them as needed + * - get() provides value as needed + */ +template class llama_handle { +public: + using pointer = T *; + + llama_handle() noexcept = default; + + explicit llama_handle(pointer ptr) noexcept : ptr_(ptr) {} + + ~llama_handle() = default; + + llama_handle(llama_handle &&) noexcept = default; + llama_handle &operator=(llama_handle &&) noexcept = default; + + llama_handle(const llama_handle &) = delete; + llama_handle &operator=(const llama_handle &) = delete; + + pointer get() const noexcept { return ptr_.get(); } + + pointer release() noexcept { return ptr_.release(); } + + void reset(pointer p = nullptr) noexcept { ptr_.reset(p); } + + explicit operator bool() const noexcept { return static_cast(ptr_); } + + pointer operator->() const noexcept { return ptr_.get(); } + + T &operator*() const noexcept { return *ptr_; } + +private: + std::unique_ptr> ptr_; +}; + +} // namespace docwire::ai::llama + +#endif diff --git a/src/llama_runner.cpp b/src/llama_runner.cpp new file mode 100644 index 000000000..c1bfc8e04 --- /dev/null +++ b/src/llama_runner.cpp @@ -0,0 +1,411 @@ +/*********************************************************************************************************************************************/ +/* DocWire SDK: Award-winning modern data processing in C++20. SourceForge Community Choice & Microsoft support. AI-driven processing. */ +/* Supports nearly 100 data formats, including email boxes and OCR. Boost efficiency in text extraction, web data extraction, data mining, */ +/* document analysis. Offline processing possible for security and confidentiality */ +/* */ +/* Copyright (c) SILVERCODERS Ltd, http://silvercoders.com */ +/* Project homepage: https://github.com/docwire/docwire */ +/* */ +/* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ +/*********************************************************************************************************************************************/ + +#include "llama_runner.h" +#include "error_tags.h" +#include "llama_handler.h" +#include "throw_if.h" +#include +#include +#include +#include +#include + +namespace docwire +{ + +namespace +{ + +std::mutex llama_backend_mutex; +std::condition_variable llama_backend_cv; +std::size_t runner_count = 0; +std::size_t active_calls = 0; +bool g_verbose = false; +/** + * @brief Manages global lifetime of llama.cpp backend. + * + * llama.cpp requires explicit global initialization and teardown via: + * - llama_backend_init() + * - llama_backend_free() + * + * This guard implements a reference-counted lifetime model: + * + * - The first live llama_runner initializes the backend. + * - The last destroyed llama_runner frees the backend. + * + * Thread Safety: + * - Protected by a global mutex. + * - Teardown waits for all active inference calls to complete. + * + * This prevents undefined behavior if a runner is destroyed while + * another thread is still performing inference. + */ +struct llama_backend_guard +{ + llama_backend_guard() + { + std::lock_guard lock(llama_backend_mutex); + if (runner_count++ == 0) { + llama_backend_init(); + } + } + + ~llama_backend_guard() + { + std::unique_lock lock(llama_backend_mutex); + if (--runner_count == 0) { + llama_backend_cv.wait(lock, [] { return active_calls == 0; }); + llama_backend_free(); + } + } + + static void acquire_call() + { + std::lock_guard lock(llama_backend_mutex); + ++active_calls; + } + + static void release_call() + { + std::lock_guard lock(llama_backend_mutex); + if (--active_calls == 0 && runner_count == 0) { + llama_backend_cv.notify_all(); + } + } +}; +/** + * @brief Tracks active inference calls. + * + * Each call to llama_runner::process() creates a llama_call_guard. + * + * Responsibilities: + * - Increments the global active_calls counter on entry. + * - Decrements it on exit. + * + * This ensures that backend teardown is deferred until all + * in-flight llama_decode() calls have completed. + * + * Used together with llama_backend_guard to provide safe + * concurrent inference and deterministic backend shutdown. + */ + +struct llama_call_guard +{ + llama_call_guard() { llama_backend_guard::acquire_call(); } + ~llama_call_guard() { llama_backend_guard::release_call(); } +}; + +} // anonymous namespace + +template <> struct pimpl_impl : pimpl_impl_base +{ + std::mutex model_mutex; + llama_backend_guard llama_backend; + ai::model_inference_config config; + ai::llama::llama_handle model; + ai::llama::llama_handle ctx; + ai::llama::llama_handle sampler; + const llama_vocab* vocab = nullptr; + + static void llamaLogCallback(ggml_log_level level, const char* text, void* /*user*/) + { + if (g_verbose || level == GGML_LOG_LEVEL_ERROR) { + std::cerr << text; + } + } + pimpl_impl(const ai::model_inference_config& cfg) : config(cfg) + { + g_verbose = config.verbose; + // Redirect llama.cpp's logs through our callback + llama_log_set(llamaLogCallback, nullptr); + } + + void ensure_model_loaded() + { + std::lock_guard lock(model_mutex); + if (model) + return; + + llama_model_params model_params = llama_model_default_params(); + + model = docwire::ai::llama::llama_handle( + llama_model_load_from_file(config.model_path.c_str(), model_params)); + + throw_if(!model, "Failed to load llama model.", errors::program_corrupted{}); + vocab = llama_model_get_vocab(model.get()); + + llama_context_params ctx_params = llama_context_default_params(); + + ctx_params.n_ctx = config.n_ctx.get(); + ctx_params.n_batch = config.n_batch.get(); + ctx_params.n_threads = config.n_threads.get(); + ctx_params.embeddings = true; + + ctx = docwire::ai::llama::llama_handle(llama_init_from_model(model.get(), ctx_params)); + + throw_if(!ctx, "Failed to create llama context.", errors::program_corrupted{}); + + llama_sampler_chain_params sp = llama_sampler_chain_default_params(); + + sampler = docwire::ai::llama::llama_handle(llama_sampler_chain_init(sp)); + + throw_if(!sampler, "Failed to create sampler.", errors::program_corrupted{}); + + llama_sampler_chain_add(sampler.get(), + llama_sampler_init_min_p(config.min_probability.get(), 1)); + + llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(config.temp.get())); + + llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); + if (!config.grammar.empty()) { + const llama_vocab* vocab = llama_model_get_vocab(model.get()); + llama_sampler_chain_add(sampler.get(), + llama_sampler_init_grammar(vocab, config.grammar.c_str(), + config.grammar_root.c_str())); + } + } + + void reset() + { + // Get the memory handle first + llama_memory_t mem = llama_get_memory(ctx.get()); + // Then clear all sequences + llama_memory_seq_rm(mem, -1, -1, -1); + llama_sampler_reset(sampler.get()); + } + + void llama_unload() + { + std::lock_guard lock(model_mutex); + sampler.reset(); + ctx.reset(); + model.reset(); + } + + /** + * @brief Builds the chat-template prompt by combining + * system prompt from config and user prompt + * @param user_input prompt given by user for a certain task + */ + std::string build_prompt(const std::string& user_input) const + { + std::vector messages = {{"system", config.system_prompt.c_str()}, + {"user", user_input.c_str()}}; + + std::string prompt; + const char* tmpl = llama_model_chat_template(model.get(), nullptr); + + if (tmpl) + { + int32_t req = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, nullptr, 0); + throw_if(req <= 0, "Template size query failed", errors::program_logic{}); + + std::vector formatted(req + 1, '\0'); + int32_t written = llama_chat_apply_template(tmpl, messages.data(), messages.size(), + true, formatted.data(), formatted.size()); + throw_if(written <= 0, "Template formatting failed", errors::program_logic{}); + + prompt.assign(formatted.data(), written); + } + else + { + // Fallback ChatML template + prompt = "<|im_start|>system\n" + config.system_prompt + + "\n" + "<|im_end|>\n" + "<|im_start|>user\n" + + user_input + + "\n" + "<|im_end|>\n" + "<|im_start|>assistant\n"; + } + return prompt; + } + + /** + * @brief Tokenizes entire prompt and return the token vector. + * @param prompt + */ + std::vector tokenize(const std::string& prompt) const + { + int n_tokens = llama_tokenize(vocab, prompt.c_str(), static_cast(prompt.size()), + nullptr, 0, false, true); + if (n_tokens < 0) + n_tokens = -n_tokens; + + throw_if(n_tokens == 0, "Empty tokenization result", errors::program_logic{}); + + std::vector tokens(n_tokens); + int written = llama_tokenize(vocab, prompt.c_str(), static_cast(prompt.size()), + tokens.data(), tokens.size(), false, true); + + throw_if(written != n_tokens, "Tokenization mismatch.", errors::program_logic{}); + throw_if(static_cast(n_tokens) > config.n_ctx.get(), + "Input exceeds context window.", errors::program_logic{}); + + return tokens; + } + + + /** + * @brief This function feeds tokens into the context in batches + * @param tokens + */ + void decode_prompt(const std::vector& tokens) + { + const int32_t n_batch = static_cast(config.n_batch.get()); + llama_pos pos = 0; + + for (size_t start = 0; start < tokens.size(); start += n_batch) { + int32_t len = std::min(n_batch, static_cast(tokens.size() - start)); + + llama_batch batch = + llama_batch_get_one(const_cast(tokens.data() + start), len); + + throw_if(llama_decode(ctx.get(), batch) != 0, "Initial decode failed", + errors::program_logic{}); + + pos += len; + } + } + + /** + * @brief This function generates response from the model and returns + */ + std::string generate() + { + std::string output; + const int max_tokens = static_cast(config.max_tokens.get()); + + for (int i = 0; i < max_tokens; ++i) { + llama_token token = llama_sampler_sample(sampler.get(), ctx.get(), -1); + llama_sampler_accept(sampler.get(), token); + + if (llama_vocab_is_eog(vocab, token)) + break; + + char buf[256]; + int n = llama_token_to_piece(vocab, token, buf, sizeof(buf), 0, true); + if (n > 0) { + output.append(buf, n); + } + + llama_batch batch = llama_batch_get_one(&token, 1); + + if (llama_decode(ctx.get(), batch) != 0) + break; + } + return output; + } + + /** + * @brief This function provides abstraction for actual process function + */ + std::string process(const std::string& user_input) + { + ensure_model_loaded(); + reset(); + + std::string prompt = build_prompt(user_input); + auto tokens = tokenize(prompt); + + decode_prompt(tokens); + return generate(); + } +}; + +namespace ai::llama +{ +llama_runner::llama_runner(const model_inference_config& config) : with_pimpl(config) {} + +void llama_runner::unload() { impl().llama_unload(); } + +/* + * This function runs inference on the given model provided to Llama + */ +std::string llama_runner::process(const std::string& input) +{ + llama_call_guard guard; + return impl().process(input); +} + +/** + * @brief Generates an embedding vector for the given input string. + */ +std::vector llama_runner::embed(const std::string& input) +{ + llama_call_guard guard; + auto& impl = this->impl(); + + impl.ensure_model_loaded(); + impl.reset(); + + throw_if(llama_model_n_embd(impl.model.get()) <= 0, "Model has no embedding dimension.", + errors::program_logic{}); + + const llama_vocab* vocab = llama_model_get_vocab(impl.model.get()); + + int n_tokens = llama_tokenize(vocab, input.c_str(), input.length(), nullptr, 0, true, false); + + throw_if(n_tokens <= 0, "Cannot embed empty input.", errors::program_logic{}); + + std::vector tokens(n_tokens); + + int written = llama_tokenize(vocab, input.c_str(), input.length(), tokens.data(), tokens.size(), + true, false); + + throw_if(written != n_tokens, "Tokenization mismatch.", errors::program_logic{}); + + llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size()); + + throw_if(llama_decode(impl.ctx.get(), batch) != 0, "Decode failed during embedding.", + errors::program_logic{}); + + const float* all_embeddings = llama_get_embeddings(impl.ctx.get()); + + throw_if(!all_embeddings, "Embeddings not available from model.", errors::program_logic{}); + + const int n_embd = llama_model_n_embd(impl.model.get()); + + std::vector result(n_embd, 0.0); + + // Mean pooling + for (int t = 0; t < n_tokens; ++t) { + const float* token_emb = all_embeddings + t * n_embd; + + for (int i = 0; i < n_embd; ++i) { + result[i] += static_cast(token_emb[i]); + } + } + + for (double& v : result) { + v /= static_cast(n_tokens); + } + + // L2 normalization + double norm = 0.0; + for (double v : result) + norm += v * v; + + norm = std::sqrt(norm); + + if (norm > 1e-6) { + for (double& v : result) + v /= norm; + } + + return result; +} + +} // namespace ai::llama + +} // namespace docwire diff --git a/src/llama_runner.h b/src/llama_runner.h new file mode 100644 index 000000000..0aac39c93 --- /dev/null +++ b/src/llama_runner.h @@ -0,0 +1,47 @@ +/*********************************************************************************************************************************************/ +/* DocWire SDK: Award-winning modern data processing in C++20. SourceForge + * Community Choice & Microsoft support. AI-driven processing. */ +/* Supports nearly 100 data formats, including email boxes and OCR. Boost + * efficiency in text extraction, web data extraction, data mining, */ +/* document analysis. Offline processing possible for security and + * confidentiality */ +/* */ +/* Copyright (c) SILVERCODERS Ltd, http://silvercoders.com */ +/* Project homepage: https://github.com/docwire/docwire */ +/* */ +/* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ +/*********************************************************************************************************************************************/ + +#ifndef DOCWIRE_AI_LLAMA_RUNNER_H +#define DOCWIRE_AI_LLAMA_RUNNER_H + +#include "ai_runner.h" +#include "ai_llama_export.h" +#include "model_inference_config.h" +#include "pimpl.h" + +namespace docwire::ai::llama +{ +/** + * @brief This class is intended to load a Llama model with its correct model path and + * respective configuration and run inference on the prompt supplied along with + * the model configuration. + */ +class DOCWIRE_AI_LLAMA_EXPORT llama_runner : public ai_runner, public with_pimpl +{ + private: + bool supports_embeddings() const; + + public: + explicit llama_runner(const model_inference_config& config); + + std::string process(const std::string& input) override; + + std::vector embed(const std::string&) override; + + virtual void unload() override; +}; + +} // namespace docwire::ai::llama + +#endif diff --git a/src/local_ai.cmake b/src/local_ai.cmake index 5deaa35a1..27712aeb9 100644 --- a/src/local_ai.cmake +++ b/src/local_ai.cmake @@ -1,32 +1,46 @@ -add_library(docwire_local_ai SHARED local_ai_embed.cpp model_chain_element.cpp model_runner.cpp tokenizer.cpp) -find_package(Boost REQUIRED COMPONENTS filesystem system json) -find_package(ctranslate2 CONFIG REQUIRED) -find_library(sentencepiece_LIBRARIES sentencepiece REQUIRED) -if(MSVC) - find_package(absl CONFIG REQUIRED) - list(APPEND sentencepiece_LIBRARIES - absl::strings - absl::flags - absl::flags_parse - absl::log - absl::check) - find_package(protobuf CONFIG REQUIRED) - list(APPEND sentencepiece_LIBRARIES protobuf::libprotobuf-lite) +add_library(docwire_local_ai SHARED + local_ai_summarize.cpp + local_ai_embed.cpp + local_ai_translate.cpp + local_ai_task.cpp +) +target_link_libraries(docwire_local_ai PUBLIC docwire_ai) +target_compile_definitions(docwire_local_ai PUBLIC DOCWIRE_LOCAL_AI) + +if(DOCWIRE_LLAMA) + target_link_libraries(docwire_local_ai PUBLIC docwire_ai_llama) + target_compile_definitions(docwire_local_ai PUBLIC DOCWIRE_LLAMA) endif() -target_link_libraries(docwire_local_ai PRIVATE docwire_core docwire_ai Boost::filesystem Boost::json CTranslate2::ctranslate2 ${sentencepiece_LIBRARIES}) -install(TARGETS docwire_local_ai EXPORT docwire_targets) -if(MSVC) - install(FILES $ DESTINATION bin CONFIGURATIONS Debug) +if(DOCWIRE_LOCAL_CT2) + target_link_libraries(docwire_local_ai PUBLIC docwire_ai_ct2) + target_compile_definitions(docwire_local_ai PUBLIC DOCWIRE_LOCAL_CT2) endif() +target_include_directories(docwire_local_ai + PUBLIC + $ + $ +) + +target_compile_features(docwire_local_ai PUBLIC cxx_std_20) + + include(GenerateExportHeader) + generate_export_header(docwire_local_ai EXPORT_FILE_NAME local_ai_export.h) -install(FILES ${CMAKE_CURRENT_BINARY_DIR}/local_ai_export.h DESTINATION include/docwire) -docwire_find_resource(FLAN_T5_FULL_PATH REL_PATH "flan-t5-large-ct2-int8" REQUIRED) -docwire_target_resources(docwire_local_ai "flan-t5-large-ct2-int8" SOURCE "${FLAN_T5_FULL_PATH}") +target_include_directories(docwire_local_ai PUBLIC $) + +install(FILES + ${CMAKE_CURRENT_BINARY_DIR}/local_ai_export.h + DESTINATION include/docwire +) + +install(TARGETS docwire_local_ai EXPORT docwire_targets) -docwire_find_resource(E5_MODEL_FULL_PATH REL_PATH "multilingual-e5-small-ct2-int8" REQUIRED) -docwire_target_resources(docwire_local_ai "multilingual-e5-small-ct2-int8" SOURCE "${E5_MODEL_FULL_PATH}") +if(MSVC) + install(FILES $ + DESTINATION bin CONFIGURATIONS Debug) +endif() diff --git a/src/local_ai_embed.cpp b/src/local_ai_embed.cpp index c4153d52d..0b733d8e9 100644 --- a/src/local_ai_embed.cpp +++ b/src/local_ai_embed.cpp @@ -10,59 +10,34 @@ /*********************************************************************************************************************************************/ #include "local_ai_embed.h" - -#include "ai_elements.h" -#include "data_source.h" -#include "error_tags.h" -#include "log_scope.h" +#include "ct2_runner.h" #include "resource_path.h" -#include "serialization_message.h" // IWYU pragma: keep #include -#include "throw_if.h" -#include - -namespace docwire -{ -template<> -struct pimpl_impl : pimpl_impl_base +namespace { - std::shared_ptr m_model_runner; - std::string m_prefix; - pimpl_impl(std::shared_ptr model_runner, std::string prefix) - : m_model_runner(std::move(model_runner)), m_prefix(std::move(prefix)) - {} -}; +constexpr std::string_view default_passage_prefix = "passage: "; +constexpr std::string_view default_query_prefix = "query: "; -namespace local_ai +std::shared_ptr make_default_runner() { + return std::make_shared( + docwire::resource_path("multilingual-e5-small-ct2-int8")); +} -const std::string embed::e5_passage_prefix = "passage: "; -const std::string embed::e5_query_prefix = "query: "; - -embed::embed(std::shared_ptr model_runner, std::string prefix) - : with_pimpl(std::move(model_runner), std::move(prefix)) -{} +} // anonymous namespace -embed::embed(std::string prefix) - : with_pimpl(std::make_shared(resource_path("multilingual-e5-small-ct2-int8")), std::move(prefix)) +namespace docwire::ai::local::passage +{ +embedder::embedder() + : docwire::ai::embed(make_default_runner(), std::string{default_passage_prefix}) {} +} // namespace docwire::ai::local::passage -continuation embed::operator()(message_ptr msg, const message_callbacks& emit_message) +namespace docwire::ai::local::query { - log_scope(msg); - if (!msg->is()) - return emit_message(std::move(msg)); - - const data_source& data = msg->get(); - throw_if(!data.has_highest_confidence_mime_type_in({mime_type{"text/plain"}}), "Input for local_ai::embed must be text/plain", errors::program_logic{}); - std::string data_str = data.string(); - - std::string prefixed_input = impl().m_prefix + data_str; - std::vector embedding_vector = impl().m_model_runner->embed(prefixed_input); - return emit_message(ai::embedding{std::move(embedding_vector)}); -} - -} // namespace local_ai -} // namespace docwire +embedder::embedder() + : docwire::ai::embed(make_default_runner(), std::string{default_query_prefix}) +{} +} // namespace docwire::ai::local::query diff --git a/src/local_ai_embed.h b/src/local_ai_embed.h index 4dbbaafa6..3cf2415d0 100644 --- a/src/local_ai_embed.h +++ b/src/local_ai_embed.h @@ -12,55 +12,39 @@ #ifndef DOCWIRE_LOCAL_AI_EMBED_H #define DOCWIRE_LOCAL_AI_EMBED_H -#include "chain_element.h" +#include "ai_embed.h" #include "local_ai_export.h" -#include "model_runner.h" -#include "pimpl.h" -#include -namespace docwire::local_ai +namespace docwire::ai::local::passage { /** - * @brief A chain element that generates embeddings for input text using a local AI model. - * - * This class is a chain element that takes a model_runner to generate a vector - * embedding for a given text. It is designed to work with sentence-transformer - * models like `multilingual-e5-small`. + * @brief Embeds a passage (document chunk) using the local AI model's default passage prefix. + * The appropriate prefix for the underlying model (e.g. "passage: " for multilingual-e5-small) + * is applied automatically. No model-specific knowledge required at the call site. */ -class DOCWIRE_LOCAL_AI_EXPORT embed : public ChainElement, public with_pimpl +class DOCWIRE_LOCAL_AI_EXPORT embedder : public docwire::ai::embed { -public: - /// Common prefix for passage embeddings with E5 models. - static const std::string e5_passage_prefix; - /// Common prefix for query embeddings with E5 models. - static const std::string e5_query_prefix; - - /** - * @brief Construct a local AI embed chain element with a specific model runner and prefix. - * - * @param model_runner The model runner to use for generating embeddings. - * @param prefix The string to prepend to the input text. Use an empty string for no prefix. - */ - explicit embed(std::shared_ptr model_runner, std::string prefix); - - /** - * @brief Construct a local AI embed chain element with a default model runner and prefix. - * - * This constructor initializes the embedder with a default `model_runner` - * configured to use the `multilingual-e5-small-ct2-int8` model. - * @param prefix The string to prepend to the input text. Use an empty string for no prefix. - */ - explicit embed(std::string prefix); - - continuation operator()(message_ptr msg, const message_callbacks& emit_message) override; + public: + embedder(); +}; +} // namespace docwire::ai::local::passage - bool is_leaf() const override { return false; } +namespace docwire::ai::local::query +{ +/** + * @brief Embeds a search query (search input) using the local AI model's default query prefix. + * + * The appropriate prefix for the underlying model (e.g. "query: " for multilingual-e5-small) + * is applied automatically. No model-specific knowledge required at the call site. -private: - using with_pimpl::impl; + */ +class DOCWIRE_LOCAL_AI_EXPORT embedder : public docwire::ai::embed +{ + public: + embedder(); }; -} // namespace docwire::local_ai +} // namespace docwire::ai::local::query #endif // DOCWIRE_LOCAL_AI_EMBED_H diff --git a/src/local_ai_summarize.cpp b/src/local_ai_summarize.cpp new file mode 100644 index 000000000..9271fb1dd --- /dev/null +++ b/src/local_ai_summarize.cpp @@ -0,0 +1,28 @@ +/*********************************************************************************************************************************************/ +/* DocWire SDK: Award-winning modern data processing in C++20. SourceForge Community Choice & Microsoft support. AI-driven processing. */ +/* Supports nearly 100 data formats, including email boxes and OCR. Boost efficiency in text extraction, web data extraction, data mining, */ +/* document analysis. Offline processing possible for security and confidentiality */ +/* */ +/* Copyright (c) SILVERCODERS Ltd, http://silvercoders.com */ +/* Project homepage: https://github.com/docwire/docwire */ +/* */ +/* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ +/*********************************************************************************************************************************************/ + +#include "ct2_runner.h" +#include "resource_path.h" +#include "local_ai_summarize.h" + +namespace docwire::ai::local +{ + +summarize::summarize(model_lifetime_policy lifetime) + : docwire::ai::summarize( + std::make_shared(resource_path("flan-t5-large-ct2-int8")), lifetime) +{} + +summarize::summarize(std::shared_ptr runner, model_lifetime_policy lifetime) + : docwire::ai::summarize(runner, lifetime) +{} + +} diff --git a/src/local_ai_summarize.h b/src/local_ai_summarize.h new file mode 100644 index 000000000..290da255f --- /dev/null +++ b/src/local_ai_summarize.h @@ -0,0 +1,30 @@ +/*********************************************************************************************************************************************/ +/* DocWire SDK: Award-winning modern data processing in C++20. SourceForge Community Choice & Microsoft support. AI-driven processing. */ +/* Supports nearly 100 data formats, including email boxes and OCR. Boost efficiency in text extraction, web data extraction, data mining, */ +/* document analysis. Offline processing possible for security and confidentiality */ +/* */ +/* Copyright (c) SILVERCODERS Ltd, http://silvercoders.com */ +/* Project homepage: https://github.com/docwire/docwire */ +/* */ +/* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ +/*********************************************************************************************************************************************/ + +#ifndef DOCWIRE_LOCAL_AI_SUMMARIZE_H +#define DOCWIRE_LOCAL_AI_SUMMARIZE_H + +#include "local_ai_export.h" +#include "ai_summarize.h" + +namespace docwire::ai::local +{ + +class DOCWIRE_LOCAL_AI_EXPORT summarize : public docwire::ai::summarize +{ +public: + summarize(model_lifetime_policy lifetime = model_lifetime_policy::persistent); + explicit summarize(std::shared_ptr runner, model_lifetime_policy lifetime = model_lifetime_policy::persistent); +}; + +} // namespace docwire::ai::local + +#endif // DOCWIRE_LOCAL_AI_SUMMARIZE_H diff --git a/src/local_ai_task.cpp b/src/local_ai_task.cpp new file mode 100644 index 000000000..0fe42f20f --- /dev/null +++ b/src/local_ai_task.cpp @@ -0,0 +1,30 @@ +/*********************************************************************************************************************************************/ +/* DocWire SDK: Award-winning modern data processing in C++20. SourceForge Community Choice & Microsoft support. AI-driven processing. */ +/* Supports nearly 100 data formats, including email boxes and OCR. Boost efficiency in text extraction, web data extraction, data mining, */ +/* document analysis. Offline processing possible for security and confidentiality */ +/* */ +/* Copyright (c) SILVERCODERS Ltd, http://silvercoders.com */ +/* Project homepage: https://github.com/docwire/docwire */ +/* */ +/* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ +/*********************************************************************************************************************************************/ + +#include "local_ai_task.h" +#include "ct2_runner.h" +#include "resource_path.h" + +namespace docwire::ai::local +{ + +task::task(const std::string& prompt, model_lifetime_policy lifetime) + : docwire::ai::task(prompt, std::make_shared( + resource_path("flan-t5-large-ct2-int8")), lifetime) +{ +} + +task::task(const std::string& prompt, std::shared_ptr runner, model_lifetime_policy lifetime) + : docwire::ai::task(prompt, runner, lifetime) +{ +} + +} // namespace docwire::ai::local diff --git a/src/local_ai_task.h b/src/local_ai_task.h new file mode 100644 index 000000000..0478d964b --- /dev/null +++ b/src/local_ai_task.h @@ -0,0 +1,30 @@ +/*********************************************************************************************************************************************/ +/* DocWire SDK: Award-winning modern data processing in C++20. SourceForge Community Choice & Microsoft support. AI-driven processing. */ +/* Supports nearly 100 data formats, including email boxes and OCR. Boost efficiency in text extraction, web data extraction, data mining, */ +/* document analysis. Offline processing possible for security and confidentiality */ +/* */ +/* Copyright (c) SILVERCODERS Ltd, http://silvercoders.com */ +/* Project homepage: https://github.com/docwire/docwire */ +/* */ +/* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ +/*********************************************************************************************************************************************/ + +#ifndef DOCWIRE_LOCAL_AI_TASK_H +#define DOCWIRE_LOCAL_AI_TASK_H + +#include "local_ai_export.h" +#include "ai_task.h" + +namespace docwire::ai::local +{ + +class DOCWIRE_LOCAL_AI_EXPORT task : public docwire::ai::task +{ +public: + explicit task(const std::string& prompt, model_lifetime_policy lifetime = model_lifetime_policy::persistent); + explicit task(const std::string& prompt, std::shared_ptr runner, model_lifetime_policy lifetime = model_lifetime_policy::persistent); +}; + +} // namespace docwire::ai::local + +#endif // DOCWIRE_LOCAL_AI_TASK_H diff --git a/src/local_ai_translate.cpp b/src/local_ai_translate.cpp new file mode 100644 index 000000000..2fb3cd1f8 --- /dev/null +++ b/src/local_ai_translate.cpp @@ -0,0 +1,28 @@ +/*********************************************************************************************************************************************/ +/* DocWire SDK: Award-winning modern data processing in C++20. SourceForge Community Choice & + * Microsoft support. AI-driven processing. */ +/* Supports nearly 100 data formats, including email boxes and OCR. Boost efficiency in text + * extraction, web data extraction, data mining, */ +/* document analysis. Offline processing possible for security and confidentiality */ +/* */ +/* Copyright (c) SILVERCODERS Ltd, http://silvercoders.com */ +/* Project homepage: https://github.com/docwire/docwire */ +/* */ +/* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ +/*********************************************************************************************************************************************/ +#include "local_ai_translate.h" +#include "ct2_runner.h" +#include "resource_path.h" + +namespace docwire::ai::local +{ + +translate::translate(const std::string& language) + : docwire::ai::translate(language, + std::make_shared(resource_path("flan-t5-large-ct2-int8"))) +{} + +translate::translate(const std::string& language, std::shared_ptr runner) + : docwire::ai::translate(language, runner) +{} +} // namespace docwire::ai::local diff --git a/src/local_ai_translate.h b/src/local_ai_translate.h new file mode 100644 index 000000000..02d5c31f3 --- /dev/null +++ b/src/local_ai_translate.h @@ -0,0 +1,32 @@ +/*********************************************************************************************************************************************/ +/* DocWire SDK: Award-winning modern data processing in C++20. SourceForge Community Choice & + * Microsoft support. AI-driven processing. */ +/* Supports nearly 100 data formats, including email boxes and OCR. Boost efficiency in text + * extraction, web data extraction, data mining, */ +/* document analysis. Offline processing possible for security and confidentiality */ +/* */ +/* Copyright (c) SILVERCODERS Ltd, http://silvercoders.com */ +/* Project homepage: https://github.com/docwire/docwire */ +/* */ +/* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ +/*********************************************************************************************************************************************/ + +#ifndef DOCWIRE_TRANSLATE_H +#define DOCWIRE_TRANSLATE_H + +#include "ai_translate.h" +#include "local_ai_export.h" + +namespace docwire::ai::local +{ + +class DOCWIRE_LOCAL_AI_EXPORT translate : public docwire::ai::translate +{ + public: + translate(const std::string& language); + explicit translate(const std::string& language, std::shared_ptr runner); +}; + +} // namespace docwire::ai::local + +#endif // DOCWIRE_TRANSLATE_H diff --git a/src/model_chain_element.cpp b/src/model_chain_element.cpp index 8a8daf071..3a01c8301 100644 --- a/src/model_chain_element.cpp +++ b/src/model_chain_element.cpp @@ -10,30 +10,45 @@ /*********************************************************************************************************************************************/ #include "model_chain_element.h" - #include "data_source.h" #include "error_tags.h" #include "resource_path.h" #include "throw_if.h" -namespace docwire::local_ai +namespace docwire::ai { -model_chain_element::model_chain_element(const std::string& prompt) - : docwire::local_ai::model_chain_element(prompt, std::make_shared(resource_path("flan-t5-large-ct2-int8"))) -{} +// model_chain_element::model_chain_element(const std::string& prompt, model_lifetime_policy lifetime) +// : docwire::local_ai::model_chain_element( +// prompt, std::make_shared(resource_path("flan-t5-large-ct2-int8")), +// lifetime) +// { +// } + +/** + * @brief constructor to run llama models + */ +model_chain_element::model_chain_element(const std::string& prompt, + std::shared_ptr runner, + model_lifetime_policy lifetime) + : m_prompt(prompt), m_model_runner(std::move(runner)), m_model_lifetime(lifetime) +{ +} continuation model_chain_element::operator()(message_ptr msg, const message_callbacks& emit_message) { - if (!msg->is()) - return emit_message(std::move(msg)); - - const data_source& data = msg->get(); - throw_if (!data.has_highest_confidence_mime_type_in({mime_type{"text/plain"}}), errors::program_logic{}); - std::string input = m_prompt + "\n" + data.string(); - std::string output = m_model_runner->process(input); - - return emit_message(data_source{std::move(output)}); + if (!msg->is()) + return emit_message(std::move(msg)); + + const data_source& data = msg->get(); + throw_if(!data.has_highest_confidence_mime_type_in({mime_type{"text/plain"}}), + errors::program_logic{}); + std::string input = m_prompt + "\n" + data.string(); + std::string output = m_model_runner->process(input); + if (m_model_lifetime == model_lifetime_policy::unload_after_use) { + m_model_runner->unload(); + } + return emit_message(data_source{std::move(output)}); } -} // namespace docwire::local_ai +} // namespace docwire::ai diff --git a/src/model_chain_element.h b/src/model_chain_element.h index 662aba715..9d05faf2a 100644 --- a/src/model_chain_element.h +++ b/src/model_chain_element.h @@ -9,16 +9,25 @@ /* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ /*********************************************************************************************************************************************/ -#ifndef DOCWIRE_LOCAL_AI_MODEL_CHAIN_ELEMENT_H -#define DOCWIRE_LOCAL_AI_MODEL_CHAIN_ELEMENT_H +#ifndef DOCWIRE_AI_MODEL_CHAIN_ELEMENT_H +#define DOCWIRE_AI_MODEL_CHAIN_ELEMENT_H +#include "ai_runner.h" #include "chain_element.h" -#include "local_ai_export.h" -#include "model_runner.h" +#include "ai_export.h" -namespace docwire::local_ai +namespace docwire::ai { +/*** + * @brief Model usage option for Load/Unload in memory + */ +enum class model_lifetime_policy +{ + persistent, // keeps the model in memory, which makes it availabel for next pipeline usage + unload_after_use // unloads after current usage +}; + /** * @brief A model chain element that processes input text using a model runner. * @@ -27,60 +36,58 @@ namespace docwire::local_ai * passing the text to the model runner. The output of the model runner is * then emitted as a new message_ptr object. */ -class DOCWIRE_LOCAL_AI_EXPORT model_chain_element : public ChainElement +class DOCWIRE_AI_EXPORT model_chain_element : public ChainElement { -public: - /** - * @brief Construct a model chain element. - * - * @param prompt The prompt to append to the input text. - * @param model_runner The model runner to use for processing the text. - */ - model_chain_element(const std::string& prompt, std::shared_ptr model_runner) - : m_prompt{prompt}, m_model_runner{model_runner} - {} + public: + /** + * @brief Construct a model chain element. + * + * @param prompt The prompt to append to the input text. + * @param ai_runner The model runner to use for processing the text. + * @param model_lifetime_policy Option to decide whether to unload model after usage or keep it persistent + */ + model_chain_element(const std::string& prompt, std::shared_ptr runner, model_lifetime_policy lifetime = model_lifetime_policy::persistent); - /** - * @brief Construct a model chain element with a default model runner. - * - * This constructor initializes the model chain element with a default - * `model_runner` configured to use the `flan-t5-large-ct2-int8` model. - * - * @param prompt The prompt to append to the input text. - */ - model_chain_element(const std::string& prompt); + /** + * @brief Construct a model chain element with a default model runner. + * + * This constructor initializes the model chain element with a default + * `model_runner` configured to use the `flan-t5-large-ct2-int8` model. + * + * @param prompt The prompt to append to the input text. + * @param model_lifetime_policy Option to decide whether to unload model after usage or keep it persistent + */ + model_chain_element(const std::string& prompt, model_lifetime_policy lifetime = model_lifetime_policy::persistent); - /** - * @brief Process the input. - * - * If the input is not a data source, emit the input and return. If the - * input is a data source, append the prompt to the input text and then - * pass the text to the model runner. The output of the model runner is - * then emitted as a new message_ptr object. - * - * @param msg The input message to process. - * @param emit_message Callback used to emit derived messages downstream. - */ - continuation operator()(message_ptr msg, const message_callbacks& emit_message) override; + /** + * @brief Process the input. + * + * If the input is not a data source, emit the input and return. If the + * input is a data source, append the prompt to the input text and then + * pass the text to the model runner. The output of the model runner is + * then emitted as a new message_ptr object. + * + * @param msg The input message to process. + * @param emit_message Callback used to emit derived messages downstream. + */ + continuation operator()(message_ptr msg, const message_callbacks& emit_message) override; - /** - * @brief Check if the model chain element is a leaf. - * - * The model chain element is never a leaf, so this function always returns - * false. - * - * @return false Always false. - */ - bool is_leaf() const override - { - return false; - } + /** + * @brief Check if the model chain element is a leaf. + * + * The model chain element is never a leaf, so this function always returns + * false. + * + * @return false Always false. + */ + bool is_leaf() const override { return false; } -private: - std::string m_prompt; - std::shared_ptr m_model_runner; + private: + std::string m_prompt; + std::shared_ptr m_model_runner; + model_lifetime_policy m_model_lifetime; }; -} // namespace docwire::local_ai +} // namespace docwire::ai #endif // DOCWIRE_AI_MODEL_CHAIN_ELEMENT_H diff --git a/src/model_inference_config.h b/src/model_inference_config.h new file mode 100644 index 000000000..eb2706caa --- /dev/null +++ b/src/model_inference_config.h @@ -0,0 +1,39 @@ +/*********************************************************************************************************************************************/ +/* DocWire SDK: Award-winning modern data processing in C++20. SourceForge Community Choice & Microsoft support. AI-driven processing. */ +/* Supports nearly 100 data formats, including email boxes and OCR. Boost efficiency in text extraction, web data extraction, data mining, */ +/* document analysis. Offline processing possible for security and confidentiality */ +/* */ +/* Copyright (c) SILVERCODERS Ltd, http://silvercoders.com */ +/* Project homepage: https://github.com/docwire/docwire */ +/* */ +/* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ +/*********************************************************************************************************************************************/ + +#ifndef DOCWIRE_AI_MODEL_INFERENCE_CONFIG_H +#define DOCWIRE_AI_MODEL_INFERENCE_CONFIG_H +#include "model_inference_config_type.h" +#include + +namespace docwire::ai +{ +/* + * @brief Handles configuration for llama model initialization and paramters + */ +struct model_inference_config +{ + std::string model_path; + context_size n_ctx{4096}; + batch_size n_batch{1024}; + thread_count n_threads{4}; + token_limit max_tokens{512}; + temperature temp{0.2f}; + min_p min_probability{0.05f}; + bool verbose = false; + std::string system_prompt = "Do NOT repeat the input. Answer concisely and directly."; + std::string grammar{}; + std::string grammar_root = "root"; +}; + +} // namespace docwire::ai + +#endif diff --git a/src/model_inference_config_type.h b/src/model_inference_config_type.h new file mode 100644 index 000000000..4fcb0714a --- /dev/null +++ b/src/model_inference_config_type.h @@ -0,0 +1,48 @@ +/*********************************************************************************************************************************************/ +/* DocWire SDK: Award-winning modern data processing in C++20. SourceForge Community Choice & Microsoft support. AI-driven processing. */ +/* Supports nearly 100 data formats, including email boxes and OCR. Boost efficiency in text extraction, web data extraction, data mining, */ +/* document analysis. Offline processing possible for security and confidentiality */ +/* */ +/* Copyright (c) SILVERCODERS Ltd, http://silvercoders.com */ +/* Project homepage: https://github.com/docwire/docwire */ +/* */ +/* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ +/*********************************************************************************************************************************************/ + +#ifndef DOCWIRE_AI_MODEL_INFERENCE_CONFIG_TYPE_H +#define DOCWIRE_AI_MODEL_INFERENCE_CONFIG_TYPE_H + +#include "strong_type.h" +#include + +namespace docwire::ai +{ +struct context_size_tag +{ +}; +struct thread_count_tag +{ +}; +struct token_limit_tag +{ +}; +struct temperature_tag +{ +}; +struct min_p_tag +{ +}; +struct batch_size_tag +{ +}; + +using batch_size = strong_type; +using context_size = strong_type; +using thread_count = strong_type; +using token_limit = strong_type; + +using temperature = strong_type; +using min_p = strong_type; +} // namespace docwire::ai + +#endif diff --git a/src/strong_type.h b/src/strong_type.h new file mode 100644 index 000000000..c6bbb846b --- /dev/null +++ b/src/strong_type.h @@ -0,0 +1,35 @@ +/*********************************************************************************************************************************************/ +/* DocWire SDK: Award-winning modern data processing in C++20. SourceForge + * Community Choice & Microsoft support. AI-driven processing. */ +/* Supports nearly 100 data formats, including email boxes and OCR. Boost + * efficiency in text extraction, web data extraction, data mining, */ +/* document analysis. Offline processing possible for security and + * confidentiality */ +/* */ +/* Copyright (c) SILVERCODERS Ltd, http://silvercoders.com */ +/* Project homepage: https://github.com/docwire/docwire */ +/* */ +/* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ + +#ifndef DOCWIRE_STRONG_TYPE_H +#define DOCWIRE_STRONG_TYPE_H + +#include +namespace docwire +{ +template class strong_type +{ + public: + using value_type = T; + + explicit constexpr strong_type(T v) noexcept : value_(std::move(v)) {} + + constexpr T get() const noexcept { return value_; } + + private: + T value_; +}; + +} // namespace docwire + +#endif diff --git a/src/tokenizer.cpp b/src/tokenizer.cpp index 9ad3875e9..061f63440 100644 --- a/src/tokenizer.cpp +++ b/src/tokenizer.cpp @@ -67,7 +67,7 @@ struct tokenizer_config } // anonymous namespace template<> -struct pimpl_impl : pimpl_impl_base +struct pimpl_impl : pimpl_impl_base { sentencepiece::SentencePieceProcessor m_processor; tokenizer_config m_tokenizer_config; @@ -116,7 +116,7 @@ struct pimpl_impl : pimpl_impl_base } }; -namespace local_ai +namespace ai::ct2 { tokenizer::tokenizer(const std::filesystem::path& model_data_path) @@ -173,5 +173,5 @@ std::string tokenizer::detokenize(const std::vector& output_tokens) return output; } -} // namespace local_ai +} // namespace ai::ct2 } // namespace docwire diff --git a/src/tokenizer.h b/src/tokenizer.h index c4b505704..8d880e09d 100644 --- a/src/tokenizer.h +++ b/src/tokenizer.h @@ -9,19 +9,19 @@ /* SPDX-License-Identifier: GPL-2.0-only OR LicenseRef-DocWire-Commercial */ /*********************************************************************************************************************************************/ -#ifndef DOCWIRE_LOCAL_AI_TOKENIZER_H -#define DOCWIRE_LOCAL_AI_TOKENIZER_H +#ifndef DOCWIRE_AI_TOKENIZER_H +#define DOCWIRE_AI_TOKENIZER_H -#include "local_ai_export.h" +#include "ai_ct2_export.h" #include "pimpl.h" #include #include #include -namespace docwire::local_ai +namespace docwire::ai::ct2 { -class DOCWIRE_LOCAL_AI_EXPORT tokenizer : public with_pimpl +class DOCWIRE_AI_CT2_EXPORT tokenizer : public with_pimpl { public: tokenizer(const std::filesystem::path& model_data_path); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index db820a2d4..084fa1d4e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -2,8 +2,13 @@ find_package(GTest CONFIG REQUIRED) add_executable(api_tests api_tests.cpp) target_link_libraries(api_tests PRIVATE docwire_core docwire_office_formats docwire_mail docwire_ocr docwire_archives - docwire_fuzzy_match docwire_base64 docwire_content_type docwire_http docwire_local_ai docwire_openai + docwire_fuzzy_match docwire_base64 docwire_content_type docwire_http docwire_openai GTest::gtest GTest::gmock) +if(TARGET docwire_local_ai) + target_link_libraries(api_tests PRIVATE docwire_local_ai) +endif() + + target_compile_definitions(api_tests PRIVATE DOCWIRE_ENABLE_SHORT_MACRO_NAMES) find_package(Boost REQUIRED COMPONENTS json) @@ -178,36 +183,53 @@ while (TRUE) endwhile() -foreach(example IN ITEMS - handling_errors_and_warnings - file_type_determination - path_to_text_stream - local_embedding_similarity - filter_emails_by_subject - stream_to_html - parse_archives - local_ai_classify - openai_classify - local_ai_translate - openai_translate - local_ai_sentiment - openai_sentiment - local_ai_summary - openai_voice_summary - openai_transcribe_summary - local_ai_find - openai_find_image - openai_embedding - reuse_chain - join_transformers - xml_parsing_example) +set(common_examples + handling_errors_and_warnings + file_type_determination + path_to_text_stream + filter_emails_by_subject + stream_to_html + parse_archives + openai_classify + openai_translate + openai_sentiment + openai_voice_summary + openai_transcribe_summary + openai_find_image + openai_embedding + reuse_chain + join_transformers + xml_parsing_example +) + +set(local_ai_examples + local_embedding_similarity + local_ai_classify + local_ai_translate + local_ai_sentiment + local_ai_summary + local_ai_find +) + +# Merge lists conditionally +set(all_examples ${common_examples}) + +if(TARGET docwire_ai_ct2) + list(APPEND all_examples ${local_ai_examples}) +endif() + +foreach(example IN LISTS all_examples) message(STATUS "Adding test ${example}") add_executable(${example} ${example}.cpp) target_include_directories(${example} PUBLIC ../src) target_link_libraries(${example} PRIVATE docwire_core docwire_office_formats docwire_mail - docwire_ocr docwire_archives docwire_ai docwire_local_ai docwire_openai docwire_fuzzy_match + docwire_ocr docwire_archives docwire_ai docwire_openai docwire_fuzzy_match docwire_content_type docwire_http) + if(TARGET docwire_local_ai) + target_link_libraries(${example} PRIVATE docwire_local_ai) + endif() + # Read the example file to check for OpenAI and local AI usage. file(READ ${example}.cpp example_content) string(FIND "${example_content}" "openai::" openai_pos) @@ -219,7 +241,7 @@ foreach(example IN ITEMS if(WIN32) set_property(TEST ${example} APPEND PROPERTY ENVIRONMENT "${docwire_test_env_path}") endif() - string(FIND "${example_content}" "local_ai::" local_ai_pos) + string(FIND "${example_content}" "ai::local::" local_ai_pos) if (NOT local_ai_pos EQUAL -1) message(STATUS "Labeling test ${example} as uses_model_runner") set_property(TEST ${example} PROPERTY LABELS "is_example;uses_model_runner") @@ -230,3 +252,45 @@ foreach(example IN ITEMS message(STATUS "Skipping test for ${example} because it uses OpenAI and OPENAI_API_KEY is not set.") endif() endforeach() + +if(TARGET docwire_ai_ct2) + message(STATUS "Adding Llama integration test") + add_executable(local_ai_ct2_integration local_ai_ct2_integration.cpp) + target_include_directories(local_ai_ct2_integration PRIVATE ../src) + target_link_libraries(local_ai_ct2_integration PRIVATE + docwire_core + docwire_office_formats + docwire_mail + docwire_ocr + docwire_archives + docwire_content_type + docwire_ai + docwire_local_ai + docwire_openai + docwire_fuzzy_match + docwire_http + ) + add_test(NAME local_ai_ct2_integration COMMAND local_ai_ct2_integration) + set_property(TEST local_ai_ct2_integration PROPERTY LABELS "is_example;uses_model_runner") +endif() + +if(TARGET docwire_ai_llama) + message(STATUS "Adding Llama integration test") + add_executable(local_ai_llama_integration local_ai_llama_integration.cpp) + target_include_directories(local_ai_llama_integration PRIVATE ../src) + target_link_libraries(local_ai_llama_integration PRIVATE + docwire_core + docwire_office_formats + docwire_mail + docwire_ocr + docwire_archives + docwire_content_type + docwire_ai + docwire_local_ai + docwire_openai + docwire_fuzzy_match + docwire_http + ) + add_test(NAME local_ai_llama_integration COMMAND local_ai_llama_integration) + set_property(TEST local_ai_llama_integration PROPERTY LABELS "is_example;uses_model_runner") +endif() diff --git a/tests/api_tests.cpp b/tests/api_tests.cpp index ba98f399c..ae8744e66 100644 --- a/tests/api_tests.cpp +++ b/tests/api_tests.cpp @@ -63,7 +63,9 @@ #include "serialization.h" #include "serialization_document_elements.h" // IWYU pragma: keep #include "throw_if.h" +#ifdef DOCWIRE_LOCAL_CT2 #include "tokenizer.h" +#endif #include "transformer_func.h" #include "txt_parser.h" #include "input.h" @@ -114,9 +116,9 @@ testing::PolymorphicMatcher> MessagePtrWith(const testi void escape_test_name(std::string& str) { std::transform(str.cbegin(), str.cend(), str.begin(), [](const auto ch) - { if(ch == '.') return '_'; + { if(ch == '.') return '_'; else if(ch == '-') return '_'; - else return ch; + else return ch; } ); } @@ -261,7 +263,7 @@ INSTANTIATE_TEST_SUITE_P( ), [](const ::testing::TestParamInfo& info) { std::string file_name = info.param; - escape_test_name(file_name); + escape_test_name(file_name); return file_name; }); @@ -271,8 +273,8 @@ class MetadataTest : public ::testing::TestWithParam static constexpr std::array names { - "meta_libreoffice_3.5_created", - "meta_libreoffice_3.5_modified" + "meta_libreoffice_3.5_created", + "meta_libreoffice_3.5_modified" }; }; @@ -287,7 +289,7 @@ TEST_P(MetadataTest, ParseFromPathTest) std::ifstream ifs{ file_name + ".out" }; ASSERT_TRUE(ifs.good()) << "File " << file_name << ".out" << " not found\n"; - + std::string expected_text{ std::istreambuf_iterator{ifs}, std::istreambuf_iterator{}}; @@ -332,7 +334,7 @@ TEST_P(CallbackTest, ParseFromPathTest) std::ifstream ifs{ output_name }; ASSERT_TRUE(ifs.good()) << "File " << file_name << ".out" << " not found\n"; - + std::string expected_text{ std::istreambuf_iterator{ifs}, std::istreambuf_iterator{}}; @@ -370,7 +372,7 @@ TEST_P(HTMLWriterTest, ParseFromPathTest) std::ifstream ifs{ file_name + ".out.html" }; ASSERT_TRUE(ifs.good()) << "File " << file_name << ".out.html" << " not found\n"; - + std::string expected_text{ std::istreambuf_iterator{ifs}, std::istreambuf_iterator{}}; @@ -384,7 +386,7 @@ TEST_P(HTMLWriterTest, ParseFromPathTest) office_formats_parser{} | mail_parser{} | OCRParser{} | HtmlExporter() | output_stream; - + // THEN EXPECT_EQ(expected_text, output_stream.str()); } @@ -420,7 +422,7 @@ TEST_P(PasswordProtectedTest, MajorTestingModule) // WHEN std::ostringstream output_stream{}; - try + try { std::filesystem::path{file_name} | content_type::by_file_extension::detector{} | @@ -433,7 +435,7 @@ TEST_P(PasswordProtectedTest, MajorTestingModule) { ASSERT_TRUE(errors::contains_type(ex)) << "Thrown exception diagnostic message:\n" << errors::diagnostic_message(ex); - } + } } INSTANTIATE_TEST_SUITE_P( @@ -703,7 +705,7 @@ class HttpServerTest : public ::testing::Test { http::server(addr, port, create_routes()); ScopedServer server_runner{std::move(server)}; - + std::ostringstream response_stream; std::string expected_response_body; const std::filesystem::path doc_path{"1.doc"}; @@ -733,7 +735,7 @@ class HttpServerTest : public ::testing::Test { { FAIL() << "Client pipeline threw an exception: " << errors::diagnostic_message(e); } - + EXPECT_EQ(response_stream.str(), expected_response_body); } @@ -1012,7 +1014,7 @@ std::string sanitize_expected_log_text(const std::string& orig_log_text) // to match the actual output on these platforms. static const std::regex re{R"x("function":".*?(\w+)\s*\([^"]*\)")x"}; return std::regex_replace(orig_log_text, re, R"y("function":"$1")y"); - } + } else { return orig_log_text; @@ -1163,7 +1165,7 @@ TEST(Logging, CerrLogRedirection) log::state_saver saver; log::set_sink(log::json_stream_sink(log_stream)); log::set_filter("*"); - + // Redirect cerr to a stringstream to verify that nothing is written to it. std::streambuf* original_cerr_buf = std::cerr.rdbuf(captured_cerr_stream.rdbuf()); @@ -1327,7 +1329,7 @@ TEST(Logging, Filtering) { log::state_saver saver; log::set_sink(log::json_stream_sink(log_stream)); - + // Filter to only include 'include_me' and 'scope_exit' tags, but exclude 'special'. log::set_filter("include_me,scope_exit,-special"); log_entry("this is excluded"); @@ -1772,7 +1774,7 @@ TEST(Input, path_ref) TEST(Input, path_temp) { - std::ostringstream output_stream{}; + std::ostringstream output_stream{}; std::filesystem::path{"1.doc"} | content_type::by_file_extension::detector{} | DOCParser{} | PlainTextExporter{} | output_stream; ASSERT_EQ(output_stream.str(), read_test_file("1.doc.out")); @@ -1790,7 +1792,7 @@ TEST(Input, vector_ref) TEST(Input, vector_temp) { - std::ostringstream output_stream{}; + std::ostringstream output_stream{}; std::string str = read_binary_file("1.doc"); std::vector{reinterpret_cast(str.data()), reinterpret_cast(str.data()) + str.size()} | content_type::by_signature::detector{} | @@ -1810,7 +1812,7 @@ TEST(Input, span_ref) TEST(Input, span_temp) { - std::ostringstream output_stream{}; + std::ostringstream output_stream{}; std::string str = read_binary_file("1.doc"); std::span{reinterpret_cast(str.data()), str.size()} | content_type::by_signature::detector{} | @@ -1829,7 +1831,7 @@ TEST(Input, string_ref) TEST(Input, string_temp) { - std::ostringstream output_stream{}; + std::ostringstream output_stream{}; read_binary_file("1.doc") | content_type::by_signature::detector{} | DOCParser{} | PlainTextExporter{} | output_stream; ASSERT_EQ(output_stream.str(), read_test_file("1.doc.out")); @@ -1847,7 +1849,7 @@ TEST(Input, string_view_ref) TEST(Input, string_view_temp) { - std::ostringstream output_stream{}; + std::ostringstream output_stream{}; std::string_view{read_binary_file("1.doc")} | content_type::by_signature::detector{} | DOCParser{} | PlainTextExporter{} | output_stream; ASSERT_EQ(output_stream.str(), read_test_file("1.doc.out")); @@ -2026,7 +2028,7 @@ TEST(TXTParser, paragraphs) MessagePtrWith(testing::Field(&document::Text::text, StrEq("Line"))), MessagePtrWith(_), MessagePtrWith(_) - )); + )); } TEST(HTMLParser, table) @@ -2256,7 +2258,7 @@ TEST(OCRParser, leptonica_stderr_capturer) { try { - data_source{std::string{"Incorrect image data"}, + data_source{std::string{"Incorrect image data"}, mime_type{"image/jpeg"}, confidence::highest} | OCRParser{} | std::vector{}; FAIL() << "OCRParser should have thrown an exception"; @@ -2618,10 +2620,10 @@ TEST(stringification, enums) { ASSERT_EQ(stringify(confidence::very_high), "very_high"); } - +#ifdef DOCWIRE_LOCAL_CT2 TEST(tokenizer, flan_t5) { - docwire::local_ai::tokenizer tokenizer { resource_path("flan-t5-large-ct2-int8") }; + docwire::ai::ct2::tokenizer tokenizer { resource_path("flan-t5-large-ct2-int8") }; // Test case for an empty input string. It should return only the end of sequence token. ASSERT_THAT(tokenizer.tokenize(""), @@ -2652,7 +2654,7 @@ TEST(tokenizer, multilingual_e5) { try { - docwire::local_ai::tokenizer tokenizer { resource_path("multilingual-e5-small-ct2-int8") }; + docwire::ai::ct2::tokenizer tokenizer { resource_path("multilingual-e5-small-ct2-int8") }; // Test case for an empty input string. It should return only the end of sequence token. ASSERT_THAT(tokenizer.tokenize(""), @@ -2683,7 +2685,7 @@ TEST(tokenizer, multilingual_e5) FAIL() << errors::diagnostic_message(e); } } - +#endif TEST(Convert, Chrono) { using namespace docwire::serialization; @@ -2891,9 +2893,9 @@ TEST(Serialization, TypedSummaryComplex) const auto& obj_pair = std::get(v_pair); ASSERT_EQ(obj_pair.v.size(), 2); // "typeid" and "value" EXPECT_EQ(std::get(obj_pair.v.at("typeid")), "std::pair"); - + const auto& nested_pair_value_obj = std::get(obj_pair.v.at("value")); // This is the object containing "first" and "second" - + // Check "first" element of the pair (std::string) const auto& first_elem_typed_summary = std::get(nested_pair_value_obj.v.at("first")); EXPECT_EQ(std::get(first_elem_typed_summary.v.at("typeid")), "std::string"); @@ -2902,9 +2904,9 @@ TEST(Serialization, TypedSummaryComplex) // Check that the nested data_source object is also typed const auto& second_elem_typed_summary = std::get(nested_pair_value_obj.v.at("second")); EXPECT_EQ(std::get(second_elem_typed_summary.v.at("typeid")), "docwire::data_source"); - + const auto& nested_ds_value_obj = std::get(second_elem_typed_summary.v.at("value")); // This is the object containing "path" and "file_extension" from the nested data_source - + // Check path within the nested data_source const auto& nested_path_typed_summary = std::get(nested_ds_value_obj.v.at("path")); EXPECT_EQ(std::get(nested_path_typed_summary.v.at("typeid")), "std::optional"); diff --git a/tests/integration_example.cmake b/tests/integration_example.cmake index aa64a8acd..e90a4637e 100644 --- a/tests/integration_example.cmake +++ b/tests/integration_example.cmake @@ -18,7 +18,7 @@ endif() # Link against the specific DocWire libraries your application needs. # docwire_core is always required. Other modules are optional based on usage. -target_link_libraries(integration_example PRIVATE docwire_core docwire_content_type docwire_office_formats docwire_local_ai) +target_link_libraries(integration_example PRIVATE docwire_core docwire_content_type docwire_office_formats) include(GNUInstallDirs) diff --git a/tests/integration_example.cpp b/tests/integration_example.cpp index 05cbed527..c70e44b43 100644 --- a/tests/integration_example.cpp +++ b/tests/integration_example.cpp @@ -32,7 +32,6 @@ int main(int argc, char* argv[]) content_type::detector{} | office_formats_parser() | PlainTextExporter() | - local_ai::model_chain_element("Write a short summary for this text:\n\n") | out_stream; } catch (const std::exception& e) diff --git a/tests/local_ai_classify.cpp b/tests/local_ai_classify.cpp index 67a54bf19..443f2faaa 100644 --- a/tests/local_ai_classify.cpp +++ b/tests/local_ai_classify.cpp @@ -8,7 +8,7 @@ int main(int argc, char* argv[]) try { - std::filesystem::path("document_processing_market_trends.odt") | content_type::detector{} | office_formats_parser{} | PlainTextExporter() | local_ai::model_chain_element("Classify to one of the following categories and answer with exact category name: agreement, invoice, report, legal, user manual, other:\n\n") | out_stream; + std::filesystem::path("document_processing_market_trends.odt") | content_type::detector{} | office_formats_parser{} | PlainTextExporter() | ai::local::task("Classify to one of the following categories and answer with exact category name: agreement, invoice, report, legal, user manual, other:\n\n") | out_stream; ensure(out_stream.str()) == "report"; } catch (const std::exception& e) diff --git a/tests/local_ai_ct2_integration.cpp b/tests/local_ai_ct2_integration.cpp new file mode 100644 index 000000000..a854f65b4 --- /dev/null +++ b/tests/local_ai_ct2_integration.cpp @@ -0,0 +1,21 @@ +#include "docwire.h" +#include "local_ai_task.h" +#include +#include + +int main(int argc, char* argv[]) +{ + using namespace docwire; + std::stringstream out_stream; + try { + std::filesystem::path("data_processing_definition.doc") | content_type::detector{} | office_formats_parser() | + PlainTextExporter() | + ai::local::task("Write a short summary for this text:\n\n") | out_stream; + } catch (const std::exception& e) { + std::cerr << errors::diagnostic_message(e) << std::endl; + return 1; + } + std::cout << out_stream.str() << std::endl; + + return 0; +} diff --git a/tests/local_ai_find.cpp b/tests/local_ai_find.cpp index 2d8c6d5ce..6096af78a 100644 --- a/tests/local_ai_find.cpp +++ b/tests/local_ai_find.cpp @@ -10,7 +10,7 @@ int main(int argc, char* argv[]) try { - std::filesystem::path("data_processing_definition.doc") | content_type::detector{} | office_formats_parser{} | PlainTextExporter() | local_ai::model_chain_element("Find sentence about \"data conversion\" in the following text:\n\n") | out_stream; + std::filesystem::path("data_processing_definition.doc") | content_type::detector{} | office_formats_parser{} | PlainTextExporter() | ai::local::task("Find sentence about \"data conversion\" in the following text:\n\n") | out_stream; ensure(out_stream.str()).is_one_of({ "Data processing refers to the activities performed on raw data to convert it into meaningful information.", "Data processing is the activities performed on raw data to convert it into meaningful information." diff --git a/tests/local_ai_llama_integration.cpp b/tests/local_ai_llama_integration.cpp new file mode 100644 index 000000000..43b5218db --- /dev/null +++ b/tests/local_ai_llama_integration.cpp @@ -0,0 +1,32 @@ +#include "docwire.h" +#include +#include +#include + +int main(int argc, char* argv[]) +{ + using namespace docwire; + std::stringstream out_stream; + docwire::ai::model_inference_config config; + config.model_path = "../models/qwen2-7b-instruct-q4_k_m.gguf"; + config.max_tokens = docwire::ai::token_limit{256}; + config.n_ctx = docwire::ai::context_size{4096}; + config.n_threads = docwire::ai::thread_count{4}; + config.temp = docwire::ai::temperature{0.2f}; + config.min_probability = docwire::ai::min_p{0.05f}; + auto runner = std::make_shared(config); + + try { + std::ofstream ofs("output.txt"); + data_source(std::string("LLMs help process long documents."), mime_type{"text/plain"}, + confidence::highest) | + ai::local::task("Summarize:\n\n", runner) | out_stream | ofs; + std::cout << "Text exported to output.txt" << std::endl; + + } catch (const std::exception& e) { + std::cerr << errors::diagnostic_message(e) << std::endl; + return 1; + } + + return 0; +} diff --git a/tests/local_ai_sentiment.cpp b/tests/local_ai_sentiment.cpp index e7c06ca0c..064df1fe5 100644 --- a/tests/local_ai_sentiment.cpp +++ b/tests/local_ai_sentiment.cpp @@ -8,7 +8,7 @@ int main(int argc, char* argv[]) try { - std::filesystem::path("data_processing_definition.doc") | content_type::detector{} | office_formats_parser{} | PlainTextExporter() | local_ai::model_chain_element("Detect sentiment:\n\n") | out_stream; + std::filesystem::path("data_processing_definition.doc") | content_type::detector{} | office_formats_parser{} | PlainTextExporter() | ai::local::task("Detect sentiment:\n\n") | out_stream; ensure(out_stream.str()) == "positive"; } catch (const std::exception& e) diff --git a/tests/local_ai_summary.cpp b/tests/local_ai_summary.cpp index 7607c4463..fac9c447d 100644 --- a/tests/local_ai_summary.cpp +++ b/tests/local_ai_summary.cpp @@ -8,7 +8,7 @@ int main(int argc, char* argv[]) try { - std::filesystem::path("data_processing_definition.doc") | content_type::detector{} | office_formats_parser{} | PlainTextExporter() | local_ai::model_chain_element("Write a short summary for this text:\n\n") | out_stream; + std::filesystem::path("data_processing_definition.doc") | content_type::detector{} | office_formats_parser{} | PlainTextExporter() | ai::local::summarize() | out_stream; ensure(out_stream.str()).is_one_of({ "Data processing is the collection, organization, analysis, and interpretation of data to extract useful insights and support decision-making.", "Data processing is the process of transforming raw data into meaningful information." diff --git a/tests/local_ai_translate.cpp b/tests/local_ai_translate.cpp index 9c3cdc0d5..4b00f3338 100644 --- a/tests/local_ai_translate.cpp +++ b/tests/local_ai_translate.cpp @@ -9,7 +9,7 @@ int main(int argc, char* argv[]) try { - std::filesystem::path("data_processing_definition.doc") | content_type::detector{} | office_formats_parser{} | PlainTextExporter() | local_ai::model_chain_element("Translate to spanish:\n\n") | out_stream; + std::filesystem::path("data_processing_definition.doc") | content_type::detector{} | office_formats_parser{} | PlainTextExporter() | ai::local::translate("spanish") | out_stream; ensure(fuzzy_match::ratio(out_stream.str(), "La procesación de datos se refiere a las actividades realizadas en el ámbito de los datos en materia de información. Se trata de recoger, organizar, analizar y interpretar los datos para extraer inteligencias y apoyar el procesamiento de decisión. Esto puede incluir tareas como la etiqueta, la filtración, la summarización y la transformación de los datos a través de diversos métodos compuestos y estadounidenses. El procesamiento de datos es esencial en diversos ámbitos, incluyendo el negocio, la ciencia y la tecnologàa, pues permite a las empresas a extraer conocimientos valiosos de grans de datos, hacer decisiones indicadas y mejorar la eficiencia global.")) > 80; } catch (const std::exception& e) diff --git a/tests/local_embedding_similarity.cpp b/tests/local_embedding_similarity.cpp index 0628393b7..238cd4bc9 100644 --- a/tests/local_embedding_similarity.cpp +++ b/tests/local_embedding_similarity.cpp @@ -1,4 +1,5 @@ #include "docwire.h" +#include "local_ai_embed.h" #include #include #include @@ -12,7 +13,7 @@ int main(int argc, char* argv[]) { // 1. Create an embedding for the document (passage) using the default prefix std::vector passage_msgs; - std::filesystem::path("data_processing_definition.doc") | content_type::detector{} | office_formats_parser{} | PlainTextExporter() | local_ai::embed(local_ai::embed::e5_passage_prefix) | passage_msgs; + std::filesystem::path("data_processing_definition.doc") | content_type::detector{} | office_formats_parser{} | PlainTextExporter() | ai::local::passage::embedder{} | passage_msgs; ensure(passage_msgs.size()) == 1; ensure(passage_msgs[0]->is()) == true; auto passage_embedding = passage_msgs[0]->get(); @@ -20,21 +21,21 @@ int main(int argc, char* argv[]) // 2. Create an embedding for a similar query using the query prefix std::vector similar_query_msgs; - docwire::data_source{std::string{"What is data processing?"}, mime_type{"text/plain"}, confidence::highest} | local_ai::embed(local_ai::embed::e5_query_prefix) | similar_query_msgs; + docwire::data_source{std::string{"What is data processing?"}, mime_type{"text/plain"}, confidence::highest} | ai::local::query::embedder{} | similar_query_msgs; ensure(similar_query_msgs.size()) == 1; ensure(similar_query_msgs[0]->is()) == true; auto similar_query_embedding = similar_query_msgs[0]->get(); // 3. Create an embedding for a partially related query std::vector partial_query_msgs; - docwire::data_source{std::string{"How can data analysis improve business efficiency?"}, mime_type{"text/plain"}, confidence::highest} | local_ai::embed(local_ai::embed::e5_query_prefix) | partial_query_msgs; + docwire::data_source{std::string{"How can data analysis improve business efficiency?"}, mime_type{"text/plain"}, confidence::highest} | ai::local::query::embedder{} | partial_query_msgs; ensure(partial_query_msgs.size()) == 1; ensure(partial_query_msgs[0]->is()) == true; auto partial_query_embedding = partial_query_msgs[0]->get(); // 4. Create an embedding for a dissimilar query std::vector dissimilar_query_msgs; - docwire::data_source{std::string{"What is the best C++ IDE?"}, mime_type{"text/plain"}, confidence::highest} | local_ai::embed(local_ai::embed::e5_query_prefix) | dissimilar_query_msgs; + docwire::data_source{std::string{"What is the best C++ IDE?"}, mime_type{"text/plain"}, confidence::highest} | ai::local::query::embedder{} | dissimilar_query_msgs; ensure(dissimilar_query_msgs.size()) == 1; ensure(dissimilar_query_msgs[0]->is()) == true; auto dissimilar_query_embedding = dissimilar_query_msgs[0]->get();