From a0e93cd5fa2115e45b7c140ea177c4fcf4a4f69e Mon Sep 17 00:00:00 2001 From: Maciej Janicki Date: Thu, 23 Apr 2026 17:59:32 +0300 Subject: [PATCH] Configurable URL for local models + remove Aitta-specific code This introduces new config options for configuring the url of the local models: - `llm_local_endpoint_url` (for LLMs), - `rag_settings.local_endpoint_url` (for the RAG embedding model). All references to Aitta are removed from the code. Instead, Aitta can be used by just setting the above URLs. The script for creating RAG DB has been adjusted: - removed redundant embedding calculation, - batching documents for better performance. --- api/routes/analysis.py | 2 +- config.yml | 24 +++---- frontend/src/components/SettingsPanel.tsx | 2 + pyproject.toml | 1 - rag/db_generation.py | 81 ++++++++--------------- src/climsight/climsight_engine.py | 21 ++---- src/climsight/embedding_utils.py | 72 -------------------- src/climsight/rag.py | 55 +++++++++------ src/climsight/smart_agent.py | 48 ++++---------- src/climsight/streamlit_interface.py | 1 - test/test_rag.py | 2 +- 11 files changed, 96 insertions(+), 213 deletions(-) delete mode 100644 src/climsight/embedding_utils.py diff --git a/api/routes/analysis.py b/api/routes/analysis.py index 3c577b0..c52e8ba 100644 --- a/api/routes/analysis.py +++ b/api/routes/analysis.py @@ -278,7 +278,7 @@ def _run_analysis( if "llm_combine" not in config: config["llm_combine"] = {} config["llm_combine"]["model_name"] = model_name - if "gpt" in model_name or "o1" in model_name or "o3" in model_name: + if model_name.startswith("gpt") or "o1" in model_name or "o3" in model_name: config["llm_combine"]["model_type"] = "openai" else: config["llm_combine"]["model_type"] = "local" diff --git a/config.yml b/config.yml index 67a183a..6694637 100644 --- a/config.yml +++ b/config.yml @@ -1,13 +1,13 @@ #model_names: gpt-4o, o1-preview, o1-mini -#model_type: "openai" #"openai / local / aitta +#model_type: "openai" #"openai / local llm_rag: - model_type: "openai" + model_type: "openai" model_name: "gpt-5-mini" # used only for RAGs llm_smart: #used only in smart_agent - model_type: "openai" + model_type: "openai" model_name: "gpt-5.2" # used only for smart agent llm_combine: #used only in combine_agent and intro - model_type: "openai" + model_type: "openai" model_name: "gpt-5.2" # used only for combine agent ("mkchaou/climsight-calm_ft_Q3_13k") llm_dataanalysis: #used only in data_analysis_agent model_type: "openai" @@ -20,6 +20,7 @@ use_smart_agent: true use_era5_data: true # Download ERA5 time series from CDS API (requires credentials) use_destine_data: true # Download DestinE projections via HDA API (requires DESP credentials) use_powerful_data_analysis: true +llm_local_endpoint_url: "http://localhost:8000/v1" # ERA5 Climatology Configuration (pre-computed observational baseline) era5_climatology: @@ -210,23 +211,22 @@ ecocrop: data_path: "./data/ecocrop/ecocrop_database/" rag_settings: rag_activated: True - # Which embedding backend to use: openai, aitta, mistral, etc. - embedding_model_type: "openai" # options: openai, aitta, mistral + # Which embedding backend to use: openai, local, mistral, etc. + embedding_model_type: "openai" # options: openai, local, mistral + local_endpoint_url: "http://localhost:8000/v1" # Embedding model name for each backend embedding_model_openai: "text-embedding-3-large" - embedding_model_aitta: "lightonai/modernbert-embed-large" + embedding_model_local: "lightonai/modernbert-embed-large" # Add more as needed, e.g.: # embedding_model_mistral: "mistral-embed-xyz" # Chroma DB paths for each backend chroma_path_ipcc_openai: "rag_db/ipcc_reports_openai" - chroma_path_ipcc_aitta: "rag_db/ipcc_reports_aitta" + chroma_path_ipcc_local: "rag_db/ipcc_reports_local" # chroma_path_ipcc_mistral: "rag_db/ipcc_reports_mistral" chroma_path_general_openai: "rag_db/general_reports_openai" - chroma_path_general_aitta: "rag_db/general_reports_aitta" + chroma_path_general_local: "rag_db/general_reports_local" # chroma_path_general_mistral: "rag_db/general_reports_mistral" - # AITTA configuration for open models (optional, only needed for aitta) - aitta_url: "https://api-climatedt-aitta.2.rahtiapp.fi" - document_path: './data/general_reports/' # or ipcc_text_reports + document_path: './data/ipcc_text_reports/' # or general_reports chunk_size: 2000 chunk_overlap: 200 separators: [" ", ",", "\n"] diff --git a/frontend/src/components/SettingsPanel.tsx b/frontend/src/components/SettingsPanel.tsx index 97f5846..2a962ba 100644 --- a/frontend/src/components/SettingsPanel.tsx +++ b/frontend/src/components/SettingsPanel.tsx @@ -10,6 +10,8 @@ const MODEL_OPTIONS = [ 'gpt-4.1-nano', 'gpt-4.1-mini', 'gpt-4.1', + 'openai/gpt-oss-120b', + 'meta-llama/Llama-3.3-70B-Instruct', ]; const CLIMATE_SOURCES = [ diff --git a/pyproject.toml b/pyproject.toml index 16f232c..df874e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,6 @@ dependencies = [ ] [project.optional-dependencies] -aitta = ["aitta-client"] dev = ["pytest", "flake8"] [build-system] diff --git a/rag/db_generation.py b/rag/db_generation.py index fd39384..30b8393 100644 --- a/rag/db_generation.py +++ b/rag/db_generation.py @@ -3,6 +3,7 @@ import os import logging +import tqdm import yaml import re @@ -20,7 +21,6 @@ # Import the new embedding utility import sys sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src')) -from climsight.embedding_utils import create_embeddings logger = logging.getLogger(__name__) logging.basicConfig( @@ -141,22 +141,18 @@ def split_docs(documents, chunk_size=2000, chunk_overlap=200, separators=[" ", " return docs -def chunk_and_embed_documents(document_path, embedding_model, openai_api_key, aitta_url, model_type, chunk_size=2000, chunk_overlap=200, separators=[" ", ",", "\n"]): +def chunk_documents(document_path, chunk_size=2000, chunk_overlap=200, separators=[" ", ",", "\n"]): """ - Chunks and embeds documents from the specified directory using provided embedding function. + Chunks documents from the specified directory. Args: - document_path (str): The path to the directory containing the documents. - - embedding_model (str): The embedding model name to use for generating embeddings. - - openai_api_key (str): OpenAI API key for OpenAI models. - - aitta_url (str): AITTA API URL for open models. - - model_type (str): The type of embedding model backend (e.g., 'openai', 'aitta'). - chunk_size (int): maximum number of characters per chunk. Default: 2000. - chunk_overlap (int): number of characters to overlap per chunk. Default: 200. - separators (list): list of characters where text can be split. Default: [" ", ",", "\n"] Returns: - - list: A list of documents with embeddings. + - list: A list of chunked documents. """ # load documents file_names = get_file_names(document_path) @@ -167,7 +163,7 @@ def chunk_and_embed_documents(document_path, embedding_model, openai_api_key, ai all_documents.extend(documents) # save all of them into one if not all_documents: - logger.info("No documents found for chunking and embedding.") + logger.info("No documents found for chunking.") return [] # Chunk documents @@ -180,27 +176,7 @@ def chunk_and_embed_documents(document_path, embedding_model, openai_api_key, ai logger.info(f"Chunked documents into {len(chunked_docs)} pieces.") - # Create embedding model using the utility function - try: - aitta_api_key = os.getenv('AITTA_API_KEY') - embedding_item = create_embeddings( - embedding_model=embedding_model, - openai_api_key=openai_api_key, - aitta_api_key=aitta_api_key, - aitta_url=aitta_url, - model_type=model_type - ) - # embedding documents - embedded_docs = [] - for doc in chunked_docs: - embedding = embedding_item.embed_documents([doc.page_content])[0] # embed_documents returns a list, so we take the first element - embedded_docs.append({"text": doc.page_content, "embedding": embedding, "metadata": doc.metadata}) - except Exception as e: - logger.error(f"Failed to embed document chunks: {e}") - return [] - - logger.info(f"Embedded {len(embedded_docs)} document chunks.") - return embedded_docs + return chunked_docs def initialize_rag(config): @@ -222,9 +198,9 @@ def initialize_rag(config): if embedding_model_type == 'openai': embedding_model = rag_settings.get('embedding_model_openai') chroma_path = rag_settings.get('chroma_path_ipcc_openai') - elif embedding_model_type == 'aitta': - embedding_model = rag_settings.get('embedding_model_aitta') - chroma_path = rag_settings.get('chroma_path_ipcc_aitta') + elif embedding_model_type == 'local': + embedding_model = rag_settings.get('embedding_model_local') + chroma_path = rag_settings.get('chroma_path_ipcc_local') # Add more types here as needed # elif embedding_model_type == 'mistral': # embedding_model = rag_settings.get('embedding_model_mistral') @@ -233,8 +209,7 @@ def initialize_rag(config): raise ValueError(f"Unknown embedding_model_type: {embedding_model_type}") openai_api_key = os.getenv('OPENAI_API_KEY') - aitta_api_key = os.getenv('AITTA_API_KEY') - aitta_url = rag_settings.get('aitta_url', os.getenv('AITTA_URL', 'https://api-climatedt-aitta.2.rahtiapp.fi')) + local_api_key = os.getenv('OPENAI_API_KEY_LOCAL') document_path = rag_settings['document_path'] chunk_size = rag_settings['chunk_size'] chunk_overlap = rag_settings['chunk_overlap'] @@ -244,8 +219,8 @@ def initialize_rag(config): if embedding_model_type == 'openai' and not openai_api_key: logger.warning("No OpenAI API Key found. Skipping RAG initialization.") return False - if embedding_model_type == 'aitta' and not aitta_api_key: - logger.warning("No AITTA API Key found. Skipping RAG initialization.") + if embedding_model_type == 'local' and not local_api_key: + logger.warning("No local API Key found. Skipping RAG initialization.") return False # check if documents are present and valid @@ -255,24 +230,24 @@ def initialize_rag(config): # Perform chunking and embedding try: - langchain_ef = create_embeddings( - embedding_model=embedding_model, - openai_api_key=openai_api_key, - aitta_api_key=aitta_api_key, - aitta_url=aitta_url, - model_type=embedding_model_type - ) - documents = chunk_and_embed_documents(document_path, embedding_model, openai_api_key, aitta_url, embedding_model_type, chunk_size, chunk_overlap, separators) - converted_documents = [ - Document(page_content=doc['text'], metadata=doc['metadata']) - for doc in documents - ] - rag_db = Chroma.from_documents( - documents=converted_documents, + if config['rag_settings']['embedding_model_type'] == 'local': + langchain_ef = OpenAIEmbeddings( + api_key=local_api_key, + base_url=rag_settings.get('local_endpoint_url'), + model=embedding_model, + tiktoken_enabled=False + ) + else: + langchain_ef = OpenAIEmbeddings(api_key=openai_api_key, model=embedding_model) + documents = chunk_documents(document_path, chunk_size, chunk_overlap, separators) + rag_db = Chroma( + collection_name="ipcc_collection", persist_directory=chroma_path, - embedding=langchain_ef, - collection_name="ipcc_collection" + embedding_function=langchain_ef ) + batch_size = 32 + for i in tqdm.tqdm(range(0, len(documents), batch_size)): + rag_db.add_documents(documents[i:i+batch_size]) rag_ready = True logger.info(f"RAG ready: {rag_ready}") logger.info("RAG database has been initialized and documents embedded.") diff --git a/src/climsight/climsight_engine.py b/src/climsight/climsight_engine.py index 5a02a0a..5ecf16e 100644 --- a/src/climsight/climsight_engine.py +++ b/src/climsight/climsight_engine.py @@ -69,7 +69,7 @@ write_climate_data_manifest, ) # import smart_agent -from smart_agent import get_aitta_chat_model, smart_agent +from smart_agent import smart_agent # import data_analysis_agent from data_analysis_agent import data_analysis_agent # import predefined data preparation functions @@ -663,7 +663,7 @@ def agent_llm_request(content_message, input_params, config, api_key, api_key_lo logger.info(f"start agent_request") if config['llm_combine']['model_type'] == "local": llm_combine_agent = ChatOpenAI( - openai_api_base="http://localhost:8000/v1", + openai_api_base=config['llm_local_endpoint_url'], model_name=config['llm_combine']['model_name'], # Match the exact model name you used openai_api_key=api_key_local, max_tokens=16000, @@ -688,12 +688,6 @@ def agent_llm_request(content_message, input_params, config, api_key, api_key_lo max_tokens=16000, ) llm_intro = llm_combine_agent - elif config['llm_combine']['model_type'] == 'aitta': - llm_combine_agent = get_aitta_chat_model( - config['llm_combine']['model_name'], - max_completion_tokens=4096 - ) - llm_intro = llm_combine_agent # Data analysis LLM (separate from combine step). llm_dataanalysis_cfg = config.get("llm_dataanalysis") @@ -701,7 +695,7 @@ def agent_llm_request(content_message, input_params, config, api_key, api_key_lo raise RuntimeError("Missing llm_dataanalysis configuration.") if llm_dataanalysis_cfg.get("model_type") == "local": llm_dataanalysis_agent = ChatOpenAI( - openai_api_base="http://localhost:8000/v1", + openai_api_base=config["llm_local_endpoint_url"], model_name=llm_dataanalysis_cfg.get("model_name"), openai_api_key=api_key_local, max_tokens=16000, @@ -712,11 +706,6 @@ def agent_llm_request(content_message, input_params, config, api_key, api_key_lo model_name=llm_dataanalysis_cfg.get("model_name"), max_tokens=16000, ) - elif llm_dataanalysis_cfg.get("model_type") == "aitta": - llm_dataanalysis_agent = get_aitta_chat_model( - llm_dataanalysis_cfg.get("model_name"), - max_completion_tokens=4096 - ) else: llm_dataanalysis_agent = llm_combine_agent @@ -1187,7 +1176,7 @@ class routeResponse(BaseModel): # Pass the dictionary to invoke input = {"user_text": state.user} response = chain.invoke(input) - elif config['llm_combine']['model_type'] in ("local", "aitta"): + elif config['llm_combine']['model_type'] == "local": prompt_text = intro_prompt.format(user_text=state.user) response_raw = llm_intro.invoke(prompt_text) import re, json @@ -1275,7 +1264,7 @@ def combine_agent(state: AgentState): state.content_message += "\n ECOCROP Search Response: {ecocrop_search_response} " logger.info(f"Ecocrop_search_response: {state.ecocrop_search_response}") - if config['llm_combine']['model_type'] in ("local", "aitta"): + if config['llm_combine']['model_type'] == "local": system_message_prompt = SystemMessagePromptTemplate.from_template(config['system_role']) elif config['llm_combine']['model_type'] == "openai": if "o1" in config['llm_combine']['model_name']: diff --git a/src/climsight/embedding_utils.py b/src/climsight/embedding_utils.py deleted file mode 100644 index adac49d..0000000 --- a/src/climsight/embedding_utils.py +++ /dev/null @@ -1,72 +0,0 @@ -import os -import logging -from typing import Optional -from langchain_openai.embeddings import OpenAIEmbeddings - -logger = logging.getLogger(__name__) - -def create_embeddings( - model_type: str, - embedding_model: str, - openai_api_key: Optional[str] = None, - aitta_api_key: Optional[str] = None, - aitta_url: Optional[str] = None, - model_name: Optional[str] = None -) -> OpenAIEmbeddings: - """ - Creates an embedding model instance based on the configuration. - - Args: - model_type (str): Which backend to use ('openai', 'aitta', ...) - embedding_model (str): The embedding model name or type - openai_api_key (str, optional): OpenAI API key for OpenAI models - aitta_api_key (str, optional): AITTA API key for open models - aitta_url (str, optional): AITTA API URL for open models - model_name (str, optional): Specific model name for open models - - Returns: - OpenAIEmbeddings: Configured embedding model instance - - Raises: - ValueError: If required parameters are missing or configuration is invalid - """ - - if model_type == 'openai': - if not openai_api_key: - raise ValueError("OPENAI_API_KEY is required for OpenAI models") - return OpenAIEmbeddings( - api_key=openai_api_key, # type: ignore - model=embedding_model - ) - elif model_type == 'aitta': - if not aitta_api_key: - raise ValueError("AITTA_API_KEY is required for aitta models") - if not aitta_url: - raise ValueError("AITTA URL is required for aitta models") - if not model_name: - model_name = embedding_model - - try: - # Import aitta_client only when needed - from aitta_client import Model, Client - from aitta_client.authentication import APIKeyAccessTokenSource - - client = Client(aitta_url, APIKeyAccessTokenSource(aitta_api_key, aitta_url)) - model = Model.load(model_name, client) - - return OpenAIEmbeddings( - model=model.id, - api_key=client.access_token_source.get_access_token(), # type: ignore - base_url=model.openai_api_url, - tiktoken_enabled=False - ) - except ImportError: - raise ImportError("aitta_client is required for aitta models. Install with: pip install aitta-client") - except Exception as e: - logger.error(f"Failed to create aitta model embeddings: {e}") - raise ValueError(f"Failed to create aitta model embeddings: {e}") - # elif model_type == 'mistral': - # # Add logic for mistral here - # pass - else: - raise ValueError(f"Unknown model_type: {model_type}") \ No newline at end of file diff --git a/src/climsight/rag.py b/src/climsight/rag.py index 788ee69..0995944 100644 --- a/src/climsight/rag.py +++ b/src/climsight/rag.py @@ -14,8 +14,6 @@ from langchain_openai import ChatOpenAI from langchain_core.documents.base import Document -from embedding_utils import create_embeddings - logger = logging.getLogger(__name__) logging.basicConfig( filename='climsight.log', @@ -43,14 +41,14 @@ def is_valid_rag_db(rag_db_path): chroma_file = os.path.join(rag_db_path, 'chroma.sqlite3') if not os.path.exists(chroma_file): + print(f"RAG DB validation failed: '{chroma_file}' does not exist.") return False folder_name = get_folder_name(rag_db_path) if folder_name is None: + print(f"RAG DB validation failed: '{rag_db_path}' is not a recognized type.") return False - folder_path = os.path.join(rag_db_path, folder_name) - if os.path.isdir(folder_path) and os.listdir(folder_path): # check if folder is non-empty - return True - return False + print(f"RAG DB validation: '{chroma_file}' exists.") + return True def load_rag(config, openai_api_key=None, db_type='ipcc'): @@ -73,9 +71,9 @@ def load_rag(config, openai_api_key=None, db_type='ipcc'): if embedding_model_type == 'openai': embedding_model = rag_settings.get('embedding_model_openai') chroma_path = rag_settings.get(f'chroma_path_{db_type}_openai') - elif embedding_model_type == 'aitta': - embedding_model = rag_settings.get('embedding_model_aitta') - chroma_path = rag_settings.get(f'chroma_path_{db_type}_aitta') + elif embedding_model_type == 'local': + embedding_model = rag_settings.get('embedding_model_local') + chroma_path = rag_settings.get(f'chroma_path_{db_type}_local') # Add more types here as needed # elif embedding_model_type == 'mistral': # embedding_model = rag_settings.get('embedding_model_mistral') @@ -83,10 +81,6 @@ def load_rag(config, openai_api_key=None, db_type='ipcc'): else: raise ValueError(f"Unknown embedding_model_type: {embedding_model_type}") - # Use the openai_api_key parameter as-is (do not overwrite) - aitta_api_key = os.getenv('AITTA_API_KEY') - aitta_url = rag_settings.get('aitta_url', os.getenv('AITTA_URL', 'https://api-climatedt-aitta.2.rahtiapp.fi')) - rag_ready = False valid_rag_db = is_valid_rag_db(chroma_path) if not valid_rag_db: @@ -95,13 +89,18 @@ def load_rag(config, openai_api_key=None, db_type='ipcc'): return rag_ready, rag_db try: - langchain_ef = create_embeddings( - model_type=embedding_model_type, - embedding_model=embedding_model, - openai_api_key=openai_api_key, - aitta_api_key=aitta_api_key, - aitta_url=aitta_url - ) + if config['rag_settings']['embedding_model_type'] == 'local': + langchain_ef = OpenAIEmbeddings( + api_key=os.getenv('OPENAI_API_KEY_LOCAL'), # type: ignore, + base_url=config['rag_settings']['local_endpoint_url'], + model=config['rag_settings']['embedding_model_local'], + tiktoken_enabled=False + ) + else: + langchain_ef = OpenAIEmbeddings( + api_key=openai_api_key, + model=config['rag_settings']['embedding_model_openai'], + ) rag_db = Chroma(persist_directory=chroma_path, embedding_function=langchain_ef, collection_name="ipcc_collection") logger.info(f"RAG database loaded with {rag_db._collection.count()} documents.") rag_ready = True @@ -176,7 +175,10 @@ def inspect(state): return state # First, retrieve documents to get sources - docs = retriever.get_relevant_documents(input_params['user_message']) + try: + docs = retriever.get_relevant_documents(input_params['user_message']) + except Exception as e: + docs = retriever.invoke(input_params['user_message']) sources_list = extract_sources(docs) # Get unique sources (filenames) unique_sources = list(set(sources_list)) @@ -186,11 +188,20 @@ def inspect(state): context = format_docs(docs) # Build the chain with pre-retrieved context + rag_llm = ChatOpenAI( + config['llm_rag']['model_name'], + api_key=openai_api_key, + base_url=( + config['rag_settings']['local_endpoint_url'] + if config['llm_rag']['model_type'] == 'local' + else None + ) + ) rag_chain = ( {"context": lambda _: context, "location": RunnableLambda(get_loci), "question": RunnablePassthrough()} | RunnableLambda(inspect) | custom_rag_prompt - | ChatOpenAI(model=config['llm_rag']['model_name'], api_key=openai_api_key) + | rag_llm | StrOutputParser() ) rag_response = rag_chain.invoke(input_params['user_message']) diff --git a/src/climsight/smart_agent.py b/src/climsight/smart_agent.py index 91a1fdf..60899d6 100644 --- a/src/climsight/smart_agent.py +++ b/src/climsight/smart_agent.py @@ -42,29 +42,12 @@ #Import for working Path import uuid from pathlib import Path -try: - from aitta_client import Model, Client - from aitta_client.authentication import APIKeyAccessTokenSource -except: - pass # Import AgentState from climsight_classes from climsight_classes import AgentState import calendar import pandas as pd -def get_aitta_chat_model(model_name, **kwargs): - aitta_url = 'https://api-climatedt-aitta.2.rahtiapp.fi' - aitta_api_key = os.environ['AITTA_API_KEY'] - client = Client(aitta_url, APIKeyAccessTokenSource(aitta_api_key, aitta_url)) - model = Model.load(model_name, client) - access_token = client.access_token_source.get_access_token() - return ChatOpenAI( - openai_api_key=access_token, - openai_api_base=model.openai_api_url, - model_name=model.id, - **kwargs - ) def smart_agent(state: AgentState, config, api_key, api_key_local, stream_handler): #def smart_agent(state: AgentState, config, api_key): @@ -104,7 +87,7 @@ def smart_agent(state: AgentState, config, api_key, api_key_local, stream_handle " - ONLY use when the query explicitly mentions agriculture, crops, or food production.\n" " - Do NOT use for general climate, infrastructure, or energy queries.\n\n" ) - if config['llm_smart']['model_type'] in ("local", "aitta"): + if config['llm_smart']['model_type'] == "local": prompt += ( "**Tool use order:** Call tools one at a time, sequentially.\n\n" ) @@ -144,7 +127,7 @@ def process_wikipedia_article(query: str) -> str: # Initialize the LLM if config['llm_smart']['model_type'] == "local": llm = ChatOpenAI( - openai_api_base="http://localhost:8000/v1", + openai_api_base=config['llm_local_endpoint_url'], model_name=config['llm_smart']['model_name'], # Match the exact model name you used openai_api_key=api_key_local, temperature = temperature, @@ -155,9 +138,6 @@ def process_wikipedia_article(query: str) -> str: model_name=config['llm_smart']['model_name'], temperature=temperature ) - elif config['llm_smart']['model_type'] == "aitta": - llm = get_aitta_chat_model( - config['llm_smart']['model_name'], temperature = temperature) # Define your custom prompt template template = ( "Read the provided Wikipedia article: {wikipage}\n\n" @@ -287,7 +267,15 @@ def process_RAG_search(query: str) -> str: data_rag = config['rag_articles']['data_path'] # Load the persisted vector store - embeddings = OpenAIEmbeddings(api_key=api_key) + if config['rag_settings']['embedding_model_type'] == 'local': + embeddings = OpenAIEmbeddings( + api_key=api_key_local, # type: ignore, + base_url=config['rag_settings']['local_endpoint_url'], + model=config['rag_settings']['embedding_model_local'], + tiktoken_enabled=False + ) + else: + embeddings = OpenAIEmbeddings(api_key=api_key) vectorstore = Chroma( persist_directory=data_rag, embedding_function=embeddings @@ -345,7 +333,7 @@ def process_RAG_search(query: str) -> str: # Initialize the LLM if config['llm_smart']['model_type'] == "local": llm = ChatOpenAI( - openai_api_base="http://localhost:8000/v1", + openai_api_base=config['llm_local_endpoint_url'], model_name=config['llm_smart']['model_name'], # Match the exact model name you used openai_api_key=api_key_local, temperature = temperature, @@ -356,9 +344,6 @@ def process_RAG_search(query: str) -> str: model_name=config['llm_smart']['model_name'], temperature=temperature ) - elif config['llm_smart']['model_type'] == "aitta": - llm = get_aitta_chat_model( - config['llm_smart']['model_name'], temperature = temperature) # Create the chain with the prompt and LLM chain = prompt | llm @@ -408,7 +393,7 @@ def process_ecocrop_search(query: str) -> str: # Initialize the LLM if config['llm_smart']['model_type'] == "local": llm = ChatOpenAI( - openai_api_base="http://localhost:8000/v1", + openai_api_base=config['llm_local_endpoint_url'], model_name=config['llm_smart']['model_name'], # Match the exact model name you used openai_api_key=api_key_local, temperature = 0, @@ -419,8 +404,6 @@ def process_ecocrop_search(query: str) -> str: model_name=config['llm_smart']['model_name'], temperature=0.0 ) - elif config['llm_smart']['model_type'] == "aitta": - llm = get_aitta_chat_model(config['llm_smart']['model_name'], temperature = 0) # Create the prompt template prompt = ChatPromptTemplate.from_template(""" @@ -474,7 +457,7 @@ def process_ecocrop_search(query: str) -> str: # Initialize the LLM if config['llm_smart']['model_type'] == "local": llm = ChatOpenAI( - openai_api_base="http://localhost:8000/v1", + openai_api_base=config['llm_local_endpoint_url'], model_name=config['llm_smart']['model_name'], openai_api_key=api_key_local, temperature = 0, @@ -486,9 +469,6 @@ def process_ecocrop_search(query: str) -> str: temperature=0.0 ) - elif config['llm_smart']['model_type'] == "aitta": - llm = get_aitta_chat_model(config['llm_smart']['model_name'], temperature = 0) - # List of tools tools = [rag_tool, ecocrop_tool, wikipedia_tool] diff --git a/src/climsight/streamlit_interface.py b/src/climsight/streamlit_interface.py index c7f23fd..a5dc533 100644 --- a/src/climsight/streamlit_interface.py +++ b/src/climsight/streamlit_interface.py @@ -20,7 +20,6 @@ from data_container import DataContainer from climsight_engine import normalize_longitude, llm_request, forming_request, location_request from extract_climatedata_functions import plot_climate_data -from embedding_utils import create_embeddings from climate_data_providers import get_available_providers from sandbox_utils import ensure_thread_id, ensure_sandbox_dirs, get_sandbox_paths, clean_sandbox diff --git a/test/test_rag.py b/test/test_rag.py index df133f7..90f26ba 100644 --- a/test/test_rag.py +++ b/test/test_rag.py @@ -21,7 +21,7 @@ class TestLoadRag(unittest.TestCase): @patch('rag.is_valid_rag_db', return_value=True) # simulate case where db is valid @patch('rag.Chroma') - @patch('embedding_utils.OpenAIEmbeddings') + @patch('rag.OpenAIEmbeddings') def test_load_rag_when_ready(self, mock_openai_embeddings, mock_chroma, mock_is_valid_rag_db): # Mock the embeddings and chroma instances mock_embedding_instance = MagicMock()