From ab4cd8684f59a8fd9607cf83188f42524876f995 Mon Sep 17 00:00:00 2001 From: xprilion Date: Tue, 28 Apr 2026 14:41:41 +0530 Subject: [PATCH] Add huggingface tools and improve execution modes --- backend/configs/prompts/system_prompt.yaml | 50 +- backend/openmlr/agent/loop.py | 11 +- backend/openmlr/agent/session.py | 3 + backend/openmlr/models.py | 6 +- backend/openmlr/routes/settings.py | 12 + backend/openmlr/tasks/agent_tasks.py | 2 +- backend/openmlr/tools/ask_user.py | 4 +- backend/openmlr/tools/huggingface.py | 533 ++++++++++++++++++++ backend/openmlr/tools/registry.py | 22 +- backend/openmlr/tools/research.py | 19 + backend/tests/test_agent_loop.py | 42 +- backend/tests/test_models.py | 22 +- backend/tests/test_routes_settings.py | 34 ++ backend/tests/test_session.py | 21 + backend/tests/test_tool_registry.py | 47 +- backend/tests/test_tools_huggingface.py | 553 +++++++++++++++++++++ backend/tests/test_tools_research.py | 39 ++ frontend/src/App.tsx | 30 +- frontend/src/__tests__/InputArea.test.tsx | 32 +- frontend/src/components/InputArea.tsx | 63 ++- frontend/src/components/QuestionDrawer.tsx | 40 +- 21 files changed, 1500 insertions(+), 85 deletions(-) create mode 100644 backend/openmlr/tools/huggingface.py create mode 100644 backend/tests/test_tools_huggingface.py diff --git a/backend/configs/prompts/system_prompt.yaml b/backend/configs/prompts/system_prompt.yaml index ae4c1d2..7647b59 100644 --- a/backend/configs/prompts/system_prompt.yaml +++ b/backend/configs/prompts/system_prompt.yaml @@ -11,35 +11,39 @@ prompt: | {% if mode == "plan" %} ## CURRENT MODE: PLAN - + You are in **Plan mode**. Your job is to understand the task, ask clarifying questions, gather context, and produce a comprehensive plan. - - **Available tools**: ask_user, plan_tool, read_file, list_dir, glob_files, - grep_search, web_search, papers, github_search, github_read_file, github_read_repo, - github_search_repos, github_get_readme - - **NOT available**: writing, research sub-agent, sandbox/code execution tools. - Calls to unavailable tools will be rejected. - + + **Available tools**: ask_user, plan_tool, read (local files), web_search, + papers, github_read_file, github_find_examples, github_search_repos, + github_get_readme, github_list_repos, hf_search_models, hf_model_info, + hf_search_datasets, hf_dataset_info, hf_read_file, compute_list, + compute_plan, compute_probe, workspace + + **NOT available**: bash, write, edit, writing, research sub-agent, + sandbox/code execution, compute_select, compute_sync_up, compute_sync_down. + Calls to unavailable tools will be **rejected by the system**. + **Rules**: 1. Ask clarifying questions using `ask_user` before making assumptions - 2. Search the web, papers, and code repos to gather context - 3. Create a structured plan using `plan_tool` with clear, actionable tasks - 4. The plan is auto-saved as PLAN.md in resources — the user can see it - 5. Do NOT execute any work — plan only - 6. Do NOT write content, run code, or make changes - 7. Be thorough in your plan — it will be the blueprint for Execute mode - + 2. Search the web, papers, GitHub, and Hugging Face to gather context + 3. Read local project files with `read` to understand the codebase + 4. Create a structured plan using `plan_tool` with clear, actionable tasks + 5. The plan is auto-saved as PLAN.md in resources — the user can see it + 6. Do NOT execute any work — plan only + 7. Do NOT write content, run code, or make changes + 8. Be thorough in your plan — it will be the blueprint for Execute mode + {% elif mode == "execute" %} ## CURRENT MODE: EXECUTE - + You are in **Execute mode**. Your job is to follow the plan and do the work. Do NOT ask questions — just execute. - + **Available tools**: ALL tools EXCEPT ask_user. - Calls to ask_user will be rejected. - + Calls to ask_user will be **rejected by the system**. + **Rules**: 1. Follow the task plan — check it with `plan_tool get` if unsure 2. Work through tasks one at a time, marking them in_progress then completed @@ -49,10 +53,10 @@ prompt: | decision and document it in your completion report 6. Keep pushing through the task list until done or interrupted 7. Generate completion reports for each task - + {% else %} - ## CURRENT MODE: EXECUTE (default) - All tools available except ask_user. Execute the work. + ## CURRENT MODE: PLAN (default) + Plan only — ask questions, gather context, create plan. No execution. {% endif %} # Task Management diff --git a/backend/openmlr/agent/loop.py b/backend/openmlr/agent/loop.py index 5128f56..133f631 100644 --- a/backend/openmlr/agent/loop.py +++ b/backend/openmlr/agent/loop.py @@ -47,8 +47,15 @@ async def _run_agent(session: Session, tool_router, user_message: str, mode: str if session.pending_approval: session.pending_approval = None - # Set the mode on the tool router for strict enforcement - effective_mode = mode if mode in ("plan", "execute") else "execute" + # Set the mode on the tool router for strict enforcement. + # Default to plan (safe) if mode is missing or invalid. + # If mode is explicitly provided, use it and persist on session. + # If not provided (e.g. approval continuation), fall back to session's stored mode. + if mode in ("plan", "execute"): + effective_mode = mode + session.current_mode = mode + else: + effective_mode = session.current_mode # preserved from the last explicit mode tool_router.set_mode(effective_mode) # Inject per-message mode hint (short reinforcement of system prompt rules) diff --git a/backend/openmlr/agent/session.py b/backend/openmlr/agent/session.py index 4c7672b..8051f29 100644 --- a/backend/openmlr/agent/session.py +++ b/backend/openmlr/agent/session.py @@ -25,6 +25,9 @@ class Session: # Cancellation _cancelled: asyncio.Event = field(default_factory=asyncio.Event) + # Mode tracking (plan/execute) — persists across approval continuations + current_mode: str = "plan" + # Approval flow pending_approval: dict | None = None diff --git a/backend/openmlr/models.py b/backend/openmlr/models.py index ac21a15..5e339a0 100644 --- a/backend/openmlr/models.py +++ b/backend/openmlr/models.py @@ -1,7 +1,7 @@ """Pydantic models for API requests and responses.""" from datetime import datetime -from typing import Any +from typing import Any, Literal from pydantic import BaseModel, Field @@ -71,7 +71,9 @@ class ConversationDetail(BaseModel): class MessageSend(BaseModel): message: str - mode: str | None = None # plan, research, write — per-message mode override + mode: Literal["plan", "execute"] | None = ( + None # per-message mode; only plan or execute accepted + ) class ApprovalRequest(BaseModel): diff --git a/backend/openmlr/routes/settings.py b/backend/openmlr/routes/settings.py index a0a813c..ce30d3d 100644 --- a/backend/openmlr/routes/settings.py +++ b/backend/openmlr/routes/settings.py @@ -64,6 +64,7 @@ async def update_setting( "openalex_api_key": "OPENALEX_API_KEY", "modal_token_id": "MODAL_TOKEN_ID", "modal_token_secret": "MODAL_TOKEN_SECRET", + "hf_token": "HF_TOKEN", } env_key = env_key_map.get(key) if env_key and isinstance(value, str): @@ -127,6 +128,7 @@ def _is_provider_configured(provider_id: str, provider_settings: dict) -> bool: "semantic_scholar": "SEMANTIC_SCHOLAR_API_KEY", "openalex": "OPENALEX_API_KEY", "modal": "MODAL_TOKEN_ID", + "huggingface": "HF_TOKEN", } env_key = env_map.get(provider_id) if env_key and os.environ.get(env_key): @@ -143,6 +145,7 @@ def _is_provider_configured(provider_id: str, provider_settings: dict) -> bool: "semantic_scholar": "semantic_scholar_api_key", "openalex": "openalex_api_key", "modal": "modal_token_id", + "huggingface": "hf_token", }.get(provider_id) if setting_key and provider_settings.get(setting_key): return True @@ -258,6 +261,14 @@ async def list_providers( "categories": ["compute"], "docs_url": "https://modal.com/docs", }, + { + "id": "huggingface", + "name": "Hugging Face", + "key_env": "HF_TOKEN", + "configured": _is_provider_configured("huggingface", provider_settings), + "categories": ["models", "papers"], + "docs_url": "https://huggingface.co/docs/hub/security-tokens", + }, ] # Add custom providers @@ -759,6 +770,7 @@ async def save_config( "OPENALEX_API_KEY", "MODAL_TOKEN_ID", "MODAL_TOKEN_SECRET", + "HF_TOKEN", } body = await request.json() diff --git a/backend/openmlr/tasks/agent_tasks.py b/backend/openmlr/tasks/agent_tasks.py index 9da7fa9..3e9dd2b 100644 --- a/backend/openmlr/tasks/agent_tasks.py +++ b/backend/openmlr/tasks/agent_tasks.py @@ -116,7 +116,7 @@ async def _async_process_message( # Build and set system prompt session.context_manager.system_prompt = build_system_prompt( tool_specs=tool_router.get_raw_specs(), - mode=mode or "general", + mode=mode if mode in ("plan", "execute") else "plan", username="user", ) diff --git a/backend/openmlr/tools/ask_user.py b/backend/openmlr/tools/ask_user.py index 995c099..1f82b91 100644 --- a/backend/openmlr/tools/ask_user.py +++ b/backend/openmlr/tools/ask_user.py @@ -66,8 +66,8 @@ def create_ask_user_tool() -> ToolSpec: }, "suggest_mode": { "type": "string", - "description": "If confident, suggest the user switch to this mode after answering (e.g. 'research', 'write')", - "enum": ["research", "write"], + "description": "If the plan is ready and the user should start executing, set this to 'execute' to suggest switching to Execute mode after answering.", + "enum": ["execute"], }, }, "required": ["questions"], diff --git a/backend/openmlr/tools/huggingface.py b/backend/openmlr/tools/huggingface.py new file mode 100644 index 0000000..381768f --- /dev/null +++ b/backend/openmlr/tools/huggingface.py @@ -0,0 +1,533 @@ +"""Hugging Face Hub tools — model search, dataset search, model/dataset cards, file reading.""" + +import logging +import os + +from ..agent.types import ToolSpec +from .http_utils import RateLimitError, fetch_with_retry + +log = logging.getLogger(__name__) + +HF_API = "https://huggingface.co" + + +def _headers() -> dict: + token = os.environ.get("HF_TOKEN") + h: dict[str, str] = {} + if token: + h["Authorization"] = f"Bearer {token}" + return h + + +# --------------------------------------------------------------------------- +# Tool factory +# --------------------------------------------------------------------------- + + +def create_huggingface_tools() -> list[ToolSpec]: + return [ + ToolSpec( + name="hf_search_models", + description=( + "Search Hugging Face Hub for models by keyword, pipeline task, or library. " + "Useful for finding pre-trained models, fine-tuned checkpoints, and SOTA architectures." + ), + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query (e.g. 'llama 3 instruct', 'stable diffusion xl')", + }, + "pipeline_tag": { + "type": "string", + "description": ( + "Filter by pipeline task (e.g. text-generation, " + "image-classification, text-to-image, automatic-speech-recognition)" + ), + }, + "library": { + "type": "string", + "description": "Filter by library (e.g. transformers, diffusers, gguf, safetensors)", + }, + "sort": { + "type": "string", + "description": "Sort by: downloads, likes, trending, created, modified (default: trending)", + "enum": ["downloads", "likes", "trending", "created", "modified"], + }, + "limit": { + "type": "integer", + "description": "Max results to return (default 15)", + }, + }, + "required": ["query"], + }, + handler=_handle_search_models, + ), + ToolSpec( + name="hf_model_info", + description=( + "Get detailed information about a Hugging Face model, including its model card " + "(README), metadata, download count, pipeline tag, library, and tags. " + "Provide the full repo ID like 'deepseek-ai/DeepSeek-V3' or 'meta-llama/Llama-3-8B'." + ), + parameters={ + "type": "object", + "properties": { + "repo_id": { + "type": "string", + "description": "Model repo ID (e.g. 'meta-llama/Llama-3-8B', 'mistralai/Mistral-7B-v0.1')", + }, + "include_readme": { + "type": "boolean", + "description": "Whether to fetch the full model card / README (default: true)", + }, + }, + "required": ["repo_id"], + }, + handler=_handle_model_info, + ), + ToolSpec( + name="hf_search_datasets", + description=( + "Search Hugging Face Hub for datasets by keyword or task. " + "Useful for finding training data, benchmarks, and evaluation datasets." + ), + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query (e.g. 'code instruction', 'medical qa')", + }, + "task": { + "type": "string", + "description": ( + "Filter by task category (e.g. text-classification, " + "question-answering, summarization, text-generation)" + ), + }, + "sort": { + "type": "string", + "description": "Sort by: downloads, likes, trending, created, modified (default: trending)", + "enum": ["downloads", "likes", "trending", "created", "modified"], + }, + "limit": { + "type": "integer", + "description": "Max results to return (default 15)", + }, + }, + "required": ["query"], + }, + handler=_handle_search_datasets, + ), + ToolSpec( + name="hf_dataset_info", + description=( + "Get detailed information about a Hugging Face dataset, including its dataset card " + "(README), metadata, download count, and tags. " + "Provide the full repo ID like 'tatsu-lab/alpaca' or 'Open-Orca/OpenOrca'." + ), + parameters={ + "type": "object", + "properties": { + "repo_id": { + "type": "string", + "description": "Dataset repo ID (e.g. 'tatsu-lab/alpaca', 'Open-Orca/OpenOrca')", + }, + "include_readme": { + "type": "boolean", + "description": "Whether to fetch the full dataset card / README (default: true)", + }, + }, + "required": ["repo_id"], + }, + handler=_handle_dataset_info, + ), + ToolSpec( + name="hf_read_file", + description=( + "Read a file from a Hugging Face repository (model, dataset, or space). " + "Useful for reading config.json, tokenizer configs, training scripts, etc. " + "For model cards / READMEs, prefer hf_model_info or hf_dataset_info instead." + ), + parameters={ + "type": "object", + "properties": { + "repo_id": { + "type": "string", + "description": "Repo ID (e.g. 'meta-llama/Llama-3-8B')", + }, + "path": { + "type": "string", + "description": "File path within the repo (e.g. 'config.json', 'tokenizer_config.json')", + }, + "repo_type": { + "type": "string", + "description": "Repository type: model, dataset, space (default: model)", + "enum": ["model", "dataset", "space"], + }, + "revision": { + "type": "string", + "description": "Branch, tag, or commit hash (default: main)", + }, + }, + "required": ["repo_id", "path"], + }, + handler=_handle_read_file, + ), + ] + + +# --------------------------------------------------------------------------- +# Handlers +# --------------------------------------------------------------------------- + + +async def _handle_search_models( + query: str, + pipeline_tag: str | None = None, + library: str | None = None, + sort: str = "trending", + limit: int = 15, + **kwargs, +) -> tuple[str, bool]: + """Search Hugging Face Hub for models.""" + url = f"{HF_API}/api/models" + + params: dict[str, str | int] = { + "search": query, + "limit": min(limit, 50), + "sort": sort, + "direction": "-1", + } + if pipeline_tag: + params["pipeline_tag"] = pipeline_tag + if library: + params["library"] = library + + try: + resp = await fetch_with_retry( + url, headers=_headers(), params=params, timeout=30, max_retries=3 + ) + except RateLimitError: + return "Hugging Face rate limit reached. Try again later or add HF_TOKEN.", False + except Exception as e: + log.warning(f"HF search models error: {e}") + return f"Hugging Face API error: {str(e)[:200]}", False + + if resp.status_code != 200: + return f"Hugging Face API error {resp.status_code}: {resp.text[:500]}", False + + models = resp.json() + if not models: + return f"No models found for: {query}", True + + lines = [f"Found {len(models)} models for '{query}':\n"] + for m in models[:limit]: + model_id = m.get("modelId") or m.get("id", "?") + downloads = m.get("downloads", 0) + likes = m.get("likes", 0) + pipeline = m.get("pipeline_tag", "") + tags = m.get("tags", [])[:5] + lib_tags = [ + t + for t in tags + if t + in ( + "transformers", + "diffusers", + "gguf", + "safetensors", + "pytorch", + "tensorflow", + "jax", + "onnx", + "openvino", + ) + ] + + lines.append(f"**{model_id}** ({downloads:,} downloads, {likes} likes)") + meta_parts = [] + if pipeline: + meta_parts.append(f"Task: {pipeline}") + if lib_tags: + meta_parts.append(f"Libraries: {', '.join(lib_tags)}") + if meta_parts: + lines.append(f" {' | '.join(meta_parts)}") + lines.append(f" https://huggingface.co/{model_id}\n") + + return "\n".join(lines), True + + +async def _handle_model_info( + repo_id: str, include_readme: bool = True, **kwargs +) -> tuple[str, bool]: + """Get detailed info + model card for a Hugging Face model.""" + url = f"{HF_API}/api/models/{repo_id}" + + try: + resp = await fetch_with_retry(url, headers=_headers(), timeout=30, max_retries=3) + except RateLimitError: + return "Hugging Face rate limit reached. Try again later or add HF_TOKEN.", False + except Exception as e: + log.warning(f"HF model info error: {e}") + return f"Hugging Face API error: {str(e)[:200]}", False + + if resp.status_code == 404: + return f"Model not found: {repo_id}", False + if resp.status_code != 200: + return f"Hugging Face API error {resp.status_code}: {resp.text[:500]}", False + + data = resp.json() + + # Build metadata summary + lines = [f"# Model: {repo_id}\n"] + lines.append(f"- **URL**: https://huggingface.co/{repo_id}") + if data.get("pipeline_tag"): + lines.append(f"- **Pipeline**: {data['pipeline_tag']}") + if data.get("library_name"): + lines.append(f"- **Library**: {data['library_name']}") + lines.append(f"- **Downloads**: {data.get('downloads', 0):,}") + lines.append(f"- **Likes**: {data.get('likes', 0)}") + if data.get("author"): + lines.append(f"- **Author**: {data['author']}") + if data.get("lastModified"): + lines.append(f"- **Last modified**: {data['lastModified']}") + + tags = data.get("tags", []) + if tags: + lines.append(f"- **Tags**: {', '.join(tags[:15])}") + + siblings = data.get("siblings", []) + if siblings: + file_names = [s.get("rfilename", "") for s in siblings[:30]] + lines.append(f"\n**Files** ({len(siblings)} total): {', '.join(file_names)}") + if len(siblings) > 30: + lines.append(f" ... and {len(siblings) - 30} more") + + # Fetch README / model card + if include_readme: + readme_content = await _fetch_readme(repo_id, "model") + if readme_content: + lines.append(f"\n---\n\n## Model Card\n\n{readme_content}") + + output = "\n".join(lines) + if len(output) > 50000: + output = output[:50000] + "\n\n...[truncated]" + return output, True + + +async def _handle_search_datasets( + query: str, + task: str | None = None, + sort: str = "trending", + limit: int = 15, + **kwargs, +) -> tuple[str, bool]: + """Search Hugging Face Hub for datasets.""" + url = f"{HF_API}/api/datasets" + + params: dict[str, str | int] = { + "search": query, + "limit": min(limit, 50), + "sort": sort, + "direction": "-1", + } + if task: + params["task_categories"] = task + + try: + resp = await fetch_with_retry( + url, headers=_headers(), params=params, timeout=30, max_retries=3 + ) + except RateLimitError: + return "Hugging Face rate limit reached. Try again later or add HF_TOKEN.", False + except Exception as e: + log.warning(f"HF search datasets error: {e}") + return f"Hugging Face API error: {str(e)[:200]}", False + + if resp.status_code != 200: + return f"Hugging Face API error {resp.status_code}: {resp.text[:500]}", False + + datasets = resp.json() + if not datasets: + return f"No datasets found for: {query}", True + + lines = [f"Found {len(datasets)} datasets for '{query}':\n"] + for d in datasets[:limit]: + ds_id = d.get("id", "?") + downloads = d.get("downloads", 0) + likes = d.get("likes", 0) + description = (d.get("description") or "")[:120] + tags = d.get("tags", [])[:5] + + lines.append(f"**{ds_id}** ({downloads:,} downloads, {likes} likes)") + if description: + lines.append(f" {description}") + if tags: + lines.append(f" Tags: {', '.join(tags)}") + lines.append(f" https://huggingface.co/datasets/{ds_id}\n") + + return "\n".join(lines), True + + +async def _handle_dataset_info( + repo_id: str, include_readme: bool = True, **kwargs +) -> tuple[str, bool]: + """Get detailed info + dataset card for a Hugging Face dataset.""" + url = f"{HF_API}/api/datasets/{repo_id}" + + try: + resp = await fetch_with_retry(url, headers=_headers(), timeout=30, max_retries=3) + except RateLimitError: + return "Hugging Face rate limit reached. Try again later or add HF_TOKEN.", False + except Exception as e: + log.warning(f"HF dataset info error: {e}") + return f"Hugging Face API error: {str(e)[:200]}", False + + if resp.status_code == 404: + return f"Dataset not found: {repo_id}", False + if resp.status_code != 200: + return f"Hugging Face API error {resp.status_code}: {resp.text[:500]}", False + + data = resp.json() + + lines = [f"# Dataset: {repo_id}\n"] + lines.append(f"- **URL**: https://huggingface.co/datasets/{repo_id}") + lines.append(f"- **Downloads**: {data.get('downloads', 0):,}") + lines.append(f"- **Likes**: {data.get('likes', 0)}") + if data.get("author"): + lines.append(f"- **Author**: {data['author']}") + if data.get("lastModified"): + lines.append(f"- **Last modified**: {data['lastModified']}") + + tags = data.get("tags", []) + if tags: + lines.append(f"- **Tags**: {', '.join(tags[:15])}") + + description = data.get("description", "") + if description: + lines.append(f"- **Description**: {description[:300]}") + + siblings = data.get("siblings", []) + if siblings: + file_names = [s.get("rfilename", "") for s in siblings[:30]] + lines.append(f"\n**Files** ({len(siblings)} total): {', '.join(file_names)}") + if len(siblings) > 30: + lines.append(f" ... and {len(siblings) - 30} more") + + # Fetch README / dataset card + if include_readme: + readme_content = await _fetch_readme(repo_id, "dataset") + if readme_content: + lines.append(f"\n---\n\n## Dataset Card\n\n{readme_content}") + + output = "\n".join(lines) + if len(output) > 50000: + output = output[:50000] + "\n\n...[truncated]" + return output, True + + +async def _handle_read_file( + repo_id: str, + path: str, + repo_type: str = "model", + revision: str = "main", + **kwargs, +) -> tuple[str, bool]: + """Read a file from a Hugging Face repository.""" + # Build the resolve URL based on repo type + if repo_type == "dataset": + url = f"{HF_API}/datasets/{repo_id}/resolve/{revision}/{path}" + elif repo_type == "space": + url = f"{HF_API}/spaces/{repo_id}/resolve/{revision}/{path}" + else: + url = f"{HF_API}/{repo_id}/resolve/{revision}/{path}" + + try: + resp = await fetch_with_retry(url, headers=_headers(), timeout=30, max_retries=3) + except RateLimitError: + return "Hugging Face rate limit reached. Try again later or add HF_TOKEN.", False + except Exception as e: + log.warning(f"HF read file error: {e}") + return f"Hugging Face API error: {str(e)[:200]}", False + + if resp.status_code == 404: + return f"File not found: {repo_id}/{path} (revision: {revision})", False + if resp.status_code == 401: + return ( + f"Access denied for {repo_id}/{path}. " + "This may be a gated model — add HF_TOKEN with accepted access." + ), False + if resp.status_code == 403: + return ( + f"Forbidden: {repo_id}/{path}. " + "You may need to accept the model's license on huggingface.co first." + ), False + if resp.status_code != 200: + return f"Hugging Face API error {resp.status_code}: {resp.text[:500]}", False + + content_type = resp.headers.get("content-type", "") + + # Binary files — just report metadata + if "application/octet-stream" in content_type or path.endswith( + (".bin", ".safetensors", ".gguf", ".pt", ".pth", ".h5", ".onnx", ".msgpack") + ): + size = len(resp.content) + return ( + f"Binary file: {repo_id}/{path} ({size:,} bytes). " + "Cannot display binary content. Use hf_model_info to see the file list." + ), True + + # Text content + try: + text = resp.text + except Exception: + return f"Could not decode file: {repo_id}/{path}", False + + # Add line numbers for code-like files + if path.endswith((".py", ".js", ".ts", ".yaml", ".yml", ".toml", ".sh", ".md", ".rst", ".txt")): + lines = text.split("\n") + numbered = [f"{i + 1}: {line}" for i, line in enumerate(lines)] + output = "\n".join(numbered) + else: + output = text + + if len(output) > 50000: + output = output[:50000] + "\n\n...[truncated]" + + return f"# {repo_id}/{path}\n\n{output}", True + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +async def _fetch_readme(repo_id: str, repo_type: str = "model") -> str | None: + """Fetch the README.md from a Hugging Face repo. Returns None on failure.""" + if repo_type == "dataset": + url = f"{HF_API}/datasets/{repo_id}/resolve/main/README.md" + else: + url = f"{HF_API}/{repo_id}/resolve/main/README.md" + + try: + resp = await fetch_with_retry(url, headers=_headers(), timeout=30, max_retries=2) + except Exception: + return None + + if resp.status_code != 200: + return None + + content = resp.text + # Strip YAML frontmatter (common in HF READMEs) + if content.startswith("---"): + end = content.find("---", 3) + if end != -1: + content = content[end + 3 :].strip() + + if len(content) > 30000: + content = content[:30000] + "\n\n...[truncated]" + + return content diff --git a/backend/openmlr/tools/registry.py b/backend/openmlr/tools/registry.py index 1d8bc28..168142e 100644 --- a/backend/openmlr/tools/registry.py +++ b/backend/openmlr/tools/registry.py @@ -8,24 +8,28 @@ # Tools not listed are allowed in all modes MODE_TOOL_RESTRICTIONS = { "plan": { - # Plan mode: ask questions, create plans, read context — NO execution tools + # Plan mode: ask questions, create plans, read context — NO execution tools. + # Tool names here must EXACTLY match the registered ToolSpec.name values. "allowed": { "ask_user", "plan_tool", - # Read-only tools for gathering context - "read_file", - "list_dir", - "glob_files", - "grep_search", + # Read-only local filesystem access for gathering context + "read", + # Web / academic search "web_search", "papers", - "github_search", + # GitHub (read-only) "github_read_file", - "github_read_repo", "github_find_examples", "github_search_repos", "github_get_readme", "github_list_repos", + # Hugging Face (read-only model/dataset discovery) + "hf_search_models", + "hf_model_info", + "hf_search_datasets", + "hf_dataset_info", + "hf_read_file", # Compute planning (read-only / advisory) "compute_list", "compute_plan", @@ -255,6 +259,7 @@ def create_tool_router(sandbox_manager=None) -> ToolRouter: # Import and register all built-in tools from .ask_user import create_ask_user_tool from .github import create_github_tools + from .huggingface import create_huggingface_tools from .local import create_local_tools from .papers import create_papers_tool from .plan import create_plan_tool @@ -264,6 +269,7 @@ def create_tool_router(sandbox_manager=None) -> ToolRouter: router.register_many(create_local_tools()) router.register_many(create_github_tools()) + router.register_many(create_huggingface_tools()) router.register_many(create_search_tools()) router.register(create_research_tool()) router.register(create_plan_tool()) diff --git a/backend/openmlr/tools/research.py b/backend/openmlr/tools/research.py index 46cd98d..ac6603d 100644 --- a/backend/openmlr/tools/research.py +++ b/backend/openmlr/tools/research.py @@ -230,6 +230,7 @@ async def _handle_research( def _get_research_tool_specs() -> list[dict]: """Get the read-only tool subset for research.""" from .github import create_github_tools + from .huggingface import create_huggingface_tools from .papers import create_papers_tool from .search import create_search_tools @@ -271,12 +272,28 @@ def _get_research_tool_specs() -> list[dict]: } ) + # Hugging Face tools for ML research (model discovery + file reading) + for spec in create_huggingface_tools(): + if spec.name in ("hf_search_models", "hf_read_file"): + tools.append( + { + "type": "function", + "function": { + "name": spec.name, + "description": spec.description, + "parameters": spec.parameters, + }, + } + ) + return tools async def _execute_research_tool(tc: ToolCall) -> tuple[str, bool]: """Execute a tool call for the research sub-agent.""" from .github import _handle_find_examples, _handle_read_file + from .huggingface import _handle_read_file as _hf_handle_read_file + from .huggingface import _handle_search_models as _hf_handle_search_models from .papers import _handle_papers from .search import _handle_web_search @@ -285,6 +302,8 @@ async def _execute_research_tool(tc: ToolCall) -> tuple[str, bool]: "papers": _handle_papers, "github_read_file": _handle_read_file, "github_find_examples": _handle_find_examples, + "hf_search_models": _hf_handle_search_models, + "hf_read_file": _hf_handle_read_file, } handler = handlers.get(tc.name) diff --git a/backend/tests/test_agent_loop.py b/backend/tests/test_agent_loop.py index 92380ff..d093103 100644 --- a/backend/tests/test_agent_loop.py +++ b/backend/tests/test_agent_loop.py @@ -50,6 +50,7 @@ def mock_session(config): session.is_cancelled = MagicMock(return_value=False) session.pending_approval = None session.pending_answers = None + session.current_mode = "plan" session.turn_count = 0 session.on_event = MagicMock() session.update_model = MagicMock() @@ -234,7 +235,9 @@ async def test_delegates_to_run_agent(self, mock_session, mock_router): mock_router.set_mode.assert_called_with("plan") - async def test_default_mode_is_execute(self, mock_session, mock_router): + async def test_unknown_mode_falls_back_to_session_mode(self, mock_session, mock_router): + """When mode is invalid, fall back to session.current_mode (defaults to plan).""" + mock_session.current_mode = "plan" mock_session.context_manager.get_messages.return_value = [] mock_session.context_manager.needs_compaction.return_value = False mock_session.context_manager.get_token_usage.return_value = {"ratio": 0.0} @@ -248,7 +251,44 @@ async def test_default_mode_is_execute(self, mock_session, mock_router): ) await run_agent_turn(mock_session, mock_router, "test", mode="unknown") + mock_router.set_mode.assert_called_with("plan") + + async def test_null_mode_falls_back_to_session_mode(self, mock_session, mock_router): + """When mode is None, fall back to session.current_mode.""" + mock_session.current_mode = "execute" + mock_session.context_manager.get_messages.return_value = [] + mock_session.context_manager.needs_compaction.return_value = False + mock_session.context_manager.get_token_usage.return_value = {"ratio": 0.0} + mock_session.config.stream = False + + with patch("openmlr.agent.loop.LLMProvider.generate") as mock_gen: + mock_gen.return_value = LLMResult( + content="Ok", + tool_calls=[], + finish_reason="stop", + ) + await run_agent_turn(mock_session, mock_router, "test", mode=None) + + mock_router.set_mode.assert_called_with("execute") + + async def test_explicit_mode_updates_session(self, mock_session, mock_router): + """When a valid mode is passed, it should be stored on session.current_mode.""" + mock_session.current_mode = "plan" + mock_session.context_manager.get_messages.return_value = [] + mock_session.context_manager.needs_compaction.return_value = False + mock_session.context_manager.get_token_usage.return_value = {"ratio": 0.0} + mock_session.config.stream = False + + with patch("openmlr.agent.loop.LLMProvider.generate") as mock_gen: + mock_gen.return_value = LLMResult( + content="Ok", + tool_calls=[], + finish_reason="stop", + ) + await run_agent_turn(mock_session, mock_router, "test", mode="execute") + mock_router.set_mode.assert_called_with("execute") + assert mock_session.current_mode == "execute" # ── Submissions ──────────────────────────────────────────── diff --git a/backend/tests/test_models.py b/backend/tests/test_models.py index 73f5e80..3efc436 100644 --- a/backend/tests/test_models.py +++ b/backend/tests/test_models.py @@ -146,9 +146,25 @@ def test_basic(self): assert m.message == "Hello world" assert m.mode is None - def test_with_mode(self): - m = MessageSend(message="Research this", mode="research") - assert m.mode == "research" + def test_with_plan_mode(self): + m = MessageSend(message="Plan this", mode="plan") + assert m.mode == "plan" + + def test_with_execute_mode(self): + m = MessageSend(message="Do this", mode="execute") + assert m.mode == "execute" + + def test_rejects_invalid_mode(self): + with pytest.raises(ValidationError): + MessageSend(message="test", mode="research") + + def test_rejects_arbitrary_mode(self): + with pytest.raises(ValidationError): + MessageSend(message="test", mode="anything_else") + + def test_allows_null_mode(self): + m = MessageSend(message="test", mode=None) + assert m.mode is None class TestApprovalRequest: diff --git a/backend/tests/test_routes_settings.py b/backend/tests/test_routes_settings.py index 530065e..537394c 100644 --- a/backend/tests/test_routes_settings.py +++ b/backend/tests/test_routes_settings.py @@ -108,6 +108,40 @@ async def test_provider_has_required_fields(self, auth_client: AsyncClient): assert "key_env" in p assert "configured" in p + async def test_huggingface_provider_listed(self, auth_client: AsyncClient): + resp = await auth_client.get("/api/providers") + data = resp.json() + provider_ids = [p["id"] for p in data["providers"]] + assert "huggingface" in provider_ids + + async def test_huggingface_provider_fields(self, auth_client: AsyncClient): + resp = await auth_client.get("/api/providers") + data = resp.json() + hf = [p for p in data["providers"] if p["id"] == "huggingface"][0] + assert hf["name"] == "Hugging Face" + assert hf["key_env"] == "HF_TOKEN" + assert "models" in hf["categories"] + assert "papers" in hf["categories"] + assert "docs_url" in hf + + async def test_huggingface_token_sets_env( + self, auth_client: AsyncClient, db_session, test_user + ): + resp = await auth_client.put( + "/api/settings/providers/hf_token", + json={"value": "hf_test_token_123"}, + ) + assert resp.status_code == 200 + assert os.environ.get("HF_TOKEN") == "hf_test_token_123" + + async def test_huggingface_token_in_config_allowlist(self, auth_client: AsyncClient): + resp = await auth_client.post( + "/api/config", + json={"HF_TOKEN": "hf_from_config"}, + ) + assert resp.status_code == 200 + assert os.environ.get("HF_TOKEN") == "hf_from_config" + class TestAppStatus: async def test_get_status(self, auth_client: AsyncClient): diff --git a/backend/tests/test_session.py b/backend/tests/test_session.py index a57feb7..227b762 100644 --- a/backend/tests/test_session.py +++ b/backend/tests/test_session.py @@ -180,3 +180,24 @@ def test_update_model_reflected_in_context_manager(self, session: Session): """ContextManager shares the config object, so the change propagates.""" session.update_model("openai/gpt-4o") assert session.context_manager.config.model_name == "openai/gpt-4o" + + +# --------------------------------------------------------------------------- +# current_mode +# --------------------------------------------------------------------------- + + +class TestCurrentMode: + def test_default_mode_is_plan(self, session: Session): + """Session defaults to plan mode (safe default).""" + assert session.current_mode == "plan" + + def test_mode_can_be_set(self, session: Session): + session.current_mode = "execute" + assert session.current_mode == "execute" + + def test_mode_persists_across_reads(self, session: Session): + session.current_mode = "plan" + assert session.current_mode == "plan" + session.current_mode = "execute" + assert session.current_mode == "execute" diff --git a/backend/tests/test_tool_registry.py b/backend/tests/test_tool_registry.py index 2c9a61a..13b9162 100644 --- a/backend/tests/test_tool_registry.py +++ b/backend/tests/test_tool_registry.py @@ -171,9 +171,9 @@ async def test_plan_mode_allows_ask_user(self, router): allowed, msg = router.is_tool_allowed("ask_user") assert allowed is True - async def test_plan_mode_allows_read_file(self, router): + async def test_plan_mode_allows_read(self, router): router.set_mode("plan") - allowed, msg = router.is_tool_allowed("read_file") + allowed, msg = router.is_tool_allowed("read") assert allowed is True async def test_execute_mode_blocks_ask_user(self, router): @@ -229,8 +229,47 @@ async def test_execute_has_blocked_list(self): async def test_plan_allowed_includes_ask_user(self): assert "ask_user" in MODE_TOOL_RESTRICTIONS["plan"]["allowed"] - async def test_plan_allowed_includes_read_file(self): - assert "read_file" in MODE_TOOL_RESTRICTIONS["plan"]["allowed"] + async def test_plan_allowed_includes_read_tool(self): + assert "read" in MODE_TOOL_RESTRICTIONS["plan"]["allowed"] async def test_execute_blocked_includes_ask_user(self): assert "ask_user" in MODE_TOOL_RESTRICTIONS["execute"]["blocked"] + + async def test_plan_allowed_includes_hf_tools(self): + plan_allowed = MODE_TOOL_RESTRICTIONS["plan"]["allowed"] + assert "hf_search_models" in plan_allowed + assert "hf_model_info" in plan_allowed + assert "hf_search_datasets" in plan_allowed + assert "hf_dataset_info" in plan_allowed + assert "hf_read_file" in plan_allowed + + async def test_plan_allowed_includes_read(self): + """The local 'read' tool should be allowed in plan mode for context gathering.""" + assert "read" in MODE_TOOL_RESTRICTIONS["plan"]["allowed"] + + async def test_plan_blocks_execution_tools(self): + """Execution tools must NOT be in the plan allowlist.""" + plan_allowed = MODE_TOOL_RESTRICTIONS["plan"]["allowed"] + assert "bash" not in plan_allowed + assert "write" not in plan_allowed + assert "edit" not in plan_allowed + assert "writing" not in plan_allowed + assert "research" not in plan_allowed + assert "sandbox_exec" not in plan_allowed + assert "sandbox_create" not in plan_allowed + assert "compute_select" not in plan_allowed + assert "compute_sync_up" not in plan_allowed + assert "compute_sync_down" not in plan_allowed + + async def test_plan_allowlist_has_no_phantom_entries(self): + """Every entry in the plan allowlist must match a real registered tool.""" + from openmlr.tools.registry import create_tool_router + + router = create_tool_router() + plan_allowed = MODE_TOOL_RESTRICTIONS["plan"]["allowed"] + registered_names = set(router.tools.keys()) + for tool_name in plan_allowed: + assert tool_name in registered_names, ( + f"Plan allowlist contains phantom tool '{tool_name}' " + f"that is not registered in the ToolRouter" + ) diff --git a/backend/tests/test_tools_huggingface.py b/backend/tests/test_tools_huggingface.py new file mode 100644 index 0000000..5ca1a59 --- /dev/null +++ b/backend/tests/test_tools_huggingface.py @@ -0,0 +1,553 @@ +"""Tests for Hugging Face Hub tools — tool specs, headers, and handler logic.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from openmlr.tools.huggingface import ( + _fetch_readme, + _handle_dataset_info, + _handle_model_info, + _handle_read_file, + _handle_search_datasets, + _handle_search_models, + _headers, + create_huggingface_tools, +) + +pytestmark = pytest.mark.asyncio + + +# --------------------------------------------------------------------------- +# Tool factory +# --------------------------------------------------------------------------- + + +class TestCreateHuggingfaceTools: + async def test_creates_all_tools(self): + tools = create_huggingface_tools() + names = [t.name for t in tools] + assert "hf_search_models" in names + assert "hf_model_info" in names + assert "hf_search_datasets" in names + assert "hf_dataset_info" in names + assert "hf_read_file" in names + assert len(tools) == 5 + + async def test_all_tools_have_handlers(self): + tools = create_huggingface_tools() + for tool in tools: + assert tool.handler is not None, f"{tool.name} has no handler" + + async def test_all_tools_have_descriptions(self): + tools = create_huggingface_tools() + for tool in tools: + assert len(tool.description) > 10, f"{tool.name} description too short" + + async def test_all_tools_have_valid_parameters(self): + tools = create_huggingface_tools() + for tool in tools: + assert tool.parameters["type"] == "object" + assert "properties" in tool.parameters + assert "required" in tool.parameters + + +class TestSearchModelsSpec: + async def test_required_params(self): + tools = create_huggingface_tools() + tool = [t for t in tools if t.name == "hf_search_models"][0] + assert "query" in tool.parameters["required"] + + async def test_optional_params(self): + tools = create_huggingface_tools() + tool = [t for t in tools if t.name == "hf_search_models"][0] + props = tool.parameters["properties"] + assert "pipeline_tag" in props + assert "library" in props + assert "sort" in props + assert "limit" in props + + async def test_sort_enum(self): + tools = create_huggingface_tools() + tool = [t for t in tools if t.name == "hf_search_models"][0] + sort_prop = tool.parameters["properties"]["sort"] + assert "enum" in sort_prop + assert "downloads" in sort_prop["enum"] + assert "likes" in sort_prop["enum"] + assert "trending" in sort_prop["enum"] + + +class TestModelInfoSpec: + async def test_required_params(self): + tools = create_huggingface_tools() + tool = [t for t in tools if t.name == "hf_model_info"][0] + assert "repo_id" in tool.parameters["required"] + + async def test_include_readme_param(self): + tools = create_huggingface_tools() + tool = [t for t in tools if t.name == "hf_model_info"][0] + assert "include_readme" in tool.parameters["properties"] + assert tool.parameters["properties"]["include_readme"]["type"] == "boolean" + + +class TestSearchDatasetsSpec: + async def test_required_params(self): + tools = create_huggingface_tools() + tool = [t for t in tools if t.name == "hf_search_datasets"][0] + assert "query" in tool.parameters["required"] + + async def test_task_filter_param(self): + tools = create_huggingface_tools() + tool = [t for t in tools if t.name == "hf_search_datasets"][0] + assert "task" in tool.parameters["properties"] + + +class TestDatasetInfoSpec: + async def test_required_params(self): + tools = create_huggingface_tools() + tool = [t for t in tools if t.name == "hf_dataset_info"][0] + assert "repo_id" in tool.parameters["required"] + + +class TestReadFileSpec: + async def test_required_params(self): + tools = create_huggingface_tools() + tool = [t for t in tools if t.name == "hf_read_file"][0] + required = tool.parameters["required"] + assert "repo_id" in required + assert "path" in required + + async def test_repo_type_enum(self): + tools = create_huggingface_tools() + tool = [t for t in tools if t.name == "hf_read_file"][0] + repo_type = tool.parameters["properties"]["repo_type"] + assert "enum" in repo_type + assert "model" in repo_type["enum"] + assert "dataset" in repo_type["enum"] + assert "space" in repo_type["enum"] + + async def test_revision_param(self): + tools = create_huggingface_tools() + tool = [t for t in tools if t.name == "hf_read_file"][0] + assert "revision" in tool.parameters["properties"] + + +# --------------------------------------------------------------------------- +# Headers +# --------------------------------------------------------------------------- + + +class TestHeaders: + async def test_headers_without_token(self, monkeypatch): + monkeypatch.delenv("HF_TOKEN", raising=False) + h = _headers() + assert "Authorization" not in h + + async def test_headers_with_token(self, monkeypatch): + monkeypatch.setenv("HF_TOKEN", "hf_test123") + h = _headers() + assert h["Authorization"] == "Bearer hf_test123" + + async def test_headers_returns_dict(self, monkeypatch): + monkeypatch.delenv("HF_TOKEN", raising=False) + h = _headers() + assert isinstance(h, dict) + + +# --------------------------------------------------------------------------- +# Handler tests (mocked HTTP) +# --------------------------------------------------------------------------- + + +def _mock_response(status_code=200, json_data=None, text="", headers=None): + """Create a mock httpx.Response.""" + resp = MagicMock() + resp.status_code = status_code + resp.json.return_value = json_data or {} + resp.text = text + resp.headers = headers or {} + resp.content = text.encode() if isinstance(text, str) else b"" + return resp + + +class TestHandleSearchModels: + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_successful_search(self, mock_fetch): + mock_fetch.return_value = _mock_response( + json_data=[ + { + "modelId": "meta-llama/Llama-3-8B", + "downloads": 1000000, + "likes": 500, + "pipeline_tag": "text-generation", + "tags": ["transformers", "safetensors", "pytorch"], + }, + { + "modelId": "mistralai/Mistral-7B-v0.1", + "downloads": 500000, + "likes": 300, + "pipeline_tag": "text-generation", + "tags": ["transformers"], + }, + ] + ) + result, success = await _handle_search_models("llama") + assert success is True + assert "meta-llama/Llama-3-8B" in result + assert "mistralai/Mistral-7B-v0.1" in result + assert "1,000,000 downloads" in result + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_empty_results(self, mock_fetch): + mock_fetch.return_value = _mock_response(json_data=[]) + result, success = await _handle_search_models("nonexistent_model_xyz") + assert success is True + assert "No models found" in result + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_api_error(self, mock_fetch): + mock_fetch.return_value = _mock_response(status_code=500, text="Internal Server Error") + result, success = await _handle_search_models("test") + assert success is False + assert "error" in result.lower() + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_rate_limit(self, mock_fetch): + from openmlr.tools.http_utils import RateLimitError + + mock_fetch.side_effect = RateLimitError() + result, success = await _handle_search_models("test") + assert success is False + assert "rate limit" in result.lower() + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_network_error(self, mock_fetch): + mock_fetch.side_effect = Exception("Connection refused") + result, success = await _handle_search_models("test") + assert success is False + assert "error" in result.lower() + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_passes_pipeline_tag(self, mock_fetch): + mock_fetch.return_value = _mock_response(json_data=[]) + await _handle_search_models("llama", pipeline_tag="text-generation") + call_kwargs = mock_fetch.call_args + assert call_kwargs.kwargs["params"]["pipeline_tag"] == "text-generation" + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_passes_library_filter(self, mock_fetch): + mock_fetch.return_value = _mock_response(json_data=[]) + await _handle_search_models("llama", library="transformers") + call_kwargs = mock_fetch.call_args + assert call_kwargs.kwargs["params"]["library"] == "transformers" + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_respects_limit(self, mock_fetch): + mock_fetch.return_value = _mock_response(json_data=[]) + await _handle_search_models("llama", limit=5) + call_kwargs = mock_fetch.call_args + assert call_kwargs.kwargs["params"]["limit"] == 5 + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_limit_capped_at_50(self, mock_fetch): + mock_fetch.return_value = _mock_response(json_data=[]) + await _handle_search_models("llama", limit=100) + call_kwargs = mock_fetch.call_args + assert call_kwargs.kwargs["params"]["limit"] == 50 + + +class TestHandleModelInfo: + @patch("openmlr.tools.huggingface._fetch_readme") + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_successful_info(self, mock_fetch, mock_readme): + mock_fetch.return_value = _mock_response( + json_data={ + "modelId": "meta-llama/Llama-3-8B", + "pipeline_tag": "text-generation", + "library_name": "transformers", + "downloads": 1000000, + "likes": 500, + "author": "meta-llama", + "lastModified": "2024-01-15T00:00:00Z", + "tags": ["transformers", "safetensors", "text-generation"], + "siblings": [ + {"rfilename": "config.json"}, + {"rfilename": "model.safetensors"}, + ], + } + ) + mock_readme.return_value = "This is a model card." + result, success = await _handle_model_info("meta-llama/Llama-3-8B") + assert success is True + assert "meta-llama/Llama-3-8B" in result + assert "text-generation" in result + assert "transformers" in result + assert "1,000,000" in result + assert "config.json" in result + assert "This is a model card." in result + + @patch("openmlr.tools.huggingface._fetch_readme") + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_not_found(self, mock_fetch, mock_readme): + mock_fetch.return_value = _mock_response(status_code=404) + result, success = await _handle_model_info("nonexistent/model") + assert success is False + assert "not found" in result.lower() + + @patch("openmlr.tools.huggingface._fetch_readme") + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_skip_readme(self, mock_fetch, mock_readme): + mock_fetch.return_value = _mock_response( + json_data={"modelId": "test/model", "downloads": 0, "likes": 0, "tags": []} + ) + await _handle_model_info("test/model", include_readme=False) + mock_readme.assert_not_called() + + @patch("openmlr.tools.huggingface._fetch_readme") + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_rate_limit(self, mock_fetch, mock_readme): + from openmlr.tools.http_utils import RateLimitError + + mock_fetch.side_effect = RateLimitError() + result, success = await _handle_model_info("test/model") + assert success is False + assert "rate limit" in result.lower() + + @patch("openmlr.tools.huggingface._fetch_readme") + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_truncates_long_output(self, mock_fetch, mock_readme): + mock_fetch.return_value = _mock_response( + json_data={"modelId": "test/model", "downloads": 0, "likes": 0, "tags": []} + ) + mock_readme.return_value = "x" * 60000 + result, success = await _handle_model_info("test/model") + assert success is True + assert len(result) <= 50100 # 50000 + "[truncated]" overhead + assert "truncated" in result.lower() + + +class TestHandleSearchDatasets: + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_successful_search(self, mock_fetch): + mock_fetch.return_value = _mock_response( + json_data=[ + { + "id": "tatsu-lab/alpaca", + "downloads": 200000, + "likes": 150, + "description": "A dataset for instruction tuning", + "tags": ["text-generation", "instruction-following"], + }, + ] + ) + result, success = await _handle_search_datasets("alpaca") + assert success is True + assert "tatsu-lab/alpaca" in result + assert "200,000 downloads" in result + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_empty_results(self, mock_fetch): + mock_fetch.return_value = _mock_response(json_data=[]) + result, success = await _handle_search_datasets("nonexistent_dataset") + assert success is True + assert "No datasets found" in result + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_passes_task_filter(self, mock_fetch): + mock_fetch.return_value = _mock_response(json_data=[]) + await _handle_search_datasets("test", task="text-classification") + call_kwargs = mock_fetch.call_args + assert call_kwargs.kwargs["params"]["task_categories"] == "text-classification" + + +class TestHandleDatasetInfo: + @patch("openmlr.tools.huggingface._fetch_readme") + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_successful_info(self, mock_fetch, mock_readme): + mock_fetch.return_value = _mock_response( + json_data={ + "id": "tatsu-lab/alpaca", + "downloads": 200000, + "likes": 150, + "author": "tatsu-lab", + "tags": ["text-generation"], + "description": "Instruction tuning dataset", + "siblings": [{"rfilename": "data.json"}], + } + ) + mock_readme.return_value = "Dataset card content." + result, success = await _handle_dataset_info("tatsu-lab/alpaca") + assert success is True + assert "tatsu-lab/alpaca" in result + assert "200,000" in result + assert "Dataset card content." in result + + @patch("openmlr.tools.huggingface._fetch_readme") + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_not_found(self, mock_fetch, mock_readme): + mock_fetch.return_value = _mock_response(status_code=404) + result, success = await _handle_dataset_info("nonexistent/dataset") + assert success is False + assert "not found" in result.lower() + + +class TestHandleReadFile: + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_read_json_file(self, mock_fetch): + mock_fetch.return_value = _mock_response( + text='{"hidden_size": 4096}', + headers={"content-type": "application/json"}, + ) + result, success = await _handle_read_file("meta-llama/Llama-3-8B", "config.json") + assert success is True + assert "hidden_size" in result + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_read_python_file_with_line_numbers(self, mock_fetch): + mock_fetch.return_value = _mock_response( + text="import torch\nprint('hello')", + headers={"content-type": "text/plain"}, + ) + result, success = await _handle_read_file("test/repo", "train.py") + assert success is True + assert "1: import torch" in result + assert "2: print('hello')" in result + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_binary_file_detection(self, mock_fetch): + mock_fetch.return_value = _mock_response( + text="", + headers={"content-type": "application/octet-stream"}, + ) + result, success = await _handle_read_file("test/repo", "model.safetensors") + assert success is True + assert "Binary file" in result + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_binary_extension_detection(self, mock_fetch): + mock_fetch.return_value = _mock_response( + text="", + headers={"content-type": "text/plain"}, + ) + result, success = await _handle_read_file("test/repo", "weights.bin") + assert success is True + assert "Binary file" in result + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_file_not_found(self, mock_fetch): + mock_fetch.return_value = _mock_response(status_code=404) + result, success = await _handle_read_file("test/repo", "missing.txt") + assert success is False + assert "not found" in result.lower() + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_gated_model_401(self, mock_fetch): + mock_fetch.return_value = _mock_response(status_code=401) + result, success = await _handle_read_file("gated/model", "config.json") + assert success is False + assert "Access denied" in result or "gated" in result.lower() + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_forbidden_403(self, mock_fetch): + mock_fetch.return_value = _mock_response(status_code=403) + result, success = await _handle_read_file("gated/model", "config.json") + assert success is False + assert "Forbidden" in result or "license" in result.lower() + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_model_url_format(self, mock_fetch): + mock_fetch.return_value = _mock_response(text="{}", headers={}) + await _handle_read_file("org/model", "config.json", repo_type="model") + url = mock_fetch.call_args.args[0] + assert url == "https://huggingface.co/org/model/resolve/main/config.json" + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_dataset_url_format(self, mock_fetch): + mock_fetch.return_value = _mock_response(text="{}", headers={}) + await _handle_read_file("org/dataset", "data.json", repo_type="dataset") + url = mock_fetch.call_args.args[0] + assert url == "https://huggingface.co/datasets/org/dataset/resolve/main/data.json" + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_space_url_format(self, mock_fetch): + mock_fetch.return_value = _mock_response(text="{}", headers={}) + await _handle_read_file("org/space", "app.py", repo_type="space") + url = mock_fetch.call_args.args[0] + assert url == "https://huggingface.co/spaces/org/space/resolve/main/app.py" + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_custom_revision(self, mock_fetch): + mock_fetch.return_value = _mock_response(text="{}", headers={}) + await _handle_read_file("org/model", "config.json", revision="v2.0") + url = mock_fetch.call_args.args[0] + assert "/resolve/v2.0/" in url + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_truncates_long_content(self, mock_fetch): + mock_fetch.return_value = _mock_response( + text="x" * 60000, + headers={"content-type": "text/plain"}, + ) + result, success = await _handle_read_file("test/repo", "big.json") + assert success is True + assert len(result) <= 50100 + assert "truncated" in result.lower() + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_rate_limit(self, mock_fetch): + from openmlr.tools.http_utils import RateLimitError + + mock_fetch.side_effect = RateLimitError() + result, success = await _handle_read_file("test/repo", "file.txt") + assert success is False + assert "rate limit" in result.lower() + + +# --------------------------------------------------------------------------- +# Helper: _fetch_readme +# --------------------------------------------------------------------------- + + +class TestFetchReadme: + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_fetches_model_readme(self, mock_fetch): + mock_fetch.return_value = _mock_response(text="# My Model\n\nThis is a model.") + result = await _fetch_readme("org/model", "model") + assert result == "# My Model\n\nThis is a model." + url = mock_fetch.call_args.args[0] + assert url == "https://huggingface.co/org/model/resolve/main/README.md" + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_fetches_dataset_readme(self, mock_fetch): + mock_fetch.return_value = _mock_response(text="# My Dataset") + result = await _fetch_readme("org/dataset", "dataset") + url = mock_fetch.call_args.args[0] + assert url == "https://huggingface.co/datasets/org/dataset/resolve/main/README.md" + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_strips_yaml_frontmatter(self, mock_fetch): + mock_fetch.return_value = _mock_response( + text="---\nlanguage: en\ntags:\n- llm\n---\n# Model Card\n\nContent here." + ) + result = await _fetch_readme("org/model", "model") + assert result.startswith("# Model Card") + assert "language: en" not in result + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_returns_none_on_404(self, mock_fetch): + mock_fetch.return_value = _mock_response(status_code=404) + result = await _fetch_readme("org/model", "model") + assert result is None + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_returns_none_on_exception(self, mock_fetch): + mock_fetch.side_effect = Exception("Network error") + result = await _fetch_readme("org/model", "model") + assert result is None + + @patch("openmlr.tools.huggingface.fetch_with_retry") + async def test_truncates_long_readme(self, mock_fetch): + mock_fetch.return_value = _mock_response(text="x" * 40000) + result = await _fetch_readme("org/model", "model") + assert len(result) <= 30100 + assert "truncated" in result.lower() diff --git a/backend/tests/test_tools_research.py b/backend/tests/test_tools_research.py index 35eddd4..f7e698c 100644 --- a/backend/tests/test_tools_research.py +++ b/backend/tests/test_tools_research.py @@ -49,6 +49,18 @@ def test_github_read_included(self): names = [s["function"]["name"] for s in specs] assert "github_read_file" in names or "github_find_examples" in names + def test_hf_tools_included(self): + specs = _get_research_tool_specs() + names = [s["function"]["name"] for s in specs] + assert "hf_search_models" in names + assert "hf_read_file" in names + + def test_hf_tools_limited_to_subset(self): + """Only hf_search_models and hf_read_file should be in research, not all 5.""" + specs = _get_research_tool_specs() + hf_names = [s["function"]["name"] for s in specs if s["function"]["name"].startswith("hf_")] + assert set(hf_names) == {"hf_search_models", "hf_read_file"} + class TestExecuteResearchTool: @pytest.mark.asyncio @@ -58,6 +70,33 @@ async def test_unknown_tool(self): assert success is False assert "not available" in result + @pytest.mark.asyncio + async def test_hf_search_models_dispatches(self): + """Verify hf_search_models is a recognized tool in the research dispatcher.""" + from unittest.mock import AsyncMock, patch + + tc = ToolCall(id="tc2", name="hf_search_models", arguments={"query": "test"}) + # Patch the source handler (imported inside _execute_research_tool at call time) + with patch( + "openmlr.tools.huggingface._handle_search_models", + new_callable=AsyncMock, + return_value=("mocked result", True), + ): + result, success = await _execute_research_tool(tc) + assert success is True + assert result == "mocked result" + + @pytest.mark.asyncio + async def test_hf_read_file_dispatches(self): + """Verify hf_read_file is a recognized tool in the research dispatcher.""" + tc = ToolCall( + id="tc3", name="hf_read_file", arguments={"repo_id": "test/repo", "path": "config.json"} + ) + # This will try to actually call the handler which will fail network-wise, + # but it should NOT return "not available" + result, success = await _execute_research_tool(tc) + assert "not available" not in result + def test_system_prompt_not_empty(self): assert len(RESEARCH_SYSTEM_PROMPT) > 0 assert "research sub-agent" in RESEARCH_SYSTEM_PROMPT.lower() diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 2a71f8e..334579b 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -116,6 +116,11 @@ function ChatUI({ const currentConvUuidRef = useRef(currentConvUuid); currentConvUuidRef.current = currentConvUuid; + // Sequence counter to discard stale switchConv responses after rapid switching + const switchSeqRef = useRef(0); + // Timer ref so pending reload timeouts can be cleared on conversation switch + const reloadTimerRef = useRef | null>(null); + // ── Derived per-conversation processing state ───────── const currentStatus = currentConvUuid ? (convStatuses[currentConvUuid] || 'idle') : 'idle'; const isProcessing = currentStatus === 'processing'; @@ -206,9 +211,15 @@ function ChatUI({ }, []); const switchConv = async (uuid: string) => { + // Increment sequence counter; any in-flight switch with an older seq is stale + const seq = ++switchSeqRef.current; + // Cancel any pending reload timer from a previous conversation's job_complete + if (reloadTimerRef.current) { clearTimeout(reloadTimerRef.current); reloadTimerRef.current = null; } try { await api.switchConversation(uuid); + if (seq !== switchSeqRef.current) return; // stale — a newer switch started const data = await api.getConversation(uuid); + if (seq !== switchSeqRef.current) return; // stale — a newer switch started setCurrentConvUuid(uuid); // Only update model if conversation has one explicitly set; don't overwrite the user's sticky model if (data.conversation?.model) setModel(data.conversation.model); @@ -238,13 +249,14 @@ function ChatUI({ }) || []); // Load active compute for this conversation - await loadActiveCompute(uuid); + if (seq === switchSeqRef.current) await loadActiveCompute(uuid); } catch { /* */ } }; - const handleSwitchConversation = async (uuid: string) => { + const handleSwitchConversation = (uuid: string) => { + // Only navigate; the routeUuid useEffect will trigger switchConv, + // avoiding the previous double-call race condition navigate(`/${uuid}`, { replace: true }); - await switchConv(uuid); }; const handleNewConversation = async () => { @@ -292,6 +304,8 @@ function ChatUI({ const reloadConversationMessages = useCallback(async (uuid: string) => { try { const data = await api.getConversation(uuid); + // Guard: only apply if this is still the active conversation + if (uuid !== currentConvUuidRef.current) return; if (data.messages) { setMessages(data.messages.map((m: any) => { if (m.role === 'tool') { @@ -509,7 +523,7 @@ function ChatUI({ setMessages((prev) => [...prev, { id: nextId(), role: 'error', content: `Job failed: ${error}` }]); } if (status === 'completed' && uuid) { - setTimeout(() => reloadConversationMessages(uuid), 500); + reloadTimerRef.current = setTimeout(() => { reloadTimerRef.current = null; reloadConversationMessages(uuid); }, 500); } } loadConversations(); @@ -635,10 +649,11 @@ function ChatUI({ {approvalEvent && setApprovalEvent(null)} />} - {questionsPayload && { + {questionsPayload && { setQuestionsPayload(null); setCurrentConvStatus('processing'); - setMessages((prev) => [...prev, { id: nextId(), role: 'user', content: `Answered:\n${summary}` }]); + setMessages((prev) => [...prev, { id: nextId(), role: 'user', content: `Answered:\n${summary}` }]); + if (switchToExecute) setInputMode('execute'); }} onClose={() => setQuestionsPayload(null)} />} }> - } /> - } /> + } /> }> } /> } /> diff --git a/frontend/src/__tests__/InputArea.test.tsx b/frontend/src/__tests__/InputArea.test.tsx index df46043..eb098b2 100644 --- a/frontend/src/__tests__/InputArea.test.tsx +++ b/frontend/src/__tests__/InputArea.test.tsx @@ -48,8 +48,8 @@ describe('InputArea', () => { {...defaultProps({ text: 'hello', onSend, onTextChange })} />, ); - // Send button now uses Lucide icon and title attribute - const sendBtn = screen.getByTitle('Send message'); + // Send button has mode-specific title + const sendBtn = screen.getByTitle('Send in Plan mode (Enter)'); fireEvent.click(sendBtn); expect(onSend).toHaveBeenCalledWith('hello', 'plan'); expect(onTextChange).toHaveBeenCalledWith(''); @@ -88,10 +88,36 @@ describe('InputArea', () => { it('empty text disables send button', () => { render(); - const sendBtn = screen.getByTitle('Send message'); + const sendBtn = screen.getByTitle('Send in Plan mode (Enter)'); expect(sendBtn).toBeDisabled(); }); + it('shows Send & Execute button in plan mode', () => { + render(); + const execBtn = screen.getByTitle('Send & switch to Execute mode (Cmd+Enter)'); + expect(execBtn).toBeInTheDocument(); + }); + + it('Send & Execute button not shown in execute mode', () => { + render(); + const execBtn = screen.queryByTitle('Send & switch to Execute mode (Cmd+Enter)'); + expect(execBtn).toBeNull(); + }); + + it('Send & Execute calls onSend with execute mode and switches', () => { + const onSend = vi.fn(); + const onModeChange = vi.fn(); + const onTextChange = vi.fn(); + render( + + ); + const execBtn = screen.getByTitle('Send & switch to Execute mode (Cmd+Enter)'); + fireEvent.click(execBtn); + expect(onModeChange).toHaveBeenCalledWith('execute'); + expect(onSend).toHaveBeenCalledWith('do it', 'execute'); + expect(onTextChange).toHaveBeenCalledWith(''); + }); + it('keyboard shortcut Cmd+M toggles from execute to plan', () => { const onModeChange = vi.fn(); render(); diff --git a/frontend/src/components/InputArea.tsx b/frontend/src/components/InputArea.tsx index 95e8860..d3f693d 100644 --- a/frontend/src/components/InputArea.tsx +++ b/frontend/src/components/InputArea.tsx @@ -1,5 +1,5 @@ import { useRef, useEffect, useCallback, useState } from 'react'; -import { ArrowUp, Square } from 'lucide-react'; +import { ArrowUp, Square, Play } from 'lucide-react'; export type Mode = 'plan' | 'execute'; @@ -34,22 +34,37 @@ export function InputArea({ disabled, showStop, mode, onModeChange, onSend, onSt if (textareaRef.current) textareaRef.current.style.height = 'auto'; }, [text, disabled, onSend, mode, onTextChange]); + /** Send message AND switch to execute mode in one action. */ + const submitAndExecute = useCallback(() => { + const trimmed = text.trim(); + if (!trimmed || disabled) return; + onTextChange(''); + onModeChange('execute'); + onSend(trimmed, 'execute'); + if (textareaRef.current) textareaRef.current.style.height = 'auto'; + }, [text, disabled, onSend, onModeChange, onTextChange]); + const toggleMode = useCallback(() => { onModeChange(mode === 'plan' ? 'execute' : 'plan'); }, [mode, onModeChange]); - // Keyboard shortcut: Cmd+M (or Ctrl+M) toggles between Plan and Execute + // Keyboard shortcuts useEffect(() => { const handler = (e: KeyboardEvent) => { if (!(e.metaKey || e.ctrlKey)) return; if (e.key === 'm' || e.key === 'M') { + // Cmd+M: toggle mode e.preventDefault(); onModeChange(mode === 'plan' ? 'execute' : 'plan'); + } else if (e.key === 'Enter' && mode === 'plan') { + // Cmd+Enter in plan mode: send & execute + e.preventDefault(); + submitAndExecute(); } }; window.addEventListener('keydown', handler); return () => window.removeEventListener('keydown', handler); - }, [mode, onModeChange]); + }, [mode, onModeChange, submitAndExecute]); const isPlan = mode === 'plan'; @@ -117,20 +132,36 @@ export function InputArea({ disabled, showStop, mode, onModeChange, onSend, onSt )} - {/* Send button - same height as mode toggle */} + {/* Send buttons */} {!disabled && ( - +
+ {/* Primary send: uses current mode */} + + + {/* Send & Execute: visible only in Plan mode */} + {isPlan && ( + + )} +
)} diff --git a/frontend/src/components/QuestionDrawer.tsx b/frontend/src/components/QuestionDrawer.tsx index 3823487..7b1b487 100644 --- a/frontend/src/components/QuestionDrawer.tsx +++ b/frontend/src/components/QuestionDrawer.tsx @@ -1,11 +1,11 @@ import { useState } from 'react'; -import { X, ChevronLeft, ChevronRight, Send } from 'lucide-react'; +import { X, ChevronLeft, ChevronRight, Send, Play } from 'lucide-react'; import { api } from '../api'; import type { QuestionsPayload } from '../types'; interface Props { payload: QuestionsPayload; - onDone: (summary: string, suggestedMode?: string) => void; + onDone: (summary: string, switchToExecute?: boolean) => void; onClose: () => void; } @@ -45,12 +45,12 @@ export function QuestionDrawer({ payload, onDone, onClose }: Props) { } }; - const submit = async () => { + const doSubmit = async (switchToExecute: boolean) => { setSubmitting(true); try { await api.submitAnswers(answers); const lines = questions.map((q) => `${q.question} → ${answers[q.id]}`); - onDone(lines.join('\n'), suggest_mode || undefined); + onDone(lines.join('\n'), switchToExecute); } catch { /* */ } finally { setSubmitting(false); } }; @@ -147,14 +147,30 @@ export function QuestionDrawer({ payload, onDone, onClose }: Props) { )} {allAnswered && ( - +
+ {/* Submit: stays in current mode */} + + {/* Submit & Execute: submits answers AND switches to execute mode */} + +
)}