Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ repos:
- pydantic-settings>=2.0
- httpx>=0.27
- typing-extensions>=4.0
# Optional-extra deps required so the now-strictly-checked modules
# under ``locus.a2a``, ``locus.integrations``, ``locus.server`` etc.
# resolve their third-party imports.
- fastapi>=0.110
- fastmcp>=3.2.0
- mcp>=1.0
pass_filenames: false
entry: mypy src/locus

Expand Down
5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -525,16 +525,11 @@ ignore_missing_imports = true
module = [
"locus.rag.*",
"locus.memory.*",
"locus.integrations.*",
"locus.streaming.*",
"locus.loop.*",
"locus.playbooks.*",
"locus.server.*",
"locus.multiagent.graph",
"locus.hooks.builtin.*",
"locus.reasoning.*",
"locus.skills.*",
"locus.a2a.*",
]
ignore_errors = true

Expand Down
3 changes: 2 additions & 1 deletion src/locus/a2a/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import logging
import os
import uuid
from collections.abc import AsyncIterator
from typing import Any

from pydantic import BaseModel, Field
Expand Down Expand Up @@ -238,7 +239,7 @@ async def stream(
user_msgs = [m for m in request.messages if m.role == "user"]
prompt = user_msgs[-1].content if user_msgs else ""

async def event_generator():
async def event_generator() -> AsyncIterator[str]:
try:
async for event in agent.run(prompt):
if isinstance(event, ThinkEvent):
Expand Down
31 changes: 22 additions & 9 deletions src/locus/integrations/fastmcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import json
import re
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

from pydantic import BaseModel, ConfigDict, Field, create_model

Expand Down Expand Up @@ -68,7 +68,10 @@ def _json_schema_type_to_python(prop: dict[str, Any]) -> type[Any]:
"boolean": bool,
}

return mapping.get(schema_type, Any)
# ``Any`` is a typing-special-form, not a runtime type, so mypy can't
# accept it as the dict default. Pydantic accepts it as a field type
# at runtime, so we tag the line.
return mapping.get(schema_type, Any) # type: ignore[arg-type]


def build_args_model(tool_name: str, schema: dict[str, Any] | None) -> type[BaseModel] | None:
Expand Down Expand Up @@ -315,7 +318,7 @@ def _create_mcp(self) -> FastMCP:
async def run_agent(prompt: str) -> str:
"""Run the Locus agent with a prompt and return the response."""
result = agent.run_sync(prompt)
return result.message
return str(result.message)

# Register a streaming version
@mcp.tool()
Expand All @@ -327,17 +330,17 @@ async def run_agent_stream(prompt: str) -> str:
# Return the final message from the last event
for event in reversed(events):
if hasattr(event, "final_message") and event.final_message:
return event.final_message
return str(event.final_message)
return "Agent completed without response"

return mcp

def run(self, transport: str = "stdio") -> None:
def run(self, transport: Literal["stdio", "http", "sse", "streamable-http"] = "stdio") -> None:
"""
Run the MCP server.

Args:
transport: Transport type ("stdio" or "sse")
transport: Transport type ("stdio", "http", "sse", or "streamable-http").
"""
if self._mcp is None:
self._mcp = self._create_mcp()
Expand Down Expand Up @@ -491,6 +494,9 @@ async def connect(self) -> None:

async def _connect_http(self) -> None:
"""Connect via HTTP/SSE transport."""
if self.base_url is None:
msg = "_connect_http called without base_url"
raise RuntimeError(msg)
try:
from mcp.client.session import ClientSession
from mcp.client.streamable_http import streamablehttp_client
Expand All @@ -517,7 +523,9 @@ class BearerAuth(httpx.Auth):
def __init__(self, token: str):
self.token = token

def auth_flow(self, request):
def auth_flow( # type: ignore[no-untyped-def]
self, request
): # httpx.Auth.auth_flow signature varies across SDK versions
request.headers["Authorization"] = f"Bearer {self.token}"
yield request

Expand Down Expand Up @@ -647,12 +655,17 @@ async def close(self) -> None:
pass
self._client_context = None

async def __aenter__(self):
async def __aenter__(self) -> MCPClient:
"""Async context manager entry."""
await self.connect()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: Any,
) -> None:
"""Async context manager exit."""
await self.close()

Expand Down
16 changes: 12 additions & 4 deletions src/locus/playbooks/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,12 @@ def enforcer(self) -> PlaybookEnforcer:

async def on_before_tool_call(self, event: BeforeToolCallEvent) -> None:
"""Validate the call against the current step; cancel on violation."""
result = self._enforcer.validate_tool_call(event.tool_name)
# ``ProtectedEvent`` (the base class for hook events) sets fields
# via ``self._init(name, value)`` rather than class-level annotations,
# so mypy can't see ``.tool_name`` / ``.error`` statically. They
# exist at runtime; the ignore is the standard pattern for this
# protocol.
result = self._enforcer.validate_tool_call(event.tool_name) # type: ignore[attr-defined]
if result.allowed:
return
# Build a useful cancel message that the agent loop will turn into
Expand All @@ -120,13 +125,16 @@ async def on_after_tool_call(self, event: AfterToolCallEvent) -> None:
before-hook cancelled the call, so anything reaching this method
actually executed.
"""
if event.error:
# ``ProtectedEvent`` sets ``.error`` / ``.tool_name`` via
# ``self._init(...)`` not class-level fields — see note in
# ``on_before_tool_call``.
if event.error: # type: ignore[attr-defined]
# Failed calls don't advance the step (the model will likely
# retry); they're still recorded for the violation log.
self._enforcer.record_tool_call(event.tool_name)
self._enforcer.record_tool_call(event.tool_name) # type: ignore[attr-defined]
return

self._enforcer.record_tool_call(event.tool_name)
self._enforcer.record_tool_call(event.tool_name) # type: ignore[attr-defined]

step = self._enforcer.current_step
if step is None:
Expand Down
2 changes: 1 addition & 1 deletion src/locus/playbooks/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def load_yaml_string(self, yaml_string: str) -> Playbook:
PlaybookLoadError: If YAML is invalid or playbook validation fails
"""
try:
import yaml
import yaml # type: ignore[import-untyped] # PyYAML ships no inline types
except ImportError as e:
raise PlaybookLoadError(
"PyYAML is required for YAML support. Install with: pip install pyyaml"
Expand Down
2 changes: 1 addition & 1 deletion src/locus/skills/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pathlib import Path
from typing import Any

import yaml
import yaml # type: ignore[import-untyped] # PyYAML ships no inline types


# AgentSkills.io name validation: kebab-case, 1-64 chars, no consecutive hyphens
Expand Down
3 changes: 2 additions & 1 deletion src/locus/streaming/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ def _serialize_event(self, event: LocusEvent) -> dict[str, Any]:
Dictionary representation of the event
"""
if self.custom_serializer:
return self.custom_serializer(event)
# User-supplied callable; mypy can't narrow its return.
return self.custom_serializer(event) # type: ignore[no-any-return]

# Use Pydantic's model_dump
data = event.model_dump()
Expand Down
Loading