diff --git a/integrate/obsidian_sync.py b/integrate/obsidian_sync.py new file mode 100644 index 00000000..de18c42f --- /dev/null +++ b/integrate/obsidian_sync.py @@ -0,0 +1,286 @@ +import os +import sys +import requests +import yaml +import logging +import argparse +import hashlib +import json +from pathlib import Path +from datetime import datetime + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger("ObsidianSync") + +# Constants +ALLOWED_EXTENSIONS = { + ".txt", ".pdf", ".md", + ".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".tiff" +} +STATE_FILE = ".obsidian_sync_state.json" + +class ObsidianSync: + def __init__(self, api_url, vault_path): + self.api_url = api_url.rstrip("/") + self.vault_path = Path(vault_path) + self.state_file_path = self.vault_path / STATE_FILE + self.state = self._load_state() + + if not self.vault_path.exists(): + raise ValueError(f"Vault path does not exist: {vault_path}") + + def _load_state(self): + if self.state_file_path.exists(): + try: + with open(self.state_file_path, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + logger.warning(f"Failed to load state file: {e}") + return {} + + def _save_state(self): + try: + with open(self.state_file_path, 'w', encoding='utf-8') as f: + json.dump(self.state, f, indent=2) + except Exception as e: + logger.error(f"Failed to save state file: {e}") + + def _get_remote_memories(self): + """Fetch all memories from the server.""" + try: + response = requests.get(f"{self.api_url}/api/memories/list") + response.raise_for_status() + data = response.json() + if data.get("code") == 200: + return data.get("data", []) + else: + logger.error(f"API Error: {data.get('message')}") + return [] + except Exception as e: + logger.error(f"Failed to list remote memories: {e}") + return [] + + def _parse_frontmatter(self, file_path): + """Parse frontmatter from markdown file manually to avoid extra dependencies.""" + metadata = {} + try: + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + if content.startswith("---\n"): + end_index = content.find("\n---\n", 4) + if end_index != -1: + frontmatter = content[4:end_index] + try: + metadata = yaml.safe_load(frontmatter) + if not isinstance(metadata, dict): + metadata = {} + except yaml.YAMLError: + pass + except Exception as e: + logger.warning(f"Failed to parse frontmatter for {file_path}: {e}") + + return metadata + + def _calculate_file_hash(self, file_path): + """Calculate MD5 hash of a file.""" + hash_md5 = hashlib.md5() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + def _upload_file(self, file_path, metadata): + """Upload a file to the server.""" + url = f"{self.api_url}/api/memories/file" + + # Add source metadata + metadata['source'] = 'obsidian' + + # Convert metadata values to strings/JSON as needed for multipart/form-data + # Flatten metadata for simple key-value pairs + form_data = {} + for k, v in metadata.items(): + if isinstance(v, (dict, list)): + form_data[k] = json.dumps(v) + else: + form_data[k] = str(v) + + try: + with open(file_path, 'rb') as f: + files = {'file': (file_path.name, f)} + response = requests.post(url, files=files, data=form_data) + response.raise_for_status() + result = response.json() + if result.get("code") == 200: + logger.info(f"Successfully uploaded: {file_path.name}") + return True + else: + logger.error(f"Failed to upload {file_path.name}: {result.get('message')}") + return False + except Exception as e: + logger.error(f"Error uploading {file_path.name}: {e}") + return False + + def _delete_remote_file(self, filename): + """Delete a file from the server.""" + url = f"{self.api_url}/api/memories/file/{filename}" + try: + response = requests.delete(url) + response.raise_for_status() + result = response.json() + if result.get("code") == 200: + logger.info(f"Successfully deleted remote file: {filename}") + return True + else: + # If 404, it's already gone, which is fine + if "does not exist" in result.get("message", ""): + return True + logger.error(f"Failed to delete {filename}: {result.get('message')}") + return False + except Exception as e: + logger.error(f"Error deleting {filename}: {e}") + return False + + def sync(self): + logger.info("Starting sync...") + + # 1. Get remote files + remote_memories = self._get_remote_memories() + + # Filter remote files that are managed by obsidian sync + # We assume files with 'source': 'obsidian' in metadata are managed by us. + remote_obsidian_files = {} + for m in remote_memories: + meta = m.get("meta_data") or {} + # Check if source is obsidian. + # Note: The server might return meta_data as a dictionary or string depending on DB + if isinstance(meta, str): + try: + meta = json.loads(meta) + except: + meta = {} + + if meta.get("source") == "obsidian": + remote_obsidian_files[m["name"]] = m + + # 2. Scan local files + local_files = {} + for root, _, files in os.walk(self.vault_path): + for file in files: + if file == STATE_FILE: + continue + + file_path = Path(root) / file + if file_path.suffix.lower() not in ALLOWED_EXTENSIONS: + continue + + # We use filename as the key because the backend uses filename. + # WARNING: Duplicate filenames in different folders will cause conflict in the current backend. + # The backend flattens the directory structure. + # We will just warn for now or process the first one found. + if file in local_files: + logger.warning(f"Duplicate filename found: {file}. Backend does not support folders. Skipping {file_path}") + continue + + local_files[file] = file_path + + # 3. Process Uploads and Updates + for filename, file_path in local_files.items(): + try: + # Calculate hash and size + current_hash = self._calculate_file_hash(file_path) + current_size = file_path.stat().st_size + + # Extract metadata + metadata = {} + if file_path.suffix.lower() == ".md": + metadata = self._parse_frontmatter(file_path) + + # Add creation date if not present + if 'created' not in metadata: + created_ts = file_path.stat().st_ctime + metadata['created'] = datetime.fromtimestamp(created_ts).isoformat() + + # Check against state + state_entry = self.state.get(filename) + needs_upload = False + + if filename not in remote_obsidian_files: + logger.info(f"New file found: {filename}") + needs_upload = True + else: + # File exists remotely. Check if it needs update. + # We check if hash changed or if state says it's different + if not state_entry: + # No local state, but exists remotely. Trust remote? + # Or check if remote size matches. + # Ideally we assume if we don't have state, we might need to sync. + # But to save bandwidth, we can check remote size. + remote_size = remote_obsidian_files[filename].get("size") + if remote_size != current_size: + logger.info(f"Size mismatch for {filename}. Local: {current_size}, Remote: {remote_size}") + needs_upload = True + else: + # Assume synced if size matches and we have no state (initial sync of existing folder) + # Update state + self.state[filename] = {"hash": current_hash, "size": current_size} + else: + if state_entry.get("hash") != current_hash: + logger.info(f"File changed: {filename}") + needs_upload = True + + if needs_upload: + # If it exists remotely, delete it first (to update) + if filename in remote_obsidian_files: + self._delete_remote_file(filename) + + if self._upload_file(file_path, metadata): + self.state[filename] = {"hash": current_hash, "size": current_size} + self._save_state() + + except Exception as e: + logger.error(f"Error processing local file {filename}: {e}") + + # 4. Process Deletions + for filename in remote_obsidian_files: + if filename not in local_files: + logger.info(f"File deleted locally, removing from server: {filename}") + if self._delete_remote_file(filename): + if filename in self.state: + del self.state[filename] + self._save_state() + + logger.info("Sync completed.") + +def main(): + parser = argparse.ArgumentParser(description="Sync Obsidian vault to LPM Memories") + parser.add_argument("--api-url", help="LPM API URL (e.g. http://localhost:5000)") + parser.add_argument("--vault-path", help="Path to Obsidian vault") + + args = parser.parse_args() + + api_url = args.api_url or os.environ.get("LPM_API_URL") + vault_path = args.vault_path or os.environ.get("OBSIDIAN_VAULT_PATH") + + if not api_url: + logger.error("API URL must be provided via --api-url or LPM_API_URL env var") + sys.exit(1) + + if not vault_path: + logger.error("Vault path must be provided via --vault-path or OBSIDIAN_VAULT_PATH env var") + sys.exit(1) + + try: + syncer = ObsidianSync(api_url, vault_path) + syncer.sync() + except Exception as e: + logger.error(f"Fatal error: {e}") + sys.exit(1) + +if __name__ == "__main__": + main() diff --git a/integrate/requirements.txt b/integrate/requirements.txt index c086a74c..644fc090 100644 --- a/integrate/requirements.txt +++ b/integrate/requirements.txt @@ -2,4 +2,6 @@ wxpy==0.3.9.8 python-dotenv==0.19.0 torch>=1.8.0 transformers>=4.5.0 -numpy>=1.19.0 +numpy>=1.19.0 +requests>=2.25.0 +PyYAML>=5.4.0 diff --git a/lpm_kernel/L2/data_pipeline/data_prep/diversity/diversity_data_generator.py b/lpm_kernel/L2/data_pipeline/data_prep/diversity/diversity_data_generator.py index ccc76f72..38358f8b 100644 --- a/lpm_kernel/L2/data_pipeline/data_prep/diversity/diversity_data_generator.py +++ b/lpm_kernel/L2/data_pipeline/data_prep/diversity/diversity_data_generator.py @@ -14,6 +14,7 @@ from lpm_kernel.configs.config import Config from lpm_kernel.L2.data_pipeline.data_prep.diversity.utils import remove_similar_dicts import lpm_kernel.L2.data_pipeline.data_prep.diversity.template_diversity as template_diversity +from lpm_kernel.common.gemini_client import GeminiClient from lpm_kernel.configs.logging import get_train_process_logger logger = get_train_process_logger() @@ -58,11 +59,18 @@ def __init__(self, preference_language: str, is_cot: bool = True): self.model_name = None else: self.model_name = user_llm_config.chat_model_name - - self.client = openai.OpenAI( - api_key=user_llm_config.chat_api_key, - base_url=user_llm_config.chat_endpoint, - ) + + if user_llm_config.provider_type == 'gemini': + logger.info("Initializing Gemini client for DiversityData generation") + self.client = GeminiClient( + api_key=user_llm_config.chat_api_key, + base_url=user_llm_config.chat_endpoint + ) + else: + self.client = openai.OpenAI( + api_key=user_llm_config.chat_api_key, + base_url=user_llm_config.chat_endpoint, + ) self.preference_language = preference_language self.max_workers = os.environ.get("concurrency_threads", 2) self.data_synthesis_mode = os.environ.get("DATA_SYNTHESIS_MODE", "low") @@ -74,6 +82,12 @@ def __init__(self, preference_language: str, is_cot: bool = True): self.base_url = user_llm_config.thinking_endpoint if self.model_name.startswith("deepseek"): self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url) + elif user_llm_config.provider_type == 'gemini': + # For Gemini, assume thinking models are used if model name indicates it or just reuse main logic + if self.model_name: + logger.info(f"Using thinking model for Gemini: {self.model_name}") + # Client already initialized with Gemini + pass else: logger.error(f"Error model_name, longcot data generating model_name: deepseek series") raise diff --git a/lpm_kernel/L2/data_pipeline/data_prep/preference/preference_QA_generate.py b/lpm_kernel/L2/data_pipeline/data_prep/preference/preference_QA_generate.py index c7a25ce0..d7fcc5d7 100644 --- a/lpm_kernel/L2/data_pipeline/data_prep/preference/preference_QA_generate.py +++ b/lpm_kernel/L2/data_pipeline/data_prep/preference/preference_QA_generate.py @@ -16,6 +16,7 @@ EN_SYS_TEMPLATES, EN_SYS_COT_TEMPLATES ) import traceback +from lpm_kernel.common.gemini_client import GeminiClient from lpm_kernel.configs.logging import get_train_process_logger logger = get_train_process_logger() @@ -76,11 +77,19 @@ def __init__(self, filename: str, bio: str, preference_language: str, is_cot: bo self.model_name = None else: self.model_name = user_llm_config.chat_model_name - - self.client = openai.OpenAI( - api_key=user_llm_config.chat_api_key, - base_url=user_llm_config.chat_endpoint, - ) + + if user_llm_config.provider_type == 'gemini': + logger.info("Initializing Gemini client for PreferenceQA generation") + self.client = GeminiClient( + api_key=user_llm_config.chat_api_key, + base_url=user_llm_config.chat_endpoint + ) + else: + self.client = openai.OpenAI( + api_key=user_llm_config.chat_api_key, + base_url=user_llm_config.chat_endpoint, + ) + if self.is_cot: logger.info("generate pereference data in longcot pattern!!!") self.model_name = user_llm_config.thinking_model_name @@ -88,6 +97,17 @@ def __init__(self, filename: str, bio: str, preference_language: str, is_cot: bo self.base_url = user_llm_config.thinking_endpoint if self.model_name.startswith("deepseek"): self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url) + elif user_llm_config.provider_type == 'gemini': + # For Gemini, assume thinking models are used if model name indicates it or just reuse main logic + # Gemini doesn't strictly need a separate client if api key is same, but to be consistent with logic: + # If thinking model is specified and provider is Gemini, reuse GeminiClient with thinking credentials? + # Actually if provider is gemini, user might not have set 'thinking_model_name' to deepseek. + # If they did, then it's a conflict. + # Assuming if is_cot is true and provider is gemini, we use the Gemini client with thinking model name if provided. + if self.model_name: + logger.info(f"Using thinking model for Gemini: {self.model_name}") + # Client already initialized with Gemini + pass else: logger.error(f"Error model_name, longcot data generating model_name: deepseek series") raise diff --git a/lpm_kernel/L2/data_pipeline/data_prep/selfqa/selfqa_generator.py b/lpm_kernel/L2/data_pipeline/data_prep/selfqa/selfqa_generator.py index 8adc1f45..0fb765a0 100644 --- a/lpm_kernel/L2/data_pipeline/data_prep/selfqa/selfqa_generator.py +++ b/lpm_kernel/L2/data_pipeline/data_prep/selfqa/selfqa_generator.py @@ -11,6 +11,7 @@ ) from lpm_kernel.api.services.user_llm_config_service import UserLLMConfigService from lpm_kernel.configs.config import Config +from lpm_kernel.common.gemini_client import GeminiClient from lpm_kernel.configs.logging import get_train_process_logger logger = get_train_process_logger() @@ -75,10 +76,17 @@ def __init__( else: self.model_name = user_llm_config.chat_model_name - self.client = openai.OpenAI( - api_key=user_llm_config.chat_api_key, - base_url=user_llm_config.chat_endpoint, - ) + if user_llm_config.provider_type == 'gemini': + logger.info("Initializing Gemini client for SelfQA generation") + self.client = GeminiClient( + api_key=user_llm_config.chat_api_key, + base_url=user_llm_config.chat_endpoint + ) + else: + self.client = openai.OpenAI( + api_key=user_llm_config.chat_api_key, + base_url=user_llm_config.chat_endpoint, + ) self.max_workers = os.environ.get("concurrency_threads", 2) self.data_synthesis_mode = os.environ.get("DATA_SYNTHESIS_MODE", "low") if self.is_cot: @@ -88,6 +96,11 @@ def __init__( self.base_url = user_llm_config.thinking_endpoint if self.model_name.startswith("deepseek"): self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url) + elif user_llm_config.provider_type == 'gemini': + if self.model_name: + logger.info(f"Using thinking model for Gemini: {self.model_name}") + # Client already initialized with Gemini + pass else: logger.error(f"Error model_name, longcot data generating model_name: deepseek series") raise diff --git a/lpm_kernel/api/domains/kernel2/services/chat_service.py b/lpm_kernel/api/domains/kernel2/services/chat_service.py index 5ce5d348..08cba4d9 100644 --- a/lpm_kernel/api/domains/kernel2/services/chat_service.py +++ b/lpm_kernel/api/domains/kernel2/services/chat_service.py @@ -17,6 +17,7 @@ RoleBasedStrategy, KnowledgeEnhancedStrategy, ) +from lpm_kernel.common.gemini_client import GeminiClient logger = logging.getLogger(__name__) @@ -340,6 +341,11 @@ def chat( # Call LLM API try: + # If using GeminiClient, we can use the same interface thanks to the adapter + # But we log it specifically as requested + if isinstance(current_client, GeminiClient): + logger.info("Using Gemini client for chat") + response = current_client.chat.completions.create(**api_params) if not stream: logger.info(f"Response: {response.json() if hasattr(response, 'json') else response}") diff --git a/lpm_kernel/api/domains/memories/routes.py b/lpm_kernel/api/domains/memories/routes.py index 8dbf5b1c..ec9986c7 100644 --- a/lpm_kernel/api/domains/memories/routes.py +++ b/lpm_kernel/api/domains/memories/routes.py @@ -4,12 +4,26 @@ from lpm_kernel.configs.config import Config from lpm_kernel.common.logging import logger from lpm_kernel.file_data.document_service import DocumentService +from lpm_kernel.common.repository.database_session import DatabaseSession +from lpm_kernel.models.memory import Memory +from sqlalchemy import select memories_bp = Blueprint("memories", __name__) storage_service = StorageService(Config.from_env()) # Allowed file formats -ALLOWED_EXTENSIONS = {"txt", "pdf", "md"} +ALLOWED_EXTENSIONS = { + "txt", + "pdf", + "md", + "png", + "jpg", + "jpeg", + "gif", + "webp", + "bmp", + "tiff", +} def allowed_file(filename): @@ -17,6 +31,45 @@ def allowed_file(filename): return "." in filename and filename.rsplit(".", 1)[1].lower() in ALLOWED_EXTENSIONS +@memories_bp.route("/api/memories/list", methods=["GET"]) +def list_memories(): + """ + List all memories + Optional query param: source (e.g., 'obsidian') to filter by metadata source + """ + try: + logger.info("Starting to list memories") + source_filter = request.args.get('source') + + db = DatabaseSession() + with db._session_factory() as session: + query = select(Memory) + # Apply basic filtering if source is provided. + # Note: filtering JSON in a DB-agnostic way via SQLAlchemy can be complex. + # For simplicity and compatibility, we'll fetch and filter in Python if needed, + # or try to use basic string matching if appropriate, but Python filtering + # is safer for JSON structure consistency across DBs (SQLite/PG/etc). + + result = session.execute(query) + memories = result.scalars().all() + + data = [] + for memory in memories: + if source_filter: + meta = memory.meta_data or {} + # Handle if meta_data is stored as string in some legacy cases or distinct DB behavior + if isinstance(meta, dict) and meta.get('source') == source_filter: + data.append(memory.to_dict()) + else: + data.append(memory.to_dict()) + + logger.info(f"Listed {len(data)} memories") + return APIResponse.success(data=data, message="Memories listed successfully") + except Exception as e: + logger.error(f"Error listing memories: {str(e)}", exc_info=True) + return APIResponse.error(message=f"Internal server error: {str(e)}", code=500) + + @memories_bp.route("/api/memories/file", methods=["POST"]) def upload_file(): """ diff --git a/lpm_kernel/api/domains/upload/client.py b/lpm_kernel/api/domains/upload/client.py index 6b024228..173e207c 100644 --- a/lpm_kernel/api/domains/upload/client.py +++ b/lpm_kernel/api/domains/upload/client.py @@ -1,15 +1,6 @@ -import aiohttp import logging from lpm_kernel.api.domains.upload.TrainingTags import TrainingTags -from lpm_kernel.configs import config -import websockets -import json -import asyncio from lpm_kernel.configs.config import Config -import time -import requests -from lpm_kernel.api.common.responses import ResponseHandler -from lpm_kernel.api.domains.loads.load_service import LoadService from typing import Optional, List, Dict logger = logging.getLogger(__name__) @@ -18,10 +9,10 @@ class HeartbeatConfig: """Heartbeat Configuration Class""" def __init__( self, - interval: int = 30, # Heartbeat interval (seconds) - timeout: int = 10, # Heartbeat timeout (seconds) - max_retries: int = 3, # Maximum retry count - retry_interval: int = 5 # Retry interval (seconds) + interval: int = 30, + timeout: int = 10, + max_retries: int = 3, + retry_interval: int = 5 ): self.interval = interval self.timeout = timeout @@ -29,546 +20,69 @@ def __init__( self.retry_interval = retry_interval class RegistryClient: + """ + RegistryClient - Now disconnected from external servers. + Most methods are no-ops or return dummy data. + """ def __init__(self, heartbeat_config: HeartbeatConfig = None): - config = Config.from_env() - self.server_url = config.get("REGISTRY_SERVICE_URL") - # Convert HTTP URL to WebSocket URL - self.ws_url = self.server_url.replace('http://', 'ws://').replace('https://', 'wss://') - # Store all active WebSocket connections + # We don't need server URL anymore + self.server_url = None + self.ws_url = None self.active_connections = {} - # Heartbeat configuration self.heartbeat_config = heartbeat_config or HeartbeatConfig() def _get_auth_header(self): - """ - Get the Authorization header for authenticated requests - - Returns: - dict: Authorization header or empty dict if no credentials - """ - current_load, error, _ = LoadService.get_current_load(with_password=True) - if not current_load or not current_load.instance_id or not current_load.instance_password: - logger.info("No credentials found for auth") - return {} - instance_id = current_load.instance_id - instance_password = current_load.instance_password - - logger.info(f"Using credentials for auth: {instance_id}:{instance_password}") - return { - "Authorization": f"Bearer {instance_id}:{instance_password}" - } + return {} def get_ws_url(self, instance_id: str, instance_password: str) -> str: - """ - Generate WebSocket URL for the specified instance - - Args: - instance_id: Instance ID - instance_password: Instance password - - Returns: - str: WebSocket URL - """ - return f"{self.ws_url}/api/ws/{instance_id}?password={instance_password}" + return "" def register_upload(self, upload_name: str, instance_id: str = None, description: str = None, email: str = None, tags: TrainingTags = None): - """ - Register Upload instance with the registry center - - Args: - upload_name: Upload name - instance_id: Instance ID (optional) - description: Description (optional) - email: User email (optional) - - Returns: - Registration data - """ - headers = self._get_auth_header() - tags_dict = tags.model_dump() if tags else None - response = requests.post( - f"{self.server_url}/api/upload/register", - headers=headers, - json={ - "upload_name": upload_name, - "instance_id": instance_id, - "description": description, - "email": email, - "tags": tags_dict - } - ) - return ResponseHandler.handle_response( - response, - success_log=f"Upload {upload_name} registered successfully in registry center, instance ID: {instance_id}", - error_prefix="Registration" - ) + logger.info(f"Mocking register_upload for {upload_name}") + return { + "instance_id": instance_id or "local_instance", + "upload_name": upload_name + } def unregister_upload(self, instance_id: str): - """Unregister Upload instance from registry center - - Args: - instance_id: Instance ID - - Returns: - dict: Unregistration result - """ - headers = self._get_auth_header() - response = requests.delete( - f"{self.server_url}/api/upload/{instance_id}", - headers=headers - ) - return ResponseHandler.handle_response( - response, - success_log=f"Upload instance {instance_id} unregistered successfully from registry center", - error_prefix="Unregistration" - ) + logger.info(f"Mocking unregister_upload for {instance_id}") + return {"status": "success"} async def connect_websocket(self, instance_id: str, instance_password: str): - """Connect to registry center WebSocket and start keep-alive - - Args: - instance_id: Instance ID - instance_password: Instance password - - Returns: - websockets.WebSocketClientProtocol: WebSocket connection - """ - # Check if connection already exists and is active - connection_key = f"{instance_id}" - if connection_key in self.active_connections: - existing_ws = self.active_connections[connection_key] - try: - # Check if connection is still active and send heartbeat - if await self.send_heartbeat(existing_ws): - logger.info(f"Using existing WebSocket connection: {connection_key}") - return existing_ws - raise Exception("Heartbeat failed") - except Exception: - # If heartbeat fails, connection is disconnected, remove from active connections - logger.warning(f"Existing WebSocket connection is disconnected, creating new connection: {connection_key}") - del self.active_connections[connection_key] - - # Create new connection - ws_uri = self.get_ws_url(instance_id, instance_password) - try: - logger.info(f"Connecting to WebSocket: {ws_uri}") - websocket = await websockets.connect(ws_uri) - logger.info(f"WebSocket connection established: {ws_uri}") - - # Add additional attributes to WebSocket connection - websocket.instance_id = instance_id - websocket.connection_key = connection_key - - # Store new connection - self.active_connections[connection_key] = websocket - - # Add lock to prevent concurrent message reception - websocket.recv_lock = asyncio.Lock() - - # Start heartbeat task - websocket.heartbeat_task = asyncio.create_task( - self._keep_alive_with_ping(websocket, instance_id), - name=f"heartbeat_{connection_key}" - ) - await self.handle_messages(websocket) - - return websocket - except Exception as e: - logger.error(f"WebSocket connection failed: {str(e)}", exc_info=True) - raise + logger.info("Mocking connect_websocket - doing nothing") + return None async def _keep_alive(self, websocket, instance_id: str): - """Keep WebSocket connection alive - - Args: - websocket: WebSocket connection - instance_id: Instance ID - """ - connection_key = f"{instance_id}" - logger.info(f"Starting heartbeat task: {connection_key}") - - retry_count = 0 - last_success_time = time.time() - - try: - while True: - try: - # Send heartbeat at configured interval - await asyncio.sleep(self.heartbeat_config.interval) - - # Check last successful heartbeat time - if time.time() - last_success_time > self.heartbeat_config.interval * 2: - logger.warning(f"Upload (ID: {instance_id}) heartbeat timeout") - raise websockets.exceptions.ConnectionClosed(1006, "Heartbeat timeout") - - success = await self.send_heartbeat(websocket) - if success: - retry_count = 0 # Reset retry count - last_success_time = time.time() - # logger.info(f"Upload (ID: {instance_id}) heartbeat sent") - else: - retry_count += 1 - if retry_count >= self.heartbeat_config.max_retries: - logger.error(f"Upload (ID: {instance_id}) heartbeat retry count exceeded") - raise websockets.exceptions.ConnectionClosed(1006, "Heartbeat retry count exceeded") - logger.warning(f"Upload (ID: {instance_id}) heartbeat send failed, retrying {retry_count} times") - await asyncio.sleep(self.heartbeat_config.retry_interval) - continue - - except websockets.exceptions.ConnectionClosed as e: - logger.warning(f"Upload (ID: {instance_id}) WebSocket connection closed: {str(e)}") - # Clean up connection - if connection_key in self.active_connections: - del self.active_connections[connection_key] - # Cancel related tasks - if hasattr(websocket, 'message_task'): - websocket.message_task.cancel() - break - - except Exception as e: - logger.error(f"Upload (ID: {instance_id}) send heartbeat failed: {str(e)}", exc_info=True) - retry_count += 1 - if retry_count >= self.heartbeat_config.max_retries: - logger.error(f"Upload (ID: {instance_id}) heartbeat retry count exceeded") - raise - await asyncio.sleep(self.heartbeat_config.retry_interval) - - except asyncio.CancelledError: - logger.info(f"Heartbeat task cancelled: {connection_key}") - raise - except Exception as e: - logger.error(f"Upload (ID: {instance_id}) keep alive task failed: {str(e)}") - # Clean up connection - if connection_key in self.active_connections: - del self.active_connections[connection_key] - raise + pass async def _keep_alive_with_ping(self, websocket, instance_id: str): - """Keep WebSocket connection alive using native ping/pong - - Args: - websocket: WebSocket connection - instance_id: Instance ID - """ - connection_key = f"{instance_id}" - logger.info(f"Starting ping task: {connection_key}") - - try: - while True: - try: - await asyncio.sleep(self.heartbeat_config.interval) - await websocket.ping() - # logger.debug(f"Ping sent successfully for {instance_id}") - - except websockets.exceptions.ConnectionClosed as e: - logger.warning(f"Upload (ID: {instance_id}) WebSocket connection closed: {str(e)}") - if connection_key in self.active_connections: - del self.active_connections[connection_key] - if hasattr(websocket, 'message_task'): - websocket.message_task.cancel() - break - - except Exception as e: - logger.error(f"Upload (ID: {instance_id}) ping failed: {str(e)}") - if connection_key in self.active_connections: - del self.active_connections[connection_key] - raise - - except asyncio.CancelledError: - logger.info(f"Ping task cancelled: {connection_key}") - raise - except Exception as e: - logger.error(f"Upload (ID: {instance_id}) keep alive task failed: {str(e)}") - if connection_key in self.active_connections: - del self.active_connections[connection_key] - raise + pass async def send_heartbeat(self, websocket): - """Send heartbeat message - - Args: - websocket: WebSocket connection - - Returns: - bool: Whether heartbeat was sent successfully - """ - try: - heartbeat_message = json.dumps({ - "type": "heartbeat", - "data": { - "timestamp": int(time.time()), - "instance_id": websocket.instance_id if hasattr(websocket, 'instance_id') else 'unknown', - "status": "alive" - }, - "version": "1.0" - }) - # logger.info(f"Preparing to send heartbeat message: {heartbeat_message}") - - # Set send timeout - async with asyncio.timeout(self.heartbeat_config.timeout): - await websocket.send(heartbeat_message) - # logger.info("Heartbeat message sent successfully") - return True - - except asyncio.TimeoutError: - logger.error("Sending heartbeat message timed out") - return False - except Exception as e: - logger.error(f"Sending heartbeat failed: {str(e)}", exc_info=True) - return False + return True async def handle_messages(self, websocket): - """Handle received WebSocket messages""" - try: - while True: - try: - # Use lock to ensure only one coroutine calls recv at a time - async with websocket.recv_lock: - message = await websocket.recv() - data = json.loads(message) - message_type = data.get("type") - - if message_type == "heartbeat_ack": - continue - elif message_type == "chat": - # Handle chat request - try: - request_data = data.get("request", {}) - logger.info(f"[Request details: {json.dumps(request_data, ensure_ascii=False)}") - - # Call chat interface - async with aiohttp.ClientSession() as session: - logger.info(f"Preparing to send request to chat interface") - config = Config.from_env() - kernel2_url = f"{config.KERNEL2_SERVICE_URL}/api/kernel2/chat" - async with session.post( - kernel2_url, - json=request_data, - headers={ - "Content-Type": "application/json", - "Accept": "text/event-stream", # Specify to accept SSE response - "Cache-Control": "no-cache", - "Connection": "keep-alive" - }, - timeout=aiohttp.ClientTimeout(total=None), # Disable timeout - chunked=True # Enable chunked transfer - ) as response: - # Check response status - logger.info(f"Response status code: {response.status}") - if response.status != 200: - error_text = await response.text() - logger.error(f"[request_id: {data.get('request_id')}] Failed to call chat interface: {error_text}") - await websocket.send(json.dumps({ - "type": "chat_response", - "request_id": data.get("request_id"), - "error": f"Failed to call chat interface: {error_text}" - })) - continue - - logger.debug(f"Starting to read streaming response") - message_count = 0 - - # Direct forwarding of streaming response - async for line in response.content: - if line: - try: - # Convert bytes to string - decoded_line = line.decode('utf-8') - - logger.debug(f"[request_id: {data.get('request_id')}] Received raw data: {decoded_line.strip()}") - - # Check if it's SSE format data - if decoded_line.startswith("data: "): - message_count += 1 - data_content = decoded_line[6:].strip() - # logger.info(f"[request_id: {data.get('request_id')}] Processing message {message_count}") - - # Check if it's a completion marker - if data_content == "[DONE]": - - logger.info(f"[request_id: {data.get('request_id')}] Received completion marker, processed {message_count} messages in total") - await websocket.send(json.dumps({ - "type": "chat_response", - "request_id": data.get("request_id"), - "done": True - })) - continue - - # Directly forward original SSE data - await websocket.send(json.dumps({ - "type": "chat_response", - "request_id": data.get("request_id"), - "raw_sse": data_content, # Contains original SSE data - "done": False - })) - logger.debug(f"[requestId: {data.get('request_id')}] Forwarded SSE message #{message_count}") - except UnicodeDecodeError as e: - logger.error(f"[requestId: {data.get('request_id')}] Failed to decode response data: {str(e)}") - except Exception as e: - logger.error(f"[requestId: {data.get('request_id')}] Error processing response data: {str(e)}, type: {type(e).__name__}") - - except Exception as e: - logger.error(f"Failed to process chat request: {str(e)}") - await websocket.send(json.dumps({ - "type": "chat_response", - "request_id": data.get("request_id"), - "error": f"Error processing chat request: {str(e)}" - })) - else: - logger.debug(f"Received unknown message type: {message}") - except websockets.exceptions.ConnectionClosed: - logger.error("WebSocket connection closed") - break - except json.JSONDecodeError: - logger.error(f"Invalid JSON message: {message}") - except Exception as e: - logger.error(f"Failed to process message: {str(e)}") - except Exception as e: - logger.error(f"Message processing loop failed: {str(e)}") - raise + pass def list_uploads(self, page_no: int = 1, page_size: int = 10, status: Optional[List[str]] = None): - """Get list of registered Upload instances with pagination and status filter - - Args: - page_no (int): Page number, starting from 1 - page_size (int): Number of items per page - status (Optional[List[str]]): List of status to filter by - - Returns: - dict: Dictionary containing information about Upload instances - """ - # headers = self._get_auth_header() - params = { - "page_no": page_no, - "page_size": page_size + logger.info("Mocking list_uploads") + return { + "total": 0, + "items": [] } - if status: - params["status"] = status - - response = requests.get( - f"{self.server_url}/api/upload/list", - # headers=headers, - params=params - ) - return ResponseHandler.handle_response( - response, - error_prefix="Failed to retrieve list" - ) def count_uploads(self): - """Get count of all registered Upload instances - - Returns: - dict: Dictionary containing count of Upload instances - """ - response = requests.get( - f"{self.server_url}/api/upload/count", - ) - return ResponseHandler.handle_response( - response, - error_prefix="Failed to retrieve count" - ) + return {"count": 0} def get_upload_detail(self, instance_id: str) -> Dict: - """Get detailed information of an Upload instance - - Args: - instance_id (str): Instance ID of the Upload - - Returns: - dict: Dictionary containing instance information with the following fields: - upload_name (str): Name of the upload - instance_id (str): Instance ID - status (str): Current status of the upload - description (str, optional): Description of the upload - email (str, optional): Associated email address - registration_time (datetime): Time when the instance was registered - last_heartbeat (datetime, optional): Time of the last heartbeat - is_connected (bool, optional): Connection status, defaults to False - instance_password (str, optional): Password for instance registration - """ - headers = self._get_auth_header() - response = requests.get( - f"{self.server_url}/api/upload/{instance_id}", - headers=headers - ) - return ResponseHandler.handle_response( - response, - error_prefix="Failed to retrieve upload details" - ) + logger.info(f"Mocking get_upload_detail for {instance_id}") + return None def update_upload(self, instance_id: str, upload_name: str = None, capabilities: dict = None, email: str = None): - """Update Upload instance information in the registry center - - Args: - instance_id: Instance ID - upload_name: New upload name (optional) - capabilities: New capability set (optional) - email: New user email (optional) - - Returns: - dict: Update result - """ - update_data = {} - if upload_name is not None: - update_data["upload_name"] = upload_name - if capabilities is not None: - update_data["capabilities"] = capabilities - if email is not None: - update_data["email"] = email - - if not update_data: - logger.warning("No update data provided for update_upload") - return {"message": "No update data provided"} - - headers = self._get_auth_header() - response = requests.put( - f"{self.server_url}/api/upload/{instance_id}", - headers=headers, - json=update_data - ) - return ResponseHandler.handle_response( - response, - success_log=f"Upload instance {instance_id} updated successfully", - error_prefix="Update" - ) + logger.info(f"Mocking update_upload for {instance_id}") + return {"status": "success"} def create_role(self, role_id, name, description, system_prompt, icon, instance_id, is_active=True, enable_l0_retrieval=True, enable_l1_retrieval=True): - """Create a new role in the registry center - - Args: - role_id: Role UUID - name: Role name - description: Role description - system_prompt: System prompt - icon: Icon URL - instance_id: Instance ID - enable_l0_retrieval: Enable L0 retrieval - enable_l1_retrieval: Enable L1 retrieval - - Returns: - dict: Created role data - """ - headers = self._get_auth_header() - response = requests.post( - f"{self.server_url}/api/roles", - headers=headers, - json={ - "role_id": role_id, - "instance_id": instance_id, - "name": name, - "description": description, - "system_prompt": system_prompt, - "is_active": is_active, - "icon": icon, - "enable_l0_retrieval": enable_l0_retrieval, - "enable_l1_retrieval": enable_l1_retrieval - } - ) - return ResponseHandler.handle_response( - response, - success_log=f"Role {name} created successfully in registry center", - error_prefix="Role creation" - ) \ No newline at end of file + logger.info(f"Mocking create_role for {name}") + return {"status": "success"} diff --git a/lpm_kernel/api/domains/upload/routes.py b/lpm_kernel/api/domains/upload/routes.py index f3840138..b17a532a 100644 --- a/lpm_kernel/api/domains/upload/routes.py +++ b/lpm_kernel/api/domains/upload/routes.py @@ -16,45 +16,31 @@ from lpm_kernel.api.domains.upload.TrainingTags import TrainingTags upload_bp = Blueprint("upload", __name__) +# Registry client is now disabled/dummy registry_client = RegistryClient() logger = logging.getLogger(__name__) @upload_bp.route("/api/upload/register", methods=["POST"]) def register_upload(): - """Register upload instance""" + """Register upload instance - LOCAL ONLY""" try: current_load, error, status_code = LoadService.get_current_load() + if error: + return jsonify(APIResponse.error(code=status_code, message=error)) - upload_name = current_load.name - instance_id = current_load.instance_id - email = current_load.email - description = current_load.description - params = TrainingParamsManager.get_latest_training_params() - model_name = params.get("model_name") - is_cot = params.get("is_cot") - document_count = len(document_service.list_documents()) - tags = TrainingTags( - model_name=model_name, - is_cot=is_cot, - document_count=document_count - ) - - result = registry_client.register_upload( - upload_name, instance_id, description, email, tags - ) + # Simulating a successful registration locally without external calls + result = { + "instance_id": current_load.instance_id, + "upload_name": current_load.name, + "status": "registered (local)" + } - instance_id_new = result.get("instance_id") - if not instance_id_new: - return jsonify(APIResponse.error( - code=400, message="Failed to register upload instance" - )) + # We don't generate a new password or update credentials from remote - instance_password = result.get("instance_password") - LoadService.update_instance_credentials(instance_id_new, instance_password) - return jsonify(APIResponse.success( - data=result + data=result, + message="Upload registered locally (remote disabled)" )) except Exception as e: @@ -66,52 +52,24 @@ def register_upload(): @upload_bp.route("/api/upload/connect", methods=["POST"]) async def connect_upload(): """ - Establish WebSocket connection for the specified Upload instance - - URL parameters: - instance_id: Instance ID - upload_name: Upload name - - Returns: - { - "code": int, - "message": str, - "data": { - "ws_url": str # WebSocket connection URL - } - } + Establish WebSocket connection for the specified Upload instance - DISABLED """ - try: - logger.info("Starting WebSocket connection process...") + logger.info("WebSocket connection requested but remote connection is disabled.") current_load, error, status_code = LoadService.get_current_load(with_password=True) if error: return jsonify(APIResponse.error( code=status_code, message=error )) - instance_id = current_load.instance_id - instance_password = current_load.instance_password - - - - # Use thread to establish WebSocket connection asynchronously - def connect_ws(): - asyncio.run(registry_client.connect_websocket(instance_id, instance_password)) - - - thread = threading.Thread(target=connect_ws) - thread.daemon = True # Set as daemon thread, so it will end automatically when main program exits - thread.start() - result = { - "instance_id": instance_id, + "instance_id": current_load.instance_id, "upload_name": current_load.name } return jsonify(APIResponse.success( data=result, - message="WebSocket connection task started" + message="WebSocket connection disabled (local mode only)" )) except Exception as e: @@ -120,32 +78,13 @@ def connect_ws(): message=f"Failed to establish WebSocket connection: {str(e)}", code=500 )) - finally: - logger.info("WebSocket connection process completed.") @upload_bp.route("/api/upload/status", methods=["GET"]) def get_upload_status(): """ - Get the status of the specified Upload instance - - Returns: - { - "code": int, - "message": str, - "data": { - "upload_name": str, - "instance_id": str, - "status": str, - "description": str, - "email": str, - "is_connected": bool, - "last_ws_check": str, - "connection_alive": bool - } - } + Get the status of the specified Upload instance - LOCAL ONLY """ try: - current_load, error, status_code = LoadService.get_current_load() if error: return jsonify(APIResponse.error( @@ -153,41 +92,22 @@ def get_upload_status(): )) instance_id = current_load.instance_id - - # Check if instance exists - detail = registry_client.get_upload_detail(instance_id) - - logger.info(f"Upload status: {detail}") - # Get basic information from local + # Only return local data upload_data = { "upload_name": current_load.name, "instance_id": instance_id, "description": current_load.description, - "email": current_load.email + "email": current_load.email, + "status": "offline", # Always offline as no remote connection + "last_heartbeat": None, + "is_connected": False, + "last_ws_check": None } - # Process remote data, provide default values if null - if detail: - # Merge remote data - upload_data.update({ - "status": "online" if detail.get("is_connected") else "offline", - "last_heartbeat": detail.get("last_heartbeat"), - "is_connected": detail.get("is_connected", False), - "last_ws_check": detail.get("last_ws_check") - }) - else: - # Provide default values - upload_data.update({ - "status": "unregistered", - "last_heartbeat": None, - "is_connected": False, - "last_ws_check": None - }) - return jsonify(APIResponse.success( data=upload_data, - message="Successfully retrieved Upload instance status" + message="Successfully retrieved Upload instance status (local)" )) except Exception as e: @@ -200,33 +120,20 @@ def get_upload_status(): @upload_bp.route("/api/upload", methods=["DELETE"]) def unregister_upload(): """ - API for unregistering Upload instance - - URL parameters: - instance_id: Instance ID - upload_name: Upload name - - Returns: - { - "code": int, - "message": str, - "data": { - "instance_id": str, - "upload_name": str - } - } + API for unregistering Upload instance - LOCAL ONLY """ try: current_load, error, status_code = LoadService.get_current_load() instance_id = current_load.instance_id - registry_client.unregister_upload(instance_id) + + # No remote call to unregister return jsonify(APIResponse.success( data={ "instance_id": instance_id, "upload_name": current_load.name }, - message="Upload instance unregistered successfully" + message="Upload instance unregistered locally" )) except Exception as e: @@ -239,41 +146,21 @@ def unregister_upload(): @upload_bp.route("/api/upload", methods=["GET"]) def list_uploads(): """ - List registered Upload instances with pagination and status filter - - Query Parameters: - page_no (int): Page number, starting from 1 - page_size (int): Number of items per page - status (List[str], optional): List of status to filter by - - Returns: - { - "code": int, - "message": str, - "data": { - "instance_id": { - "upload_name": str, - "description": str, - "email": str, - "status": str - } - } - } + List registered Upload instances - LOCAL ONLY (Mock/Empty) """ try: - page_no = request.args.get("page_no", 1, type=int) - page_size = request.args.get("page_size", 10, type=int) - status = request.args.getlist("status") + # Return empty list or just current load if we want to simulate + # But usually this endpoint lists ALL uploads from registry. + # Since we are disconnected, we return empty list or just valid structure. - result = registry_client.list_uploads( - page_no=page_no, - page_size=page_size, - status=status if status else None - ) + result = { + "total": 0, + "items": [] + } return jsonify(APIResponse.success( data=result, - message="Successfully retrieved Upload list" + message="Successfully retrieved Upload list (local empty)" )) except Exception as e: @@ -286,23 +173,14 @@ def list_uploads(): @upload_bp.route("/api/upload/count", methods=["GET"]) def count_uploads(): """ - Get the number of registered Upload instances - - Returns: - { - "code": int, - "message": str, - "data": { - "count": int - } - } + Get the number of registered Upload instances - LOCAL ONLY """ try: - result = registry_client.count_uploads() + result = {"count": 0} return jsonify(APIResponse.success( data=result, - message="Successfully retrieved Upload count" + message="Successfully retrieved Upload count (local)" )) except Exception as e: @@ -315,33 +193,13 @@ def count_uploads(): @upload_bp.route("/api/upload", methods=["PUT"]) def update_upload(): """ - API for updating Upload instance information - - URL parameters: - instance_id: Instance ID - - Request body: - { - "upload_name": str (optional), - "description": str (optional), - "email": str (optional) - } - - Returns: - { - "code": int, - "message": str, - "data": { - "instance_id": str, - "upload_name": str, - "description": str, - "email": str, - "status": str - } - } + API for updating Upload instance information - LOCAL ONLY """ try: current_load, error, status_code = LoadService.get_current_load() + if error: + return jsonify(APIResponse.error(code=status_code, message=error)) + instance_id = current_load.instance_id data = request.get_json() @@ -351,26 +209,18 @@ def update_upload(): code=400 )) - upload_name = data.get("upload_name") - description = data.get("description") - email = data.get("email") + # We could technically update local Load info here if we wanted to support local updates via this API + # But this API was meant for Registry updates. + # For now, just return success. - # At least one update field is required - if upload_name is None and description is None and email is None: - return jsonify(APIResponse.error( - message="At least one update field is required: upload_name, description, or email", - code=400 - )) + result = { + "instance_id": instance_id, + "status": "updated (local)" + } - result = registry_client.update_upload( - instance_id=instance_id, - upload_name=upload_name, - description=description, - email=email - ) return jsonify(APIResponse.success( data=result, - message="Upload instance updated successfully" + message="Upload instance updated successfully (local)" )) diff --git a/lpm_kernel/api/services/expert_llm_service.py b/lpm_kernel/api/services/expert_llm_service.py index 0574a820..981916f2 100644 --- a/lpm_kernel/api/services/expert_llm_service.py +++ b/lpm_kernel/api/services/expert_llm_service.py @@ -3,9 +3,10 @@ """ import logging from lpm_kernel.api.services.user_llm_config_service import UserLLMConfigService -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, Union from openai import OpenAI from lpm_kernel.configs.config import Config +from lpm_kernel.common.gemini_client import GeminiClient logger = logging.getLogger(__name__) @@ -29,14 +30,23 @@ def __init__(self): # return self._config @property - def client(self) -> OpenAI: - """Get the OpenAI client for expert LLM""" + def client(self) -> Union[OpenAI, GeminiClient]: + """Get the LLM client (OpenAI or Gemini) for expert LLM""" if self._client is None: self.user_llm_config = self.user_llm_config_service.get_available_llm() - self._client = OpenAI( - api_key=self.user_llm_config.chat_api_key, - base_url=self.user_llm_config.chat_endpoint, - ) + + if self.user_llm_config.provider_type == 'gemini': + logger.info("Initializing Gemini client for Expert LLM service") + self._client = GeminiClient( + api_key=self.user_llm_config.chat_api_key, + base_url=self.user_llm_config.chat_endpoint + ) + else: + logger.info("Initializing OpenAI client for Expert LLM service") + self._client = OpenAI( + api_key=self.user_llm_config.chat_api_key, + base_url=self.user_llm_config.chat_endpoint, + ) return self._client def get_model_params(self) -> Dict[str, Any]: diff --git a/lpm_kernel/common/gemini_client.py b/lpm_kernel/common/gemini_client.py new file mode 100644 index 00000000..566150e9 --- /dev/null +++ b/lpm_kernel/common/gemini_client.py @@ -0,0 +1,220 @@ +import google.generativeai as genai +import logging +import time +import uuid +from typing import List, Dict, Any, Optional, Union, Iterator + +logger = logging.getLogger(__name__) + +class GeminiClient: + """ + Adapter for Google Gemini API to mimic OpenAI client interface. + """ + def __init__(self, api_key: str, base_url: Optional[str] = None): + if not api_key: + raise ValueError("API key is required for GeminiClient") + genai.configure(api_key=api_key) + self.base_url = base_url or "https://generativelanguage.googleapis.com" + self.chat = self.Chat(self) + + class Chat: + def __init__(self, client): + self.client = client + self.completions = self.Completions(client) + + class Completions: + def __init__(self, client): + self.client = client + + def create(self, **kwargs) -> Any: + return self.client._create_completion(**kwargs) + + def _convert_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, Any]]: + """Convert OpenAI messages to Gemini history format.""" + gemini_history = [] + system_instruction = None + + # Extract system prompt if present (Gemini supports system_instruction at model init) + # However, generate_content doesn't support system_instruction per call easily unless we use beta or specific models. + # Standard approach: Merge system prompt into first user message or use system_instruction if model supports it. + # For simplicity and broad support, we can prepend system prompt. + # UPDATE: Gemini 1.5 Pro/Flash supports system_instruction. + + # Let's separate system messages + system_messages = [m for m in messages if m.get("role") == "system"] + if system_messages: + system_instruction = "\n".join([m.get("content", "") for m in system_messages]) + + # Process user/assistant messages + # Gemini expects 'user' role as 'user' and 'assistant' as 'model'. + for msg in messages: + role = msg.get("role") + content = msg.get("content") + + if role == "system": + continue # Handled separately + + if role == "user": + gemini_history.append({"role": "user", "parts": [content]}) + elif role == "assistant": + gemini_history.append({"role": "model", "parts": [content]}) + + return gemini_history, system_instruction + + def _create_completion( + self, + messages: List[Dict[str, str]], + model: str, + stream: bool = False, + temperature: float = 0.7, + max_tokens: Optional[int] = None, + **kwargs + ) -> Any: + try: + # Handle model name (remove 'models/' prefix if present twice or ensure it matches Gemini format) + # Gemini models are usually "models/gemini-1.5-flash" or just "gemini-1.5-flash" + # If user provides "models/lpm", we might need to fallback or trust config. + # Assuming config provides valid Gemini model name. + + # Convert messages + history, system_instruction = self._convert_messages(messages) + + # Configure model + generation_config = genai.types.GenerationConfig( + temperature=temperature, + max_output_tokens=max_tokens + ) + + gemini_model = genai.GenerativeModel( + model_name=model, + system_instruction=system_instruction, + generation_config=generation_config + ) + + # Prepare chat or content generation + # If history is empty (only system prompt?), sending empty content might fail. + # If history has only one user message, use generate_content. + # If history has multiple, use start_chat. + + if not history: + # Should not happen in valid chat + raise ValueError("No user/model messages provided") + + last_message = history[-1] + if last_message["role"] != "user": + # OpenAI allows last message to be assistant (to continue?), Gemini expects user prompt last? + # Actually generate_content takes content. + # If we use chat, we need history + current message. + pass + + # We will use start_chat for history support + # Pop the last message as the new prompt + if history and history[-1]["role"] == "user": + prompt = history[-1]["parts"][0] + chat_history = history[:-1] + else: + # Fallback if last message is not user (e.g. continue generation? Not fully supported here) + prompt = " " # Empty prompt? + chat_history = history + + chat_session = gemini_model.start_chat(history=chat_history) + + response = chat_session.send_message(prompt, stream=stream) + + if stream: + return self._stream_response_adapter(response, model) + else: + return self._response_adapter(response, model) + + except Exception as e: + logger.error(f"Gemini API error: {str(e)}") + raise + + def _response_adapter(self, response, model): + """Adapt Gemini response to OpenAI format object.""" + # Wait for response completion + try: + text = response.text + except ValueError: + # Blocked content? + text = "" + if response.prompt_feedback: + logger.warning(f"Gemini prompt feedback: {response.prompt_feedback}") + + return OpenAIResponse({ + "id": f"chatcmpl-{uuid.uuid4()}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [{ + "index": 0, + "message": OpenAIMessage({ + "role": "assistant", + "content": text + }), + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 0, # Not easily available + "completion_tokens": 0, + "total_tokens": 0 + } + }) + + def _stream_response_adapter(self, response_iterator, model): + """Yield OpenAI-format chunks from Gemini stream.""" + response_id = f"chatcmpl-{uuid.uuid4()}" + created = int(time.time()) + + for chunk in response_iterator: + text = "" + try: + text = chunk.text + except ValueError: + continue + + yield OpenAIResponse({ + "id": response_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [{ + "index": 0, + "delta": { + "content": text + }, + "finish_reason": None + }] + }) + + # Yield finish reason + yield OpenAIResponse({ + "id": response_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": "stop" + }] + }) + +class OpenAIResponse(dict): + """Helper to allow dot access to dictionary.""" + def __init__(self, data): + super().__init__(data) + for k, v in data.items(): + if isinstance(v, dict): + self[k] = OpenAIResponse(v) + elif isinstance(v, list): + self[k] = [OpenAIResponse(i) if isinstance(i, dict) else i for i in v] + + def __getattr__(self, key): + try: + return self[key] + except KeyError: + raise AttributeError(key) + +class OpenAIMessage(OpenAIResponse): + pass diff --git a/lpm_kernel/common/strategy/classification.py b/lpm_kernel/common/strategy/classification.py index 415a34f6..f21f56ea 100644 --- a/lpm_kernel/common/strategy/classification.py +++ b/lpm_kernel/common/strategy/classification.py @@ -4,9 +4,13 @@ from typing import Optional import lpm_kernel.common.strategy.strategy_openai as openai import lpm_kernel.common.strategy.strategy_huggingface as huggingface +import lpm_kernel.common.strategy.strategy_gemini as gemini def strategy_classification(user_llm_config: Optional[UserLLMConfigDTO], chunked_texts): - if "sentence-transformers" in user_llm_config.embedding_endpoint: + if user_llm_config.provider_type == "gemini": + # Using Gemini strategy to generate embedding vectors + return gemini.gemini_strategy(user_llm_config, chunked_texts) + elif user_llm_config.embedding_endpoint and "sentence-transformers" in user_llm_config.embedding_endpoint: # Using Hugging Face strategy to generate embedding vectors return huggingface.huggingface_strategy(user_llm_config, chunked_texts) else: diff --git a/lpm_kernel/common/strategy/strategy_gemini.py b/lpm_kernel/common/strategy/strategy_gemini.py new file mode 100644 index 00000000..49705d2c --- /dev/null +++ b/lpm_kernel/common/strategy/strategy_gemini.py @@ -0,0 +1,70 @@ +from lpm_kernel.api.dto.user_llm_config_dto import UserLLMConfigDTO +from lpm_kernel.configs.logging import get_train_process_logger +import google.generativeai as genai +import os +import numpy as np +from typing import List, Union + +logger = get_train_process_logger() + +def gemini_strategy(user_llm_config: UserLLMConfigDTO, chunked_texts: List[str]) -> np.ndarray: + """ + Generate embeddings using Google Gemini API + + Args: + user_llm_config: User LLM Configuration + chunked_texts: List of text chunks to embed + + Returns: + numpy.ndarray: Array of embeddings + """ + try: + # Get API key from env var (preferred) or config + api_key = os.getenv("GEMINI_API_KEY") + if not api_key: + # Fallback to config if available, though requirements emphasize env var + api_key = user_llm_config.embedding_api_key or user_llm_config.key + + if not api_key: + raise ValueError("GEMINI_API_KEY environment variable not set") + + # Configure Gemini + genai.configure(api_key=api_key) + + # Get model name, default to text-embedding-004 if not specified + model_name = user_llm_config.embedding_model_name + if not model_name: + model_name = "models/text-embedding-004" + elif not model_name.startswith("models/"): + model_name = f"models/{model_name}" + + embeddings_list = [] + + # Iterate over chunks to get embedding for each + for text in chunked_texts: + try: + result = genai.embed_content( + model=model_name, + content=text, + task_type="retrieval_document", + title=None, + output_dimensionality=512 # Set default dimension to 512 as requested + ) + + if 'embedding' in result: + embeddings_list.append(result['embedding']) + else: + logger.warning(f"No embedding found for chunk: {text[:50]}...") + # Fallback or error? For now, if one fails, maybe we should raise + # But keeping consistent size is important. + # If empty, maybe add a zero vector? Or raise. + raise ValueError(f"Unexpected response from Gemini API: {result}") + except Exception as e: + logger.error(f"Error embedding chunk '{text[:50]}...': {str(e)}") + raise + + return np.array(embeddings_list) + + except Exception as e: + logger.error(f"Error generating Gemini embeddings: {str(e)}") + raise diff --git a/lpm_kernel/file_data/chunker.py b/lpm_kernel/file_data/chunker.py index 83010535..3c2a79d5 100644 --- a/lpm_kernel/file_data/chunker.py +++ b/lpm_kernel/file_data/chunker.py @@ -2,7 +2,7 @@ from lpm_kernel.L1.bio import Chunk import traceback import time -from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_text_splitters import RecursiveCharacterTextSplitter from lpm_kernel.configs.logging import get_train_process_logger logger = get_train_process_logger() diff --git a/lpm_kernel/utils.py b/lpm_kernel/utils.py index 6932ce8b..ccfcb422 100644 --- a/lpm_kernel/utils.py +++ b/lpm_kernel/utils.py @@ -3,7 +3,7 @@ import tiktoken import re from typing import Any, Optional, Union, Collection, AbstractSet, Literal, List -from langchain.text_splitter import TextSplitter +from langchain_text_splitters import TextSplitter import random import string from itertools import chain diff --git a/mcp/mcp_public.py b/mcp/mcp_public.py index 1f3bfa2d..fd7e72df 100644 --- a/mcp/mcp_public.py +++ b/mcp/mcp_public.py @@ -5,7 +5,8 @@ import requests mindverse = FastMCP("mindverse_public") -url = "app.secondme.io" +# External communication disabled +url = "localhost" messages =[] @@ -13,88 +14,23 @@ async def get_response(query:str, instance_id:str) -> str | None: """ Received a response based on public secondme model. + (Disabled: Returns mock response) Args: query (str): Questions raised by users regarding the secondme model. instance_id (str): ID used to identify the secondme model, or url used to identify the secondme model. """ - id = instance_id.split('/')[-1] - path = f"/api/chat/{id}" - headers = {"Content-Type": "application/json"} - messages.append({"role": "user", "content": query}) - - data = { - "messages": messages, - "metadata": { - "enable_l0_retrieval": False, - "role_id": "default_role" - }, - "temperature": 0.7, - "max_tokens": 2000, - "stream": True - } - - conn = http.client.HTTPSConnection(url) - - # Send the POST request - conn.request("POST", path, body=json.dumps(data), headers=headers) - - # Get the response - response = conn.getresponse() - - full_content = "" - - for line in response: - if line: - decoded_line = line.decode('utf-8').strip() - if decoded_line == 'data: [DONE]': - break - if decoded_line.startswith('data: '): - try: - json_str = decoded_line[6:] - chunk = json.loads(json_str) - content = chunk['choices'][0]['delta'].get('content', '') - if content: - full_content += content - except json.JSONDecodeError: - pass - - conn.close() - if full_content: - messages.append({"role": "system", "content": full_content}) - return full_content - else: - return None + return "Communication with external Second Me servers is disabled." @mindverse.tool() async def get_online_instances(): """ Check which secondme models are available for chatting online. + (Disabled: Returns empty list) """ - url = "https://app.secondme.io/api/upload/list?page_size=100" - response = requests.get(url) - - if response.status_code == 200: - data = response.json() - items = data.get("data", {}).get("items", []) - - online_items = [ - { - "upload_name": item["upload_name"], - "instance_id": item["instance_id"], - "description": item["description"] - } - for item in items if item.get("status") == "online" - ] - - return json.dumps(online_items, ensure_ascii=False, indent=2) - else: - raise Exception(f"Request failed with status code: {response.status_code}") + return json.dumps([], ensure_ascii=False, indent=2) if __name__ == "__main__": mindverse.run(transport='stdio') - - - diff --git a/pyproject.toml b/pyproject.toml index 2de8e583..4290e6b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,8 @@ sentence-transformers = "^2.6.0" # Development environment dependencies # Use 'poetry install --with dev' to install development dependencies +google-generativeai = "^0.8.5" +langchain-text-splitters = "^0.3.3" [tool.poetry.group.dev.dependencies] pytest = "7.4.4" ruff = "0.1.15"