Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ description = "noot library"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"anthropic>=0.76.0",
"litellm",
"mitmproxy",
"rich",
]
Expand Down
35 changes: 32 additions & 3 deletions src/noot/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def cmd_init(args):
from rich.console import Console
from rich.prompt import Prompt

from noot.init import init_project
from noot.init import LLM_PROVIDERS, init_project

console = Console()

Expand All @@ -41,8 +41,29 @@ def cmd_init(args):
"Must start with a letter.[/red]"
)

# LLM provider selection
if args.provider:
provider_key = args.provider
if provider_key not in LLM_PROVIDERS:
console.print(f"[red]Unknown provider: {provider_key}[/red]")
console.print(f"[dim]Available: {', '.join(LLM_PROVIDERS.keys())}[/dim]")
sys.exit(1)
else:
console.print("\n[bold]Select your LLM provider:[/bold]")
for i, (key, config) in enumerate(LLM_PROVIDERS.items(), 1):
console.print(f" {i}. {config['name']}")

choice = Prompt.ask(
"[green]Provider[/green]",
choices=[str(i) for i in range(1, len(LLM_PROVIDERS) + 1)],
default="1",
)
provider_key = list(LLM_PROVIDERS.keys())[int(choice) - 1]

provider = LLM_PROVIDERS[provider_key]

try:
init_project(Path.cwd(), project_name)
init_project(Path.cwd(), project_name, model=provider["model"])
msg = f"[bold green]Project '{project_name}' initialized![/bold green]"
console.print(f"\n{msg}")
console.print("\n[dim]Created:[/dim]")
Expand All @@ -56,7 +77,11 @@ def cmd_init(args):
console.print("\n[dim]Next steps:[/dim]")
console.print(f" 1. Edit cli/{project_name}.py with your CLI")
console.print(f" 2. Edit tests/test_{project_name}.py with your tests")
console.print(" 3. Set ANTHROPIC_API_KEY environment variable")
if provider["env_var"]:
console.print(f" 3. Set {provider['env_var']} environment variable")
console.print(f" [dim]{provider['env_hint']}[/dim]")
else:
console.print(f" 3. {provider['env_hint']}")
except Exception as e:
console.print(f"[red]Error: {e}[/red]")
sys.exit(1)
Expand All @@ -78,6 +103,10 @@ def main():
"--name", "-n",
help="Project name (skips interactive prompt)"
)
init_parser.add_argument(
"--provider", "-p",
help="LLM provider (anthropic, openai, gemini, groq, ollama)"
)
init_parser.set_defaults(func=cmd_init)

args = parser.parse_args()
Expand Down
56 changes: 56 additions & 0 deletions src/noot/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Configuration loading for noot projects."""

from __future__ import annotations

import tomllib
from pathlib import Path

DEFAULT_MODEL = "anthropic/claude-sonnet-4-20250514"


def find_pyproject_toml(start_dir: Path | None = None) -> Path | None:
"""
Find pyproject.toml by searching upward from start_dir.

Args:
start_dir: Directory to start searching from. Defaults to cwd.

Returns:
Path to pyproject.toml or None if not found.
"""
if start_dir is None:
start_dir = Path.cwd()

current = start_dir.resolve()
while current != current.parent:
pyproject = current / "pyproject.toml"
if pyproject.exists():
return pyproject
current = current.parent

return None


def get_project_model(start_dir: Path | None = None) -> str:
"""
Get the configured model from pyproject.toml.

Looks for [tool.noot] section with a 'model' key.
Falls back to DEFAULT_MODEL if not found.

Args:
start_dir: Directory to start searching for pyproject.toml.

Returns:
Model string (e.g., "anthropic/claude-sonnet-4-20250514").
"""
pyproject_path = find_pyproject_toml(start_dir)
if pyproject_path is None:
return DEFAULT_MODEL

try:
with open(pyproject_path, "rb") as f:
data = tomllib.load(f)
return data.get("tool", {}).get("noot", {}).get("model", DEFAULT_MODEL)
except Exception:
return DEFAULT_MODEL
14 changes: 9 additions & 5 deletions src/noot/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from noot.assertions import execute_assertion
from noot.cache import Cache
from noot.config import get_project_model
from noot.llm import LLM
from noot.mitmproxy_manager import MitmproxyConfig, MitmproxyManager
from noot.step import StepResult, execute_step
Expand All @@ -28,7 +29,7 @@ class Flow:
def __init__(
self,
command: str | None = None,
model: str = "claude-sonnet-4-20250514",
model: str | None = None,
pane_width: int = 120,
pane_height: int = 40,
stability_timeout: float = 5.0,
Expand All @@ -41,7 +42,8 @@ def __init__(

Args:
command: Initial command to run (e.g., "spx init")
model: Anthropic model to use for step interpretation
model: LLM model string to override project config.
If None, reads from pyproject.toml [tool.noot] section.
pane_width: Terminal width
pane_height: Terminal height
stability_timeout: Default timeout for waiting for terminal stability
Expand All @@ -55,7 +57,8 @@ def __init__(
self._terminal = Terminal(pane_width=pane_width, pane_height=pane_height)
cassette_path = Path(cassette) if cassette else None
self._cache = Cache.from_env(cassette_path)
self._llm = LLM(model=model, cache=self._cache)
resolved_model = model if model is not None else get_project_model()
self._llm = LLM(model=resolved_model, cache=self._cache)
self._stability_timeout = stability_timeout
self._steps: list[StepResult] = []

Expand All @@ -76,7 +79,7 @@ def __init__(
def spawn(
cls,
command: str,
model: str = "claude-sonnet-4-20250514",
model: str | None = None,
pane_width: int = 120,
pane_height: int = 40,
stability_timeout: float = 5.0,
Expand All @@ -89,7 +92,8 @@ def spawn(

Args:
command: Command to run (e.g., "spx init")
model: Anthropic model for step interpretation
model: LLM model string to override project config.
If None, reads from pyproject.toml [tool.noot] section.
pane_width: Terminal width
pane_height: Terminal height
stability_timeout: Default timeout for stability
Expand Down
52 changes: 48 additions & 4 deletions src/noot/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,48 @@
import subprocess
from pathlib import Path


def init_project(target_dir: Path, project_name: str) -> None:
from noot.config import DEFAULT_MODEL

# LLM provider configurations
LLM_PROVIDERS = {
"anthropic": {
"name": "Anthropic (Claude)",
"model": "anthropic/claude-sonnet-4-20250514",
"env_var": "ANTHROPIC_API_KEY",
"env_hint": "Get your key at https://console.anthropic.com/",
},
"openai": {
"name": "OpenAI (GPT)",
"model": "openai/gpt-4o",
"env_var": "OPENAI_API_KEY",
"env_hint": "Get your key at https://platform.openai.com/api-keys",
},
"gemini": {
"name": "Google (Gemini)",
"model": "gemini/gemini-1.5-pro",
"env_var": "GEMINI_API_KEY",
"env_hint": "Get your key at https://aistudio.google.com/apikey",
},
"groq": {
"name": "Groq",
"model": "groq/llama-3.3-70b-versatile",
"env_var": "GROQ_API_KEY",
"env_hint": "Get your key at https://console.groq.com/keys",
},
"ollama": {
"name": "Ollama (Local)",
"model": "ollama/llama3.2",
"env_var": None,
"env_hint": "Make sure Ollama is running locally",
},
}


def init_project(
target_dir: Path,
project_name: str,
model: str = DEFAULT_MODEL,
) -> None:
"""
Initialize a noot project in the target directory.

Expand All @@ -20,6 +60,7 @@ def init_project(target_dir: Path, project_name: str) -> None:
Args:
target_dir: Directory to initialize (usually cwd)
project_name: Name of the project (lowercase, underscores allowed)
model: LLM model string in provider/model format

Raises:
FileExistsError: If files would be overwritten
Expand Down Expand Up @@ -64,9 +105,12 @@ def init_project(target_dir: Path, project_name: str) -> None:
(target_dir / ".cassettes" / "cli").mkdir(parents=True, exist_ok=True)
(target_dir / ".cassettes" / "http").mkdir(parents=True, exist_ok=True)

# Read templates and substitute project_name
# Read templates and substitute project_name and model
def sub(template: Path) -> str:
return template.read_text().replace("{{project_name}}", project_name)
content = template.read_text()
content = content.replace("{{project_name}}", project_name)
content = content.replace("{{model}}", model)
return content

pyproject_content = sub(pyproject_template)
cli_content = sub(cli_template)
Expand Down
69 changes: 41 additions & 28 deletions src/noot/llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Simple Anthropic API wrapper for step execution."""
"""Simple LLM wrapper for step execution using LiteLLM."""

from __future__ import annotations

Expand All @@ -10,13 +10,13 @@
from pathlib import Path
from typing import TYPE_CHECKING

import anthropic
from anthropic.types import TextBlock
import litellm

if TYPE_CHECKING:
from noot.cache import Cache

from noot.cache import CacheMissError
from noot.config import DEFAULT_MODEL

STEP_SYSTEM_PROMPT = """\
You are controlling an interactive CLI application. Given the current terminal screen
Expand Down Expand Up @@ -90,24 +90,33 @@


class LLM:
"""Simple Anthropic Claude wrapper with optional caching."""
"""LiteLLM wrapper with optional caching for any LLM provider."""

def __init__(
self, model: str = "claude-sonnet-4-20250514", cache: Cache | None = None
self, model: str = DEFAULT_MODEL, cache: Cache | None = None
):
self._model = model
self._cache = cache
self._logger = _get_logger()
self._client: anthropic.Anthropic | None = None

def _get_client(self) -> anthropic.Anthropic:
"""Lazily initialize the Anthropic client."""
if self._client is None:
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
raise ValueError("ANTHROPIC_API_KEY environment variable required")
self._client = anthropic.Anthropic(api_key=api_key)
return self._client
def _validate_api_key(self) -> None:
"""Validate that the required API key is set for the model's provider."""
provider = self._model.split("/")[0] if "/" in self._model else "anthropic"

env_var_map = {
"anthropic": "ANTHROPIC_API_KEY",
"openai": "OPENAI_API_KEY",
"gemini": "GEMINI_API_KEY",
"google": "GEMINI_API_KEY",
"mistral": "MISTRAL_API_KEY",
"groq": "GROQ_API_KEY",
"ollama": None, # Ollama doesn't need an API key
}

env_var = env_var_map.get(provider)
if env_var and not os.environ.get(env_var):
msg = f"{env_var} environment variable required for {provider} models"
raise ValueError(msg)

def complete(self, screen: str, instruction: str) -> str:
"""Send screen + instruction to LLM and get response."""
Expand Down Expand Up @@ -144,17 +153,19 @@ def complete(self, screen: str, instruction: str) -> str:
}
)
)
response = self._get_client().messages.create(

self._validate_api_key()
response = litellm.completion(
model=self._model,
max_tokens=256,
system=STEP_SYSTEM_PROMPT,
messages=[{"role": "user", "content": user_message}],
messages=[
{"role": "system", "content": STEP_SYSTEM_PROMPT},
{"role": "user", "content": user_message},
],
)

content_block = response.content[0]
if not isinstance(content_block, TextBlock):
raise TypeError("Expected TextBlock response")
response_text = content_block.text
content = response.choices[0].message.content # type: ignore[union-attr]
response_text: str = content if content is not None else ""
self._logger.info(
json.dumps(
{
Expand Down Expand Up @@ -217,17 +228,19 @@ def generate_assertion(self, screen: str, expected_state: str) -> str:
}
)
)
response = self._get_client().messages.create(

self._validate_api_key()
response = litellm.completion(
model=self._model,
max_tokens=512,
system=ASSERTION_SYSTEM_PROMPT,
messages=[{"role": "user", "content": user_message}],
messages=[
{"role": "system", "content": ASSERTION_SYSTEM_PROMPT},
{"role": "user", "content": user_message},
],
)

content_block = response.content[0]
if not isinstance(content_block, TextBlock):
raise TypeError("Expected TextBlock response")
assertion_code = content_block.text.strip()
content = response.choices[0].message.content # type: ignore[union-attr]
assertion_code: str = (content if content is not None else "").strip()

# Clean up code block formatting if present
if assertion_code.startswith("```"):
Expand Down
3 changes: 3 additions & 0 deletions src/noot/templates/pyproject.toml.template
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ dev = ["noot", "pytest"]

[tool.pytest.ini_options]
testpaths = ["tests"]

[tool.noot]
model = "{{model}}"
Loading