diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 9d8b705..888fd81 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -28,18 +28,19 @@ jobs: uses: actions/cache@v3 with: path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('python/requirements.txt') }} + key: ${{ runner.os }}-pip-${{ hashFiles('python/requirements-ci.txt') }} restore-keys: | ${{ runner.os }}-pip- - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -e .[dev] + pip install -r python/requirements-ci.txt - name: Run tests - run: | - pytest -v --cov-report=xml --cov-report=term + run: pytest tests -v --cov-report=xml --cov-report=term --ignore=tests/test_llama_cpp_backend.py --ignore=tests/test_vllm_backend.py + env: + PYTHONPATH: ${{ github.workspace }}/python - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 @@ -60,7 +61,7 @@ jobs: - name: Install package and linting tools run: | - pip install -e .[dev] + pip install -r python/requirements-ci.txt - name: Run black (format check) run: | diff --git a/README.md b/README.md index 7b79282..966550a 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ nMaintained by [JustInternetAI](https://github.com/JustInternetAI) ### Core - **Godot C++ Module**: Deterministic tick loop, event bus, navigation, sensors, stable replay logs - **Agent Runtime**: Adapters for llama.cpp, TensorRT-LLM, vLLM with function-calling tool API +- **Model Management**: Automated LLM model downloading from Hugging Face Hub with caching and verification - **Tool System**: World querying (vision rays, inventories), pathfinding, crafting actions via JSON schemas - **Memory & RAG**: Short-term scratchpad + long-term vector store with episode summaries - **Benchmark Scenes**: 3 sandbox environments (foraging, crafting chain, team capture) with metrics @@ -130,6 +131,26 @@ agent-arena/ See [docs/quickstart.md](docs/quickstart.md) for a tutorial on creating your first agent-driven scene. +### Model Management + +Agent Arena includes a built-in tool to download and manage LLM models from Hugging Face Hub: + +```bash +# Download a model for testing +cd python +python -m tools.model_manager download tinyllama-1.1b-chat --format gguf --quant q4_k_m + +# List available models in registry +python -m tools.model_manager info + +# List downloaded models +python -m tools.model_manager list +``` + +Supported models include TinyLlama (1.1B), Phi-2 (2.7B), Llama-2 (7B/13B), Mistral (7B), Llama-3 (8B), and Mixtral (8x7B). + +For detailed documentation on model management, see [docs/model_management.md](docs/model_management.md). + ## Development Roadmap - [ ] Phase 1: Core infrastructure (deterministic sim, event bus, basic tools) diff --git a/configs/models.yaml b/configs/models.yaml new file mode 100644 index 0000000..0d93c21 --- /dev/null +++ b/configs/models.yaml @@ -0,0 +1,159 @@ +# Model Registry for Agent Arena +# This file defines available models and their Hugging Face Hub sources + +models: + # Small models for development and testing + tinyllama-1.1b-chat: + huggingface_id: "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" + description: "Extremely fast, basic capabilities, great for testing" + size_class: "tiny" + formats: + gguf: + q4_k_m: + file: "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf" + sha256: null # Checksums can be added for verification + q5_k_m: + file: "tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf" + sha256: null + q8_0: + file: "tinyllama-1.1b-chat-v1.0.Q8_0.gguf" + sha256: null + + phi-2: + huggingface_id: "TheBloke/phi-2-GGUF" + description: "Fast, good reasoning for 2.7B size, excellent for development" + size_class: "small" + formats: + gguf: + q4_k_m: + file: "phi-2.Q4_K_M.gguf" + sha256: null + q5_k_m: + file: "phi-2.Q5_K_M.gguf" + sha256: null + q8_0: + file: "phi-2.Q8_0.gguf" + sha256: null + + # Production-ready 7B models + llama-2-7b-chat: + huggingface_id: "TheBloke/Llama-2-7B-Chat-GGUF" + description: "Good balance of speed and quality, widely tested" + size_class: "medium" + formats: + gguf: + q4_k_m: + file: "llama-2-7b-chat.Q4_K_M.gguf" + sha256: null + q5_k_m: + file: "llama-2-7b-chat.Q5_K_M.gguf" + sha256: null + q8_0: + file: "llama-2-7b-chat.Q8_0.gguf" + sha256: null + + mistral-7b-instruct-v0.2: + huggingface_id: "TheBloke/Mistral-7B-Instruct-v0.2-GGUF" + description: "High quality instruction following, fast inference" + size_class: "medium" + formats: + gguf: + q4_k_m: + file: "mistral-7b-instruct-v0.2.Q4_K_M.gguf" + sha256: null + q5_k_m: + file: "mistral-7b-instruct-v0.2.Q5_K_M.gguf" + sha256: null + q8_0: + file: "mistral-7b-instruct-v0.2.Q8_0.gguf" + sha256: null + + llama-3-8b-instruct: + huggingface_id: "QuantFactory/Meta-Llama-3-8B-Instruct-GGUF" + description: "Latest Llama 3, best quality in 8B class" + size_class: "medium" + formats: + gguf: + q4_k_m: + file: "Meta-Llama-3-8B-Instruct.Q4_K_M.gguf" + sha256: null + q5_k_m: + file: "Meta-Llama-3-8B-Instruct.Q5_K_M.gguf" + sha256: null + q8_0: + file: "Meta-Llama-3-8B-Instruct.Q8_0.gguf" + sha256: null + + # Larger models for high quality + llama-2-13b-chat: + huggingface_id: "TheBloke/Llama-2-13B-Chat-GGUF" + description: "Better reasoning and instruction following than 7B" + size_class: "large" + formats: + gguf: + q4_k_m: + file: "llama-2-13b-chat.Q4_K_M.gguf" + sha256: null + q5_k_m: + file: "llama-2-13b-chat.Q5_K_M.gguf" + sha256: null + + mixtral-8x7b-instruct: + huggingface_id: "TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF" + description: "Mixture of Experts, excellent quality, 47B total parameters" + size_class: "xlarge" + formats: + gguf: + q4_k_m: + file: "mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf" + sha256: null + q5_k_m: + file: "mixtral-8x7b-instruct-v0.1.Q5_K_M.gguf" + sha256: null + +# Quantization guide +quantization_info: + q4_k_m: + description: "4-bit quantization, good balance of size and quality" + quality: "Medium" + speed: "Fast" + size_factor: 0.25 # Approx 1/4 of original size + + q5_k_m: + description: "5-bit quantization, better quality than Q4" + quality: "Medium-High" + speed: "Medium-Fast" + size_factor: 0.31 + + q8_0: + description: "8-bit quantization, near original quality" + quality: "High" + speed: "Medium" + size_factor: 0.50 + +# Size class reference (unquantized sizes) +size_classes: + tiny: + description: "< 2B parameters" + ram_required: "2-4 GB" + use_case: "Testing, rapid iteration" + + small: + description: "2-4B parameters" + ram_required: "4-8 GB" + use_case: "Development, basic tasks" + + medium: + description: "7-8B parameters" + ram_required: "8-16 GB" + use_case: "Production, general purpose" + + large: + description: "13-14B parameters" + ram_required: "16-32 GB" + use_case: "High quality tasks" + + xlarge: + description: "30B+ parameters" + ram_required: "32+ GB" + use_case: "Highest quality, research" diff --git a/docs/model_management.md b/docs/model_management.md new file mode 100644 index 0000000..8c56375 --- /dev/null +++ b/docs/model_management.md @@ -0,0 +1,396 @@ +# Model Management Guide + +This guide explains how to download, manage, and use LLM models with Agent Arena. + +## Overview + +Agent Arena includes a Model Manager tool that automates downloading and managing models from Hugging Face Hub. The tool supports: + +- **GGUF models** for llama.cpp backend (CPU and GPU) +- **PyTorch/safetensors models** for vLLM backend (GPU) +- **Automatic caching** to avoid re-downloading +- **Checksum verification** for model integrity +- **Multiple quantization levels** for size/quality tradeoffs + +## Quick Start + +### 1. Install Dependencies + +First, ensure you have the LLM dependencies installed: + +```bash +# Activate your virtual environment +cd python +venv\Scripts\activate # Windows +# source venv/bin/activate # Linux/Mac + +# Install LLM dependencies (includes huggingface-hub) +pip install -e ".[llm]" +``` + +**Note:** The model manager automatically finds the project root, so you can run commands from any directory within the project. + +### 2. Download a Model + +Download a model using the command-line interface: + +```bash +# Download a small model for testing (TinyLlama 1.1B) +python -m tools.model_manager download tinyllama-1.1b-chat --format gguf --quant q4_k_m + +# Download a production model (Mistral 7B) +python -m tools.model_manager download mistral-7b-instruct-v0.2 --format gguf --quant q4_k_m +``` + +### 3. List Downloaded Models + +```bash +python -m tools.model_manager list +``` + +### 4. Use the Model + +Update your backend configuration to point to the downloaded model: + +```yaml +# configs/backend/llama_cpp.yaml +backend: + type: llama_cpp + model_path: "models/mistral-7b-instruct-v0.2/gguf/q4_k_m/mistral-7b-instruct-v0.2.Q4_K_M.gguf" + n_ctx: 4096 + n_gpu_layers: 0 # Set to -1 for full GPU offload +``` + +## Available Models + +### Small Models (Development/Testing) + +| Model | Size | RAM Required | Use Case | +|-------|------|--------------|----------| +| `tinyllama-1.1b-chat` | 1.1B | 2-4 GB | Testing, rapid iteration | +| `phi-2` | 2.7B | 4-8 GB | Development, good reasoning | + +### Production Models (7-8B) + +| Model | Size | RAM Required | Description | +|-------|------|--------------|-------------| +| `llama-2-7b-chat` | 7B | 8-16 GB | Balanced speed/quality, widely tested | +| `mistral-7b-instruct-v0.2` | 7B | 8-16 GB | High quality instruction following | +| `llama-3-8b-instruct` | 8B | 8-16 GB | Latest Llama, best quality in class | + +### Large Models (High Quality) + +| Model | Size | RAM Required | Description | +|-------|------|--------------|-------------| +| `llama-2-13b-chat` | 13B | 16-32 GB | Better reasoning than 7B | +| `mixtral-8x7b-instruct` | 47B | 32+ GB | Mixture of Experts, excellent quality | + +## Quantization Levels + +Quantization reduces model size with minimal quality loss. Choose based on your needs: + +| Quantization | Quality | Speed | Size Factor | Recommended For | +|--------------|---------|-------|-------------|-----------------| +| `q4_k_m` | Medium | Fast | 25% | General use, good balance | +| `q5_k_m` | Medium-High | Medium-Fast | 31% | Better quality, still fast | +| `q8_0` | High | Medium | 50% | Near-original quality | + +**Example:** A 7B model unquantized is ~14GB. With Q4_K_M quantization it's ~3.8GB. + +## CLI Commands + +### Download a Model + +```bash +python -m tools.model_manager download [options] + +Options: + --format FORMAT Model format (default: gguf) + --quant QUANTIZATION Quantization type (e.g., q4_k_m, q5_k_m) + --force Force re-download even if exists +``` + +**Examples:** + +```bash +# Download default quantization +python -m tools.model_manager download llama-2-7b-chat --quant q4_k_m + +# Download higher quality version +python -m tools.model_manager download llama-2-7b-chat --quant q8_0 + +# Force re-download +python -m tools.model_manager download mistral-7b-instruct-v0.2 --quant q4_k_m --force +``` + +### List Cached Models + +```bash +python -m tools.model_manager list [--format FORMAT] + +# List all models +python -m tools.model_manager list + +# Filter by format +python -m tools.model_manager list --format gguf +``` + +Output example: +``` +Cached Models: +================================================================================ +llama-2-7b-chat gguf /q4_k_m 3.83 GB +mistral-7b-instruct-v0.2 gguf /q5_k_m 5.13 GB +================================================================================ +Total storage: 8.96 GB +``` + +### Verify a Model + +Check if a downloaded model is valid: + +```bash +python -m tools.model_manager verify [options] + +Options: + --format FORMAT Model format (default: gguf) + --quant QUANTIZATION Quantization type +``` + +**Example:** + +```bash +python -m tools.model_manager verify llama-2-7b-chat --format gguf --quant q4_k_m +``` + +### Remove a Model + +```bash +python -m tools.model_manager remove [options] + +Options: + --format FORMAT Remove specific format only + --quant QUANTIZATION Remove specific quantization only +``` + +**Examples:** + +```bash +# Remove all versions of a model +python -m tools.model_manager remove llama-2-7b-chat + +# Remove specific quantization +python -m tools.model_manager remove llama-2-7b-chat --format gguf --quant q4_k_m + +# Remove all GGUF versions +python -m tools.model_manager remove llama-2-7b-chat --format gguf +``` + +### Show Model Information + +```bash +python -m tools.model_manager info [model_id] + +# List all available models +python -m tools.model_manager info + +# Show details for specific model +python -m tools.model_manager info llama-2-7b-chat +``` + +## Model Storage Structure + +Models are cached in the `models/` directory with this structure: + +``` +models/ +├── llama-2-7b-chat/ +│ └── gguf/ +│ ├── q4_k_m/ +│ │ └── llama-2-7b-chat.Q4_K_M.gguf +│ └── q5_k_m/ +│ └── llama-2-7b-chat.Q5_K_M.gguf +├── mistral-7b-instruct-v0.2/ +│ └── gguf/ +│ └── q4_k_m/ +│ └── mistral-7b-instruct-v0.2.Q4_K_M.gguf +└── tinyllama-1.1b-chat/ + └── gguf/ + └── q4_k_m/ + └── tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf +``` + +## Adding Custom Models + +To add a custom model to the registry: + +1. Edit `configs/models.yaml` +2. Add your model following this template: + +```yaml +models: + your-model-name: + huggingface_id: "author/model-repo-name" + description: "Description of the model" + size_class: "medium" # tiny, small, medium, large, xlarge + formats: + gguf: + q4_k_m: + file: "model-filename.Q4_K_M.gguf" + sha256: null # Optional: add SHA256 for verification + q5_k_m: + file: "model-filename.Q5_K_M.gguf" + sha256: null +``` + +3. Download the model: + +```bash +python -m tools.model_manager download your-model-name --quant q4_k_m +``` + +## Storage Requirements + +Plan your disk space based on models you'll use: + +| Model Class | Q4_K_M Size | Q5_K_M Size | Q8_0 Size | +|-------------|-------------|-------------|-----------| +| Tiny (1B) | ~600 MB | ~750 MB | ~1.2 GB | +| Small (2-3B) | ~1.5 GB | ~2 GB | ~3 GB | +| Medium (7-8B) | ~3.8 GB | ~5 GB | ~7 GB | +| Large (13B) | ~7 GB | ~9 GB | ~13 GB | +| XLarge (47B) | ~26 GB | ~33 GB | ~47 GB | + +**Recommendation:** Start with Q4_K_M quantization for best size/quality balance. + +## Performance Characteristics + +### Speed vs Quality Tradeoff + +- **Q4_K_M**: Fastest, good quality, smallest size +- **Q5_K_M**: Slightly slower, better quality, medium size +- **Q8_0**: Slowest, best quality, largest size + +### GPU Acceleration + +For GPU acceleration with llama.cpp: + +```yaml +# configs/backend/llama_cpp.yaml +backend: + type: llama_cpp + model_path: "models/mistral-7b-instruct-v0.2/gguf/q4_k_m/mistral-7b-instruct-v0.2.Q4_K_M.gguf" + n_ctx: 4096 + n_gpu_layers: -1 # -1 = offload all layers to GPU +``` + +**GPU Memory Requirements:** +- 7B Q4_K_M with full GPU offload: ~4-5 GB VRAM +- 7B Q5_K_M with full GPU offload: ~5-6 GB VRAM +- 13B Q4_K_M with full GPU offload: ~8-9 GB VRAM + +## Troubleshooting + +### Download Errors + +**Problem:** `HTTP 401 Unauthorized` +**Solution:** Some models require authentication. Set your Hugging Face token: + +```bash +# Windows +set HF_TOKEN=your_token_here + +# Linux/Mac +export HF_TOKEN=your_token_here +``` + +**Problem:** Download interrupted +**Solution:** The tool supports resume. Just re-run the download command. + +**Problem:** "Model not found in registry" +**Solution:** Check available models with `python -m tools.model_manager info` + +### Checksum Verification Failed + +**Problem:** Checksum mismatch after download +**Solution:** +1. Remove the corrupted model: `python -m tools.model_manager remove ` +2. Re-download: `python -m tools.model_manager download --force` + +### Out of Disk Space + +**Problem:** Insufficient disk space +**Solution:** +1. Check current usage: `python -m tools.model_manager list` +2. Remove unused models: `python -m tools.model_manager remove ` +3. Use smaller quantization (Q4_K_M instead of Q8_0) + +### Model Loading Errors + +**Problem:** Backend fails to load model +**Solution:** +1. Verify model exists: `python -m tools.model_manager list` +2. Check path in config matches actual file path +3. Verify model integrity: `python -m tools.model_manager verify ` + +## Python API + +You can also use the ModelManager programmatically: + +```python +from pathlib import Path +from tools.model_manager import ModelManager + +# Initialize +manager = ModelManager( + models_dir=Path("models"), + config_path=Path("configs/models.yaml") +) + +# Download a model +model_path = manager.download_model( + model_id="mistral-7b-instruct-v0.2", + format="gguf", + quantization="q4_k_m" +) + +if model_path: + print(f"Model downloaded to: {model_path}") + +# List cached models +models = manager.list_models() +for model in models: + print(f"{model['model']}: {model['size_gb']:.2f} GB") + +# Get path to existing model +model_path = manager.get_model_path( + model_id="llama-2-7b-chat", + format="gguf", + quantization="q4_k_m" +) + +# Verify model +is_valid = manager.verify_model( + model_path, + expected_sha256="abc123..." # Optional +) + +# Remove a model +manager.remove_model("old-model") +``` + +## Best Practices + +1. **Start Small**: Begin with `tinyllama-1.1b-chat` for testing +2. **Monitor Storage**: Regularly check disk usage with `list` command +3. **Clean Up**: Remove unused models to free space +4. **Use Q4_K_M**: Best balance for most use cases +5. **Verify Downloads**: Run `verify` after downloading large models +6. **Plan GPU Usage**: Check VRAM requirements before downloading large models + +## See Also + +- [Backend Configuration](../configs/backend/) +- [LLM Backend Guide](llm_backends.md) +- [Hugging Face Hub](https://huggingface.co/models) +- [GGUF Format Info](https://github.com/ggerganov/llama.cpp#gguf) diff --git a/pyproject.toml b/pyproject.toml index a3c6c40..6c35322 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ llm = [ "openai>=1.0.0", "torch>=2.0.0", "transformers>=4.35.0", + "huggingface-hub>=0.20.0", ] vector = [ diff --git a/python/backends/__init__.py b/python/backends/__init__.py index b4c4f30..27e5368 100644 --- a/python/backends/__init__.py +++ b/python/backends/__init__.py @@ -2,7 +2,7 @@ LLM Backend Adapters for Agent Arena """ -from .base import BaseBackend, BackendConfig +from .base import BackendConfig, BaseBackend from .llama_cpp_backend import LlamaCppBackend from .vllm_backend import VLLMBackend, VLLMBackendConfig diff --git a/python/backends/llama_cpp_backend.py b/python/backends/llama_cpp_backend.py index 3437e3a..1352999 100644 --- a/python/backends/llama_cpp_backend.py +++ b/python/backends/llama_cpp_backend.py @@ -28,7 +28,7 @@ def _load_model(self) -> None: logger.info(f"Loading model from {self.config.model_path}") # Use GPU layers from config - n_gpu_layers = getattr(self.config, 'n_gpu_layers', 0) + n_gpu_layers = getattr(self.config, "n_gpu_layers", 0) if n_gpu_layers > 0: logger.info(f"Offloading {n_gpu_layers} layers to GPU") diff --git a/python/backends/vllm_backend.py b/python/backends/vllm_backend.py index c543318..10ea735 100644 --- a/python/backends/vllm_backend.py +++ b/python/backends/vllm_backend.py @@ -194,14 +194,16 @@ def generate_with_tools( # Convert tool schemas to OpenAI format openai_tools = [] for tool in tools: - openai_tools.append({ - "type": "function", - "function": { - "name": tool["name"], - "description": tool["description"], - "parameters": tool.get("parameters", {}), + openai_tools.append( + { + "type": "function", + "function": { + "name": tool["name"], + "description": tool["description"], + "parameters": tool.get("parameters", {}), + }, } - }) + ) # Use chat completions API for function calling response = self.client.chat.completions.create( diff --git a/python/requirements-ci.txt b/python/requirements-ci.txt new file mode 100644 index 0000000..a3acaf6 --- /dev/null +++ b/python/requirements-ci.txt @@ -0,0 +1,36 @@ +# CI-specific requirements (excludes dependencies that require compilation) +# Core dependencies +numpy>=1.24.0 +pydantic>=2.0.0 +msgpack>=1.0.5 +hydra-core>=1.3.0 +omegaconf>=2.3.0 + +# IPC Server +fastapi>=0.104.0 +uvicorn>=0.24.0 + +# LLM backends (lightweight only - exclude llama-cpp-python and torch) +openai>=1.0.0 # For OpenAI-compatible APIs +requests>=2.31.0 +huggingface-hub>=0.20.0 # For downloading models from HuggingFace Hub + +# Testing and development +pytest>=7.4.0 +pytest-asyncio>=0.21.0 +pytest-cov>=4.1.0 +black>=23.0.0 +ruff>=0.1.0 +mypy>=1.5.0 + +# Utilities +python-dotenv>=1.0.0 +tenacity>=8.2.0 # Retry logic +aiohttp>=3.9.0 # Async HTTP +tqdm>=4.66.0 # Progress bars +rich>=13.7.0 # Pretty console output +pyyaml>=6.0.0 # YAML configuration files + +# Logging and monitoring +structlog>=23.2.0 +python-json-logger>=2.0.7 diff --git a/python/requirements.txt b/python/requirements.txt index 2f9a6df..523b275 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -22,6 +22,7 @@ chromadb>=0.4.0 # Alternative lightweight vector store # ML/AI (optional) torch>=2.0.0 # PyTorch for model training transformers>=4.35.0 # HuggingFace transformers +huggingface-hub>=0.20.0 # For downloading models from HuggingFace Hub # Testing and development pytest>=7.4.0 @@ -37,6 +38,7 @@ tenacity>=8.2.0 # Retry logic aiohttp>=3.9.0 # Async HTTP tqdm>=4.66.0 # Progress bars rich>=13.7.0 # Pretty console output +pyyaml>=6.0.0 # YAML configuration files # Logging and monitoring structlog>=23.2.0 diff --git a/python/run_ipc_server_with_gpu.py b/python/run_ipc_server_with_gpu.py index 821bd02..6804f92 100644 --- a/python/run_ipc_server_with_gpu.py +++ b/python/run_ipc_server_with_gpu.py @@ -9,12 +9,12 @@ import logging import sys -from agent_runtime.runtime import AgentRuntime from agent_runtime.agent import Agent +from agent_runtime.runtime import AgentRuntime from agent_runtime.tool_dispatcher import ToolDispatcher -from backends import LlamaCppBackend, BackendConfig +from backends import BackendConfig, LlamaCppBackend from ipc.server import create_server -from tools import register_movement_tools, register_inventory_tools, register_world_query_tools +from tools import register_inventory_tools, register_movement_tools, register_world_query_tools # Configure logging logging.basicConfig( @@ -56,25 +56,25 @@ def main(): "--model", type=str, default="../models/llama-2-7b-chat.Q4_K_M.gguf", - help="Path to GGUF model file (default: ../models/llama-2-7b-chat.Q4_K_M.gguf)" + help="Path to GGUF model file (default: ../models/llama-2-7b-chat.Q4_K_M.gguf)", ) parser.add_argument( "--gpu-layers", type=int, default=-1, - help="Number of layers to offload to GPU: -1=all, 0=CPU only (default: -1)" + help="Number of layers to offload to GPU: -1=all, 0=CPU only (default: -1)", ) parser.add_argument( "--temperature", type=float, default=0.7, - help="LLM temperature for decision making (default: 0.7)" + help="LLM temperature for decision making (default: 0.7)", ) parser.add_argument( "--max-tokens", type=int, default=256, - help="Maximum tokens to generate per decision (default: 256)" + help="Maximum tokens to generate per decision (default: 256)", ) args = parser.parse_args() @@ -89,7 +89,9 @@ def main(): logger.info(f"Port: {args.port}") logger.info(f"Max Workers: {args.workers}") logger.info(f"Model: {args.model}") - logger.info(f"GPU Layers: {args.gpu_layers} ({'all' if args.gpu_layers == -1 else 'CPU only' if args.gpu_layers == 0 else args.gpu_layers})") + logger.info( + f"GPU Layers: {args.gpu_layers} ({'all' if args.gpu_layers == -1 else 'CPU only' if args.gpu_layers == 0 else args.gpu_layers})" + ) logger.info(f"Temperature: {args.temperature}") logger.info(f"Max Tokens: {args.max_tokens}") logger.info("=" * 60) @@ -100,7 +102,7 @@ def main(): model_path=args.model, temperature=args.temperature, max_tokens=args.max_tokens, - n_gpu_layers=args.gpu_layers + n_gpu_layers=args.gpu_layers, ) logger.info("Loading GPU-accelerated LLM backend...") @@ -122,11 +124,13 @@ def main(): agent_id="gpu_agent_001", backend=backend, tools=list(tool_dispatcher.tools.keys()), - goals=["explore the world", "collect resources", "survive"] + goals=["explore the world", "collect resources", "survive"], ) runtime.register_agent(test_agent) - logger.info(f"✓ Registered agent '{test_agent.state.agent_id}' with GPU backend and {len(test_agent.available_tools)} tools") + logger.info( + f"✓ Registered agent '{test_agent.state.agent_id}' with GPU backend and {len(test_agent.available_tools)} tools" + ) logger.info("=" * 60) logger.info("Server ready! You can now:") @@ -142,13 +146,13 @@ def main(): except KeyboardInterrupt: logger.info("\nShutting down gracefully...") - if 'backend' in locals(): + if "backend" in locals(): logger.info("Unloading LLM backend...") backend.unload() sys.exit(0) except Exception as e: logger.error(f"Fatal error: {e}", exc_info=True) - if 'backend' in locals(): + if "backend" in locals(): backend.unload() sys.exit(1) diff --git a/python/run_vllm_server.py b/python/run_vllm_server.py index 3c8fa57..b759fba 100644 --- a/python/run_vllm_server.py +++ b/python/run_vllm_server.py @@ -85,6 +85,7 @@ def main(): # Check if vLLM is installed try: import vllm + logger.info(f"vLLM version: {vllm.__version__}") except ImportError: logger.error( @@ -94,7 +95,6 @@ def main(): sys.exit(1) # Import vLLM server - from vllm.entrypoints.openai.api_server import run_server logger.info(f"Starting vLLM server for model: {args.model}") logger.info(f"Server will be available at: http://{args.host}:{args.port}") @@ -105,13 +105,20 @@ def main(): # Build command-line arguments for vLLM vllm_args = [ - "--model", args.model, - "--host", args.host, - "--port", str(args.port), - "--tensor-parallel-size", str(args.tensor_parallel_size), - "--gpu-memory-utilization", str(args.gpu_memory), - "--max-model-len", str(args.max_model_len), - "--dtype", args.dtype, + "--model", + args.model, + "--host", + args.host, + "--port", + str(args.port), + "--tensor-parallel-size", + str(args.tensor_parallel_size), + "--gpu-memory-utilization", + str(args.gpu_memory), + "--max-model-len", + str(args.max_model_len), + "--dtype", + args.dtype, ] if args.trust_remote_code: @@ -129,9 +136,9 @@ def main(): # Or start directly if vLLM supports it import subprocess + subprocess.run( - ["python", "-m", "vllm.entrypoints.openai.api_server"] + vllm_args, - check=True + ["python", "-m", "vllm.entrypoints.openai.api_server"] + vllm_args, check=True ) except KeyboardInterrupt: diff --git a/python/test_agent_gpu.py b/python/test_agent_gpu.py index bbb5ce1..1c53d9d 100644 --- a/python/test_agent_gpu.py +++ b/python/test_agent_gpu.py @@ -11,14 +11,14 @@ import json import logging -from backends import LlamaCppBackend, BackendConfig -from agent_runtime.agent import Agent, Action + +from agent_runtime.agent import Action, Agent from agent_runtime.tool_dispatcher import ToolDispatcher +from backends import BackendConfig, LlamaCppBackend # Configure logging logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) @@ -27,6 +27,7 @@ # Step 1: Create Sample Tools # ============================================================ + def create_tool_dispatcher() -> ToolDispatcher: """Create a ToolDispatcher with sample game tools.""" dispatcher = ToolDispatcher() @@ -40,7 +41,7 @@ def move_to(target_x: float, target_y: float, speed: float = 1.0) -> dict: "success": True, "message": f"Moving to ({target_x}, {target_y}) at speed {speed}", "estimated_time": time_estimate, - "distance": distance + "distance": distance, } dispatcher.register_tool( @@ -50,21 +51,15 @@ def move_to(target_x: float, target_y: float, speed: float = 1.0) -> dict: parameters={ "type": "object", "properties": { - "target_x": { - "type": "number", - "description": "Target X coordinate" - }, - "target_y": { - "type": "number", - "description": "Target Y coordinate" - }, + "target_x": {"type": "number", "description": "Target X coordinate"}, + "target_y": {"type": "number", "description": "Target Y coordinate"}, "speed": { "type": "number", "description": "Movement speed (default 1.0)", - "default": 1.0 - } + "default": 1.0, + }, }, - "required": ["target_x", "target_y"] + "required": ["target_x", "target_y"], }, returns={ "type": "object", @@ -72,9 +67,9 @@ def move_to(target_x: float, target_y: float, speed: float = 1.0) -> dict: "success": {"type": "boolean"}, "message": {"type": "string"}, "estimated_time": {"type": "number"}, - "distance": {"type": "number"} - } - } + "distance": {"type": "number"}, + }, + }, ) # Tool 2: Collect resource @@ -86,13 +81,13 @@ def collect_resource(resource_name: str) -> dict: "success": True, "message": f"Collected {resource_name}", "resource": resource_name, - "quantity": 1 + "quantity": 1, } else: return { "success": False, "message": f"Unknown resource: {resource_name}", - "error": "Invalid resource type" + "error": "Invalid resource type", } dispatcher.register_tool( @@ -105,10 +100,10 @@ def collect_resource(resource_name: str) -> dict: "resource_name": { "type": "string", "description": "Name of resource to collect (wood, stone, or food)", - "enum": ["wood", "stone", "food"] + "enum": ["wood", "stone", "food"], } }, - "required": ["resource_name"] + "required": ["resource_name"], }, returns={ "type": "object", @@ -116,40 +111,29 @@ def collect_resource(resource_name: str) -> dict: "success": {"type": "boolean"}, "message": {"type": "string"}, "resource": {"type": "string"}, - "quantity": {"type": "number"} - } - } + "quantity": {"type": "number"}, + }, + }, ) # Tool 3: Check inventory def check_inventory() -> dict: """Check current inventory (mock data for demo).""" - return { - "success": True, - "inventory": { - "wood": 5, - "stone": 3, - "food": 2 - }, - "total_items": 10 - } + return {"success": True, "inventory": {"wood": 5, "stone": 3, "food": 2}, "total_items": 10} dispatcher.register_tool( name="check_inventory", function=check_inventory, description="Check the agent's current inventory", - parameters={ - "type": "object", - "properties": {} - }, + parameters={"type": "object", "properties": {}}, returns={ "type": "object", "properties": { "success": {"type": "boolean"}, "inventory": {"type": "object"}, - "total_items": {"type": "number"} - } - } + "total_items": {"type": "number"}, + }, + }, ) logger.info(f"Created ToolDispatcher with {len(dispatcher.tools)} tools") @@ -160,6 +144,7 @@ def check_inventory() -> dict: # Step 2: Enhanced Agent with Backend Integration # ============================================================ + class EnhancedAgent(Agent): """ Enhanced Agent that properly integrates with LLM backend and tools. @@ -167,16 +152,17 @@ class EnhancedAgent(Agent): This extends the base Agent class to implement actual backend communication. """ - def __init__(self, agent_id: str, backend, tool_dispatcher: ToolDispatcher, goals: list[str] | None = None): + def __init__( + self, + agent_id: str, + backend, + tool_dispatcher: ToolDispatcher, + goals: list[str] | None = None, + ): # Get available tool names from dispatcher available_tools = list(tool_dispatcher.tools.keys()) - super().__init__( - agent_id=agent_id, - backend=backend, - tools=available_tools, - goals=goals - ) + super().__init__(agent_id=agent_id, backend=backend, tools=available_tools, goals=goals) self.tool_dispatcher = tool_dispatcher @@ -209,7 +195,7 @@ def _query_llm(self, context: str) -> str: result = self.backend.generate( prompt=prompt, temperature=0.3, # Lower temperature for more consistent JSON - max_tokens=150 + max_tokens=150, ) # Extract JSON from response @@ -224,11 +210,11 @@ def _extract_json(self, text: str) -> str: import re # Try to find JSON object in the response - start = text.find('{') - end = text.rfind('}') + start = text.find("{") + end = text.rfind("}") if start != -1 and end != -1: - json_str = text[start:end+1] + json_str = text[start : end + 1] # Try to validate and return if valid try: @@ -239,7 +225,9 @@ def _extract_json(self, text: str) -> str: # Common issue: missing comma between fields # Pattern: "value"\n"field" should be "value",\n"field" - fixed_json = re.sub(r'([\d"])\s*\n\s*("(?:reasoning|tool|params))', r'\1,\n\2', json_str) + fixed_json = re.sub( + r'([\d"])\s*\n\s*("(?:reasoning|tool|params))', r"\1,\n\2", json_str + ) try: json.loads(fixed_json) @@ -262,15 +250,12 @@ def _extract_json(self, text: str) -> str: elif tool == "move_to" and target_x_match and target_y_match: params = { "target_x": float(target_x_match.group(1)), - "target_y": float(target_y_match.group(1)) + "target_y": float(target_y_match.group(1)), } elif tool == "check_inventory": params = {} - reconstructed = { - "tool": tool, - "params": params - } + reconstructed = {"tool": tool, "params": params} logger.debug(f"Reconstructed JSON from pattern matching: {reconstructed}") return json.dumps(reconstructed) @@ -294,13 +279,14 @@ def execute_action(self, action: Action) -> dict: # Step 3: Test Scenarios # ============================================================ + def test_scenario_1_resource_collection(): """ Scenario: Agent sees wood nearby and should collect it. """ - print("\n" + "="*60) + print("\n" + "=" * 60) print("SCENARIO 1: Resource Collection") - print("="*60) + print("=" * 60) # Setup dispatcher = create_tool_dispatcher() @@ -309,7 +295,7 @@ def test_scenario_1_resource_collection(): model_path="../models/llama-2-7b-chat.Q4_K_M.gguf", temperature=0.3, max_tokens=150, - n_gpu_layers=-1 # Full GPU acceleration + n_gpu_layers=-1, # Full GPU acceleration ) backend = LlamaCppBackend(config) @@ -317,19 +303,22 @@ def test_scenario_1_resource_collection(): agent_id="forager_001", backend=backend, tool_dispatcher=dispatcher, - goals=["collect resources for crafting"] + goals=["collect resources for crafting"], ) # Simulate observations print("\n[Simulation] Agent observes environment...") - agent.perceive({ - "position": {"x": 0, "y": 0}, - "visible_objects": [ - {"type": "wood", "distance": 2.5, "position": {"x": 2, "y": 1}}, - {"type": "tree", "distance": 5.0} - ], - "inventory_count": 10 - }, source="vision") + agent.perceive( + { + "position": {"x": 0, "y": 0}, + "visible_objects": [ + {"type": "wood", "distance": 2.5, "position": {"x": 2, "y": 1}}, + {"type": "tree", "distance": 5.0}, + ], + "inventory_count": 10, + }, + source="vision", + ) # Agent decides action print("\n[Agent] Deciding action based on observations and goals...") @@ -349,16 +338,16 @@ def test_scenario_1_resource_collection(): print("\n[Agent] Failed to decide action") backend.unload() - print("\n" + "="*60) + print("\n" + "=" * 60) def test_scenario_2_navigation(): """ Scenario: Agent needs to move to a target location. """ - print("\n" + "="*60) + print("\n" + "=" * 60) print("SCENARIO 2: Navigation") - print("="*60) + print("=" * 60) # Setup dispatcher = create_tool_dispatcher() @@ -367,7 +356,7 @@ def test_scenario_2_navigation(): model_path="../models/llama-2-7b-chat.Q4_K_M.gguf", temperature=0.3, max_tokens=150, - n_gpu_layers=-1 + n_gpu_layers=-1, ) backend = LlamaCppBackend(config) @@ -375,16 +364,15 @@ def test_scenario_2_navigation(): agent_id="explorer_001", backend=backend, tool_dispatcher=dispatcher, - goals=["explore the map", "find the tower at (10, 15)"] + goals=["explore the map", "find the tower at (10, 15)"], ) # Simulate observations print("\n[Simulation] Agent receives navigation task...") - agent.perceive({ - "position": {"x": 0, "y": 0}, - "target_location": {"x": 10, "y": 15}, - "obstacles": [] - }, source="navigation") + agent.perceive( + {"position": {"x": 0, "y": 0}, "target_location": {"x": 10, "y": 15}, "obstacles": []}, + source="navigation", + ) # Agent decides action print("\n[Agent] Deciding navigation action...") @@ -404,16 +392,16 @@ def test_scenario_2_navigation(): print("\n[Agent] Failed to decide action") backend.unload() - print("\n" + "="*60) + print("\n" + "=" * 60) def test_scenario_3_inventory_check(): """ Scenario: Agent checks inventory before crafting. """ - print("\n" + "="*60) + print("\n" + "=" * 60) print("SCENARIO 3: Inventory Management") - print("="*60) + print("=" * 60) # Setup dispatcher = create_tool_dispatcher() @@ -422,7 +410,7 @@ def test_scenario_3_inventory_check(): model_path="../models/llama-2-7b-chat.Q4_K_M.gguf", temperature=0.3, max_tokens=150, - n_gpu_layers=-1 + n_gpu_layers=-1, ) backend = LlamaCppBackend(config) @@ -430,16 +418,19 @@ def test_scenario_3_inventory_check(): agent_id="crafter_001", backend=backend, tool_dispatcher=dispatcher, - goals=["craft a wooden tool", "check if we have enough materials"] + goals=["craft a wooden tool", "check if we have enough materials"], ) # Simulate observations print("\n[Simulation] Agent wants to craft something...") - agent.perceive({ - "crafting_station": "workbench", - "recipe_requires": {"wood": 3, "stone": 1}, - "action": "prepare_crafting" - }, source="crafting") + agent.perceive( + { + "crafting_station": "workbench", + "recipe_requires": {"wood": 3, "stone": 1}, + "action": "prepare_crafting", + }, + source="crafting", + ) # Agent decides action print("\n[Agent] Deciding what to do before crafting...") @@ -459,7 +450,7 @@ def test_scenario_3_inventory_check(): print("\n[Agent] Failed to decide action") backend.unload() - print("\n" + "="*60) + print("\n" + "=" * 60) # ============================================================ @@ -467,9 +458,9 @@ def test_scenario_3_inventory_check(): # ============================================================ if __name__ == "__main__": - print("\n" + "="*60) + print("\n" + "=" * 60) print("Agent + GPU Backend + Tools Integration Test") - print("="*60) + print("=" * 60) print("\nThis test demonstrates an autonomous agent using:") print(" - GPU-accelerated Llama-2-7B backend (113 tok/s)") print(" - ToolDispatcher with 3 sample tools") @@ -482,9 +473,9 @@ def test_scenario_3_inventory_check(): test_scenario_2_navigation() test_scenario_3_inventory_check() - print("\n" + "="*60) + print("\n" + "=" * 60) print("All scenarios completed!") - print("="*60) + print("=" * 60) except Exception as e: logger.error(f"Test failed: {e}", exc_info=True) diff --git a/python/test_llama_backend.py b/python/test_llama_backend.py index e96b86c..c8a84e8 100644 --- a/python/test_llama_backend.py +++ b/python/test_llama_backend.py @@ -6,12 +6,12 @@ """ import logging -from backends import LlamaCppBackend, BackendConfig + +from backends import BackendConfig, LlamaCppBackend # Set up logging to see what's happening logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) @@ -43,9 +43,9 @@ def main(): logger.info("Backend loaded successfully!") # Test 1: Basic text generation - logger.info("\n" + "="*60) + logger.info("\n" + "=" * 60) logger.info("Test 1: Basic Text Generation") - logger.info("="*60) + logger.info("=" * 60) prompt = "Hello! My name is" logger.info(f"Prompt: '{prompt}'") @@ -57,9 +57,9 @@ def main(): logger.info(f"Finish reason: {result.finish_reason}") # Test 2: Tool calling - logger.info("\n" + "="*60) + logger.info("\n" + "=" * 60) logger.info("Test 2: Tool Calling (Function Calling)") - logger.info("="*60) + logger.info("=" * 60) tools = [ { @@ -90,7 +90,7 @@ def main(): }, "required": ["item_name"], }, - } + }, ] prompt = "I need to pick up the sword and then move to coordinates (10, 20, 5)" @@ -107,9 +107,9 @@ def main(): logger.warning("Failed to parse tool call from response") # Test 3: Different temperatures - logger.info("\n" + "="*60) + logger.info("\n" + "=" * 60) logger.info("Test 3: Temperature Comparison") - logger.info("="*60) + logger.info("=" * 60) prompt = "The capital of France is" @@ -119,9 +119,9 @@ def main(): logger.info(f"Result: {result.text.strip()}") # Test 4: Conversation context - logger.info("\n" + "="*60) + logger.info("\n" + "=" * 60) logger.info("Test 4: Multi-turn Conversation") - logger.info("="*60) + logger.info("=" * 60) conversation = """[INST] You are a helpful AI assistant. [/INST] I understand. I'm here to help! [INST] What is the weather like today? [/INST]""" @@ -129,9 +129,9 @@ def main(): result = backend.generate(conversation, max_tokens=100) logger.info(f"Assistant: {result.text}") - logger.info("\n" + "="*60) + logger.info("\n" + "=" * 60) logger.info("All tests completed successfully!") - logger.info("="*60) + logger.info("=" * 60) # Clean up backend.unload() diff --git a/python/test_llama_gpu.py b/python/test_llama_gpu.py index f346920..4d71622 100644 --- a/python/test_llama_gpu.py +++ b/python/test_llama_gpu.py @@ -6,9 +6,10 @@ import logging import time -from backends import LlamaCppBackend, BackendConfig -logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') +from backends import BackendConfig, LlamaCppBackend + +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") logger = logging.getLogger(__name__) @@ -16,7 +17,7 @@ def test_inference(config: BackendConfig, test_name: str): """Test inference with given config.""" print(f"\n{'='*60}") print(f"{test_name}") - print('='*60) + print("=" * 60) start_time = time.time() backend = LlamaCppBackend(config) @@ -44,9 +45,9 @@ def test_inference(config: BackendConfig, test_name: str): def main(): - print("\n" + "="*60) + print("\n" + "=" * 60) print("GPU Acceleration Test for llama.cpp") - print("="*60) + print("=" * 60) model_path = "../models/llama-2-7b-chat.Q4_K_M.gguf" @@ -81,13 +82,13 @@ def main(): full_speed = test_inference(full_gpu_config, "Test 3: Full GPU (all layers)") # Summary - print("\n" + "="*60) + print("\n" + "=" * 60) print("Performance Summary") - print("="*60) + print("=" * 60) print(f"CPU only: {cpu_speed:.2f} tokens/sec (baseline)") print(f"Partial GPU: {partial_speed:.2f} tokens/sec ({partial_speed/cpu_speed:.2f}x speedup)") print(f"Full GPU: {full_speed:.2f} tokens/sec ({full_speed/cpu_speed:.2f}x speedup)") - print("="*60) + print("=" * 60) if full_speed > cpu_speed * 2: print("\n✓ GPU acceleration is working! Significant speedup achieved.") diff --git a/python/test_llama_simple.py b/python/test_llama_simple.py index 6e67b7b..c0fe0c2 100644 --- a/python/test_llama_simple.py +++ b/python/test_llama_simple.py @@ -5,16 +5,17 @@ """ import logging -from backends import LlamaCppBackend, BackendConfig -logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') +from backends import BackendConfig, LlamaCppBackend + +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") logger = logging.getLogger(__name__) def main(): - print("\n" + "="*60) + print("\n" + "=" * 60) print("Llama.cpp Backend - Simple Test") - print("="*60 + "\n") + print("=" * 60 + "\n") # Initialize backend config = BackendConfig( @@ -49,7 +50,7 @@ def main(): # Llama-2 chat format: [INST] question [/INST] prompt = "[INST] What is the capital of France? Answer in one word. [/INST]" - print(f"Question: What is the capital of France?") + print("Question: What is the capital of France?") result = backend.generate(prompt, temperature=0.1, max_tokens=10) print(f"Answer: {result.text.strip()}\n") @@ -86,15 +87,17 @@ def main(): print("Test 5: Temperature Comparison") print("-" * 40) - base_prompt = "[INST] Complete this sentence in a creative way: The robot opened the door and saw [/INST]" + base_prompt = ( + "[INST] Complete this sentence in a creative way: The robot opened the door and saw [/INST]" + ) for temp in [0.1, 0.5, 1.0]: result = backend.generate(base_prompt, temperature=temp, max_tokens=30) print(f"Temp {temp}: {result.text.strip()}") - print("\n" + "="*60) + print("\n" + "=" * 60) print("All tests completed!") - print("="*60) + print("=" * 60) backend.unload() diff --git a/python/test_quick_gpu.py b/python/test_quick_gpu.py index 3222e0f..96a6969 100644 --- a/python/test_quick_gpu.py +++ b/python/test_quick_gpu.py @@ -3,11 +3,12 @@ """ import time -from backends import LlamaCppBackend, BackendConfig -print("\n" + "="*60) +from backends import BackendConfig, LlamaCppBackend + +print("\n" + "=" * 60) print("Quick GPU Test") -print("="*60 + "\n") +print("=" * 60 + "\n") # Test with GPU config_gpu = BackendConfig( @@ -46,4 +47,4 @@ print("\nWARNING: Speed seems slow - GPU may not be fully utilized.") backend.unload() -print("\n" + "="*60) +print("\n" + "=" * 60) diff --git a/python/tools/__init__.py b/python/tools/__init__.py index 663a051..abd8db7 100644 --- a/python/tools/__init__.py +++ b/python/tools/__init__.py @@ -1,7 +1,7 @@ """ Agent Arena - Tool Library -Standard tools for agent world interaction. +Standard tools for agent world interaction and model management. """ from .inventory import register_inventory_tools @@ -13,3 +13,6 @@ "register_movement_tools", "register_inventory_tools", ] + +# ModelManager is available via: from tools.model_manager import ModelManager +# Not imported here to avoid circular import warnings when running as module diff --git a/python/tools/__main__.py b/python/tools/__main__.py new file mode 100644 index 0000000..d2018cb --- /dev/null +++ b/python/tools/__main__.py @@ -0,0 +1,11 @@ +""" +CLI entry point for tools.model_manager module. + +This allows running the model manager as: + python -m tools.model_manager [args] +""" + +from .model_manager import main + +if __name__ == "__main__": + main() diff --git a/python/tools/model_manager.py b/python/tools/model_manager.py new file mode 100644 index 0000000..afb71e2 --- /dev/null +++ b/python/tools/model_manager.py @@ -0,0 +1,547 @@ +""" +Model Download and Management Tool for Agent Arena. + +This module provides functionality to download, verify, and manage LLM models +from Hugging Face Hub for use with different backends (llama.cpp, vLLM, etc.). +""" + +import hashlib +import logging +import shutil +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import yaml # type: ignore[import-untyped] +from huggingface_hub import hf_hub_download # type: ignore[import-untyped] + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelInfo: + """Information about a model.""" + + name: str + huggingface_id: str + format: str + quantization: str | None = None + filename: str | None = None + sha256: str | None = None + size_bytes: int | None = None + + +class ModelManager: + """Manages model downloads, verification, and caching.""" + + def __init__(self, models_dir: Path | None = None, config_path: Path | None = None): + """ + Initialize the ModelManager. + + Args: + models_dir: Directory where models are cached. Defaults to ./models/ + config_path: Path to models.yaml config. Defaults to ./configs/models.yaml + """ + # Find project root (directory containing configs/) + if config_path is None: + project_root = self._find_project_root() + self.config_path = project_root / "configs" / "models.yaml" + else: + self.config_path = Path(config_path) + + if models_dir is None: + project_root = self._find_project_root() + self.models_dir = project_root / "models" + else: + self.models_dir = Path(models_dir) + + self.models_dir.mkdir(parents=True, exist_ok=True) + + # Load model registry + self.registry = self._load_registry() + + def _find_project_root(self) -> Path: + """Find the project root directory by looking for configs/ directory.""" + current = Path.cwd() + + # Check current directory and parents + for parent in [current] + list(current.parents): + if (parent / "configs" / "models.yaml").exists(): + return parent + + # If not found, use current directory + return current + + def _load_registry(self) -> dict[str, Any]: + """Load the model registry from YAML config.""" + if not self.config_path.exists(): + logger.warning(f"Model registry not found at {self.config_path}") + return {"models": {}} + + try: + with open(self.config_path) as f: + config = yaml.safe_load(f) + return config or {"models": {}} + except Exception as e: + logger.error(f"Failed to load model registry: {e}") + return {"models": {}} + + def _get_model_dir(self, model_id: str, format: str, quantization: str | None = None) -> Path: + """Get the directory path for a specific model.""" + model_dir = self.models_dir / model_id / format + if quantization: + model_dir = model_dir / quantization + return model_dir + + def _calculate_sha256(self, file_path: Path, chunk_size: int = 8192) -> str: + """Calculate SHA256 checksum of a file.""" + sha256_hash = hashlib.sha256() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(chunk_size), b""): + sha256_hash.update(chunk) + return sha256_hash.hexdigest() + + def verify_model(self, model_path: Path, expected_sha256: str | None = None) -> bool: + """ + Verify model file integrity. + + Args: + model_path: Path to the model file + expected_sha256: Expected SHA256 checksum (if None, just checks file exists) + + Returns: + True if verification passes, False otherwise + """ + if not model_path.exists(): + logger.error(f"Model file not found: {model_path}") + return False + + if not expected_sha256: + logger.info("No checksum provided, skipping verification") + return True + + logger.info("Calculating checksum...") + actual_sha256 = self._calculate_sha256(model_path) + + if actual_sha256 != expected_sha256: + logger.error( + f"Checksum mismatch!\nExpected: {expected_sha256}\nActual: {actual_sha256}" + ) + return False + + logger.info("Checksum verified successfully") + return True + + def download_model( + self, + model_id: str, + format: str = "gguf", + quantization: str | None = None, + force: bool = False, + ) -> Path | None: + """ + Download a model from Hugging Face Hub. + + Args: + model_id: Model identifier (e.g., "llama-2-7b-chat") + format: Model format ("gguf", "pytorch", etc.) + quantization: Quantization type (e.g., "q4_k_m", "q5_k_m") + force: Force re-download even if model exists + + Returns: + Path to the downloaded model file, or None if download failed + """ + # Look up model in registry + if model_id not in self.registry.get("models", {}): + logger.error(f"Model '{model_id}' not found in registry") + logger.info(f"Available models: {', '.join(self.registry.get('models', {}).keys())}") + return None + + model_config = self.registry["models"][model_id] + huggingface_id = model_config.get("huggingface_id") + if not huggingface_id: + logger.error(f"No huggingface_id specified for model '{model_id}'") + return None + + # Get format-specific config + formats = model_config.get("formats", {}) + if format not in formats: + logger.error(f"Format '{format}' not available for model '{model_id}'") + logger.info(f"Available formats: {', '.join(formats.keys())}") + return None + + format_config = formats[format] + + # Get quantization-specific config + if quantization: + if quantization not in format_config: + logger.error(f"Quantization '{quantization}' not available for {model_id}/{format}") + logger.info(f"Available quantizations: {', '.join(format_config.keys())}") + return None + quant_config = format_config[quantization] + else: + # If no quantization specified, use the format config directly + quant_config = format_config + + filename = quant_config.get("file") + if not filename: + logger.error(f"No filename specified for {model_id}/{format}/{quantization}") + return None + + expected_sha256 = quant_config.get("sha256") + + # Determine download path + model_dir = self._get_model_dir(model_id, format, quantization) + model_path = model_dir / filename + + # Check if model already exists and is valid + if model_path.exists() and not force: + logger.info(f"Model already exists at {model_path}") + if self.verify_model(model_path, expected_sha256): + logger.info("Existing model is valid, skipping download") + return Path(model_path) + else: + logger.warning("Existing model is invalid, re-downloading...") + + # Create directory + model_dir.mkdir(parents=True, exist_ok=True) + + # Download from Hugging Face Hub + logger.info(f"Downloading {model_id} ({format}/{quantization or 'default'})") + logger.info(f"Source: {huggingface_id}") + logger.info(f"File: {filename}") + + try: + downloaded_path_str: str = hf_hub_download( + repo_id=huggingface_id, + filename=filename, + local_dir=model_dir, + local_dir_use_symlinks=False, + resume_download=True, + ) + + # Move file to expected location if needed + downloaded_path = Path(downloaded_path_str) + if downloaded_path != model_path: + shutil.move(str(downloaded_path), str(model_path)) + + logger.info(f"Download complete: {model_path}") + + # Verify checksum + if expected_sha256: + if not self.verify_model(model_path, expected_sha256): + logger.error("Checksum verification failed!") + return None + + return Path(model_path) + + except Exception as e: + logger.error(f"Download failed: {e}") + return None + + def list_models(self, format_filter: str | None = None) -> list[dict[str, Any]]: + """ + List all cached models. + + Args: + format_filter: Optional filter by format (e.g., "gguf") + + Returns: + List of model information dictionaries + """ + cached_models: list[dict[str, Any]] = [] + + if not self.models_dir.exists(): + return cached_models + + # Walk through models directory + for model_dir in self.models_dir.iterdir(): + if not model_dir.is_dir(): + continue + + model_name = model_dir.name + + for format_dir in model_dir.iterdir(): + if not format_dir.is_dir(): + continue + + format_name = format_dir.name + + if format_filter and format_name != format_filter: + continue + + # Check for quantization subdirectories + for item in format_dir.iterdir(): + if item.is_dir(): + # Quantization directory + quant_name = item.name + for model_file in item.iterdir(): + if model_file.is_file(): + size_bytes = model_file.stat().st_size + cached_models.append( + { + "model": model_name, + "format": format_name, + "quantization": quant_name, + "file": model_file.name, + "path": str(model_file), + "size_bytes": size_bytes, + "size_gb": size_bytes / (1024**3), + } + ) + elif item.is_file(): + # Model file directly in format directory + size_bytes = item.stat().st_size + cached_models.append( + { + "model": model_name, + "format": format_name, + "quantization": None, + "file": item.name, + "path": str(item), + "size_bytes": size_bytes, + "size_gb": size_bytes / (1024**3), + } + ) + + return cached_models + + def get_model_path( + self, model_id: str, format: str = "gguf", quantization: str | None = None + ) -> Path | None: + """ + Get the local path to a cached model. + + Args: + model_id: Model identifier + format: Model format + quantization: Quantization type (optional) + + Returns: + Path to the model file if it exists, None otherwise + """ + model_dir = self._get_model_dir(model_id, format, quantization) + + if not model_dir.exists(): + return None + + # Find the first model file in the directory + for item in model_dir.iterdir(): + if item.is_file() and not item.name.startswith("."): + return item + + return None + + def remove_model( + self, model_id: str, format: str | None = None, quantization: str | None = None + ) -> bool: + """ + Remove a cached model. + + Args: + model_id: Model identifier + format: Model format (if None, removes all formats) + quantization: Quantization type (if None, removes all quantizations) + + Returns: + True if model was removed, False otherwise + """ + model_base_dir = self.models_dir / model_id + + if not model_base_dir.exists(): + logger.warning(f"Model '{model_id}' not found in cache") + return False + + if format is None: + # Remove entire model directory + shutil.rmtree(model_base_dir) + logger.info(f"Removed all versions of model '{model_id}'") + return True + + format_dir = model_base_dir / format + + if not format_dir.exists(): + logger.warning(f"Format '{format}' not found for model '{model_id}'") + return False + + if quantization is None: + # Remove entire format directory + shutil.rmtree(format_dir) + logger.info(f"Removed {model_id}/{format}") + + # Clean up empty model directory + if not any(model_base_dir.iterdir()): + model_base_dir.rmdir() + + return True + + quant_dir = format_dir / quantization + + if not quant_dir.exists(): + logger.warning(f"Quantization '{quantization}' not found for {model_id}/{format}") + return False + + # Remove quantization directory + shutil.rmtree(quant_dir) + logger.info(f"Removed {model_id}/{format}/{quantization}") + + # Clean up empty directories + if not any(format_dir.iterdir()): + format_dir.rmdir() + if not any(model_base_dir.iterdir()): + model_base_dir.rmdir() + + return True + + +def main(): + """CLI entry point.""" + import argparse + + parser = argparse.ArgumentParser( + description="Model Download and Management Tool for Agent Arena" + ) + subparsers = parser.add_subparsers(dest="command", help="Command to execute") + + # Download command + download_parser = subparsers.add_parser("download", help="Download a model") + download_parser.add_argument("model_id", help="Model identifier") + download_parser.add_argument("--format", default="gguf", help="Model format (default: gguf)") + download_parser.add_argument( + "--quant", "--quantization", dest="quantization", help="Quantization type (e.g., q4_k_m)" + ) + download_parser.add_argument( + "--force", action="store_true", help="Force re-download even if model exists" + ) + + # List command + list_parser = subparsers.add_parser("list", help="List cached models") + list_parser.add_argument("--format", help="Filter by format") + + # Verify command + verify_parser = subparsers.add_parser("verify", help="Verify a model") + verify_parser.add_argument("model_id", help="Model identifier") + verify_parser.add_argument("--format", default="gguf", help="Model format") + verify_parser.add_argument("--quant", dest="quantization", help="Quantization type") + + # Remove command + remove_parser = subparsers.add_parser("remove", help="Remove a cached model") + remove_parser.add_argument("model_id", help="Model identifier") + remove_parser.add_argument("--format", help="Model format") + remove_parser.add_argument("--quant", dest="quantization", help="Quantization type") + + # Info command + info_parser = subparsers.add_parser("info", help="Show information about available models") + info_parser.add_argument("model_id", nargs="?", help="Specific model to show info for") + + args = parser.parse_args() + + # Setup logging + logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + + manager = ModelManager() + + if args.command == "download": + model_path = manager.download_model( + args.model_id, + format=args.format, + quantization=args.quantization, + force=args.force, + ) + if model_path: + print("\n[SUCCESS] Model downloaded successfully!") + print(f"Path: {model_path}") + else: + print("\n[FAILED] Download failed") + exit(1) + + elif args.command == "list": + models = manager.list_models(format_filter=args.format) + + if not models: + print("No cached models found") + return + + print("\nCached Models:") + print("=" * 80) + + total_size = 0 + for model in models: + quant = f"/{model['quantization']}" if model["quantization"] else "" + size_gb = model["size_gb"] + total_size += model["size_bytes"] + + print(f"{model['model']:<25} {model['format']:<10}{quant:<15} {size_gb:>6.2f} GB") + + print("=" * 80) + print(f"Total storage: {total_size / (1024**3):.2f} GB") + + elif args.command == "verify": + model_path = manager.get_model_path( + args.model_id, format=args.format, quantization=args.quantization + ) + + if not model_path: + print(f"[FAILED] Model not found: {args.model_id}") + exit(1) + + # Get expected SHA256 from registry + expected_sha256 = None + if args.model_id in manager.registry.get("models", {}): + model_config = manager.registry["models"][args.model_id] + formats = model_config.get("formats", {}) + if args.format in formats: + format_config = formats[args.format] + if args.quantization and args.quantization in format_config: + expected_sha256 = format_config[args.quantization].get("sha256") + else: + expected_sha256 = format_config.get("sha256") + + if manager.verify_model(model_path, expected_sha256): + print(f"[SUCCESS] Model verified successfully: {model_path}") + else: + print(f"[FAILED] Verification failed: {model_path}") + exit(1) + + elif args.command == "remove": + if manager.remove_model(args.model_id, format=args.format, quantization=args.quantization): + print("[SUCCESS] Model removed successfully") + else: + print("[FAILED] Failed to remove model") + exit(1) + + elif args.command == "info": + if args.model_id: + # Show info for specific model + if args.model_id not in manager.registry.get("models", {}): + print(f"Model '{args.model_id}' not found in registry") + exit(1) + + model_config = manager.registry["models"][args.model_id] + print(f"\nModel: {args.model_id}") + print(f"Hugging Face ID: {model_config.get('huggingface_id')}") + print("\nAvailable formats:") + for fmt, fmt_config in model_config.get("formats", {}).items(): + print(f" {fmt}:") + if isinstance(fmt_config, dict): + for quant, quant_config in fmt_config.items(): + if isinstance(quant_config, dict): + print(f" {quant}: {quant_config.get('file', 'N/A')}") + else: + # List all available models + print("\nAvailable models in registry:") + print("=" * 80) + for model_id in manager.registry.get("models", {}).keys(): + model_config = manager.registry["models"][model_id] + hf_id = model_config.get("huggingface_id", "N/A") + formats = ", ".join(model_config.get("formats", {}).keys()) + print(f"{model_id:<25} {hf_id:<40} [{formats}]") + + else: + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py new file mode 100644 index 0000000..3077758 --- /dev/null +++ b/tests/test_model_manager.py @@ -0,0 +1,287 @@ +""" +Unit tests for the Model Manager. +""" + +import hashlib +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +import yaml + +from tools.model_manager import ModelManager + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def mock_config(temp_dir): + """Create a mock model configuration.""" + config = { + "models": { + "test-model": { + "huggingface_id": "test/test-model-gguf", + "formats": { + "gguf": { + "q4_k_m": { + "file": "test-model.Q4_K_M.gguf", + "sha256": "abc123", + } + } + }, + } + } + } + + config_path = temp_dir / "models.yaml" + with open(config_path, "w") as f: + yaml.dump(config, f) + + return config_path + + +@pytest.fixture +def model_manager(temp_dir, mock_config): + """Create a ModelManager instance for testing.""" + return ModelManager(models_dir=temp_dir / "models", config_path=mock_config) + + +class TestModelManager: + """Tests for ModelManager class.""" + + def test_init(self, model_manager, temp_dir): + """Test ModelManager initialization.""" + assert model_manager.models_dir == temp_dir / "models" + assert model_manager.models_dir.exists() + assert "models" in model_manager.registry + assert "test-model" in model_manager.registry["models"] + + def test_get_model_dir(self, model_manager, temp_dir): + """Test _get_model_dir method.""" + model_dir = model_manager._get_model_dir("test-model", "gguf", "q4_k_m") + expected = temp_dir / "models" / "test-model" / "gguf" / "q4_k_m" + assert model_dir == expected + + def test_calculate_sha256(self, model_manager, temp_dir): + """Test SHA256 calculation.""" + test_file = temp_dir / "test.txt" + test_content = b"Hello, World!" + test_file.write_bytes(test_content) + + expected_hash = hashlib.sha256(test_content).hexdigest() + actual_hash = model_manager._calculate_sha256(test_file) + + assert actual_hash == expected_hash + + def test_verify_model_file_not_found(self, model_manager, temp_dir): + """Test verify_model with non-existent file.""" + nonexistent = temp_dir / "nonexistent.gguf" + assert not model_manager.verify_model(nonexistent) + + def test_verify_model_no_checksum(self, model_manager, temp_dir): + """Test verify_model without checksum (just file existence).""" + test_file = temp_dir / "test.gguf" + test_file.write_bytes(b"test data") + + assert model_manager.verify_model(test_file, expected_sha256=None) + + def test_verify_model_checksum_match(self, model_manager, temp_dir): + """Test verify_model with matching checksum.""" + test_file = temp_dir / "test.gguf" + test_content = b"test data" + test_file.write_bytes(test_content) + + expected_hash = hashlib.sha256(test_content).hexdigest() + assert model_manager.verify_model(test_file, expected_sha256=expected_hash) + + def test_verify_model_checksum_mismatch(self, model_manager, temp_dir): + """Test verify_model with mismatched checksum.""" + test_file = temp_dir / "test.gguf" + test_file.write_bytes(b"test data") + + wrong_hash = "0" * 64 + assert not model_manager.verify_model(test_file, expected_sha256=wrong_hash) + + def test_list_models_empty(self, model_manager): + """Test list_models with no cached models.""" + models = model_manager.list_models() + assert models == [] + + def test_list_models_with_content(self, model_manager, temp_dir): + """Test list_models with cached models.""" + # Create a fake model structure + model_dir = temp_dir / "models" / "test-model" / "gguf" / "q4_k_m" + model_dir.mkdir(parents=True, exist_ok=True) + + model_file = model_dir / "test-model.Q4_K_M.gguf" + model_file.write_bytes(b"fake model data") + + models = model_manager.list_models() + assert len(models) == 1 + assert models[0]["model"] == "test-model" + assert models[0]["format"] == "gguf" + assert models[0]["quantization"] == "q4_k_m" + assert models[0]["file"] == "test-model.Q4_K_M.gguf" + + def test_list_models_with_format_filter(self, model_manager, temp_dir): + """Test list_models with format filter.""" + # Create models with different formats + gguf_dir = temp_dir / "models" / "test-model" / "gguf" + gguf_dir.mkdir(parents=True, exist_ok=True) + (gguf_dir / "model.gguf").write_bytes(b"data") + + pytorch_dir = temp_dir / "models" / "test-model" / "pytorch" + pytorch_dir.mkdir(parents=True, exist_ok=True) + (pytorch_dir / "model.pt").write_bytes(b"data") + + # Filter by gguf + models = model_manager.list_models(format_filter="gguf") + assert len(models) == 1 + assert models[0]["format"] == "gguf" + + # Filter by pytorch + models = model_manager.list_models(format_filter="pytorch") + assert len(models) == 1 + assert models[0]["format"] == "pytorch" + + def test_get_model_path_exists(self, model_manager, temp_dir): + """Test get_model_path when model exists.""" + model_dir = temp_dir / "models" / "test-model" / "gguf" / "q4_k_m" + model_dir.mkdir(parents=True, exist_ok=True) + + model_file = model_dir / "test-model.Q4_K_M.gguf" + model_file.write_bytes(b"fake model data") + + path = model_manager.get_model_path("test-model", "gguf", "q4_k_m") + assert path == model_file + + def test_get_model_path_not_exists(self, model_manager): + """Test get_model_path when model doesn't exist.""" + path = model_manager.get_model_path("nonexistent", "gguf", "q4_k_m") + assert path is None + + def test_remove_model_not_found(self, model_manager): + """Test remove_model with non-existent model.""" + assert not model_manager.remove_model("nonexistent") + + def test_remove_model_entire_model(self, model_manager, temp_dir): + """Test removing entire model (all formats).""" + # Create model structure + model_dir = temp_dir / "models" / "test-model" / "gguf" / "q4_k_m" + model_dir.mkdir(parents=True, exist_ok=True) + (model_dir / "model.gguf").write_bytes(b"data") + + # Remove entire model + assert model_manager.remove_model("test-model") + assert not (temp_dir / "models" / "test-model").exists() + + def test_remove_model_specific_format(self, model_manager, temp_dir): + """Test removing specific format.""" + # Create multiple formats + gguf_dir = temp_dir / "models" / "test-model" / "gguf" + gguf_dir.mkdir(parents=True, exist_ok=True) + (gguf_dir / "model.gguf").write_bytes(b"data") + + pytorch_dir = temp_dir / "models" / "test-model" / "pytorch" + pytorch_dir.mkdir(parents=True, exist_ok=True) + (pytorch_dir / "model.pt").write_bytes(b"data") + + # Remove only gguf format + assert model_manager.remove_model("test-model", format="gguf") + assert not gguf_dir.exists() + assert pytorch_dir.exists() + + def test_remove_model_specific_quantization(self, model_manager, temp_dir): + """Test removing specific quantization.""" + # Create multiple quantizations + q4_dir = temp_dir / "models" / "test-model" / "gguf" / "q4_k_m" + q4_dir.mkdir(parents=True, exist_ok=True) + (q4_dir / "model.gguf").write_bytes(b"data") + + q5_dir = temp_dir / "models" / "test-model" / "gguf" / "q5_k_m" + q5_dir.mkdir(parents=True, exist_ok=True) + (q5_dir / "model.gguf").write_bytes(b"data") + + # Remove only q4 + assert model_manager.remove_model("test-model", format="gguf", quantization="q4_k_m") + assert not q4_dir.exists() + assert q5_dir.exists() + + @patch("tools.model_manager.hf_hub_download") + def test_download_model_success(self, mock_download, model_manager, temp_dir): + """Test successful model download.""" + # Setup mock + model_dir = temp_dir / "models" / "test-model" / "gguf" / "q4_k_m" + model_dir.mkdir(parents=True, exist_ok=True) + model_file = model_dir / "test-model.Q4_K_M.gguf" + + # Create fake file content + test_content = b"fake model content" + + # Mock download to create the file + def mock_download_func(*args, **kwargs): + model_file.write_bytes(test_content) + return str(model_file) + + mock_download.side_effect = mock_download_func + + # Update config to have correct checksum + expected_hash = hashlib.sha256(test_content).hexdigest() + model_manager.registry["models"]["test-model"]["formats"]["gguf"]["q4_k_m"][ + "sha256" + ] = expected_hash + + # Download model + result = model_manager.download_model("test-model", "gguf", "q4_k_m") + + assert result == model_file + assert model_file.exists() + mock_download.assert_called_once() + + def test_download_model_not_in_registry(self, model_manager): + """Test download_model with model not in registry.""" + result = model_manager.download_model("nonexistent", "gguf") + assert result is None + + def test_download_model_format_not_available(self, model_manager): + """Test download_model with unavailable format.""" + result = model_manager.download_model("test-model", "pytorch") + assert result is None + + def test_download_model_quantization_not_available(self, model_manager): + """Test download_model with unavailable quantization.""" + result = model_manager.download_model("test-model", "gguf", "q8_0") + assert result is None + + @patch("tools.model_manager.hf_hub_download") + def test_download_model_skip_existing(self, mock_download, model_manager, temp_dir): + """Test that existing valid model is not re-downloaded.""" + # Create existing model + model_dir = temp_dir / "models" / "test-model" / "gguf" / "q4_k_m" + model_dir.mkdir(parents=True, exist_ok=True) + model_file = model_dir / "test-model.Q4_K_M.gguf" + + test_content = b"existing model" + model_file.write_bytes(test_content) + + # Update config with correct checksum + expected_hash = hashlib.sha256(test_content).hexdigest() + model_manager.registry["models"]["test-model"]["formats"]["gguf"]["q4_k_m"][ + "sha256" + ] = expected_hash + + # Try to download (should skip) + result = model_manager.download_model("test-model", "gguf", "q4_k_m", force=False) + + assert result == model_file + mock_download.assert_not_called() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])