Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 108 additions & 102 deletions backend/benchmark_small_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,113 +61,119 @@ def benchmark_model(model_info: dict) -> dict:
}

try:
# Clear memory
gc.collect()
if torch.backends.mps.is_available():
torch.mps.empty_cache()

initial_memory = get_memory_usage()

# Load model
start_time = time.time()
print(" Loading tokenizer and model...")

tokenizer = AutoTokenizer.from_pretrained(model_info["model_id"])
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
model_info["model_id"],
torch_dtype=torch.float16 if torch.backends.mps.is_available() else torch.float32,
device_map="auto" if torch.backends.mps.is_available() else None,
trust_remote_code=True,
)

# Move to MPS if available
if torch.backends.mps.is_available():
model = model.to("mps")

load_time = time.time() - start_time
results["load_time"] = load_time
results["model_size_mb"] = get_model_size(model)
results["memory_usage_gb"] = get_memory_usage() - initial_memory

print(f" ✅ Loaded in {load_time:.2f}s")
print(f" 📦 Model size: {results['model_size_mb']:.1f} MB")
print(f" 🧠 Memory usage: {results['memory_usage_gb']:.2f} GB")

# Test inference
print(" 🚀 Running inference tests...")

for i, prompt in enumerate(TEST_PROMPTS):
try:
# Tokenize
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
if torch.backends.mps.is_available():
inputs = {k: v.to("mps") for k, v in inputs.items()}

# Generate
start_time = time.time()

with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50,
do_sample=True,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id,
)

inference_time = time.time() - start_time

# Decode output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Calculate tokens/second
new_tokens = len(outputs[0]) - len(inputs["input_ids"][0])
tokens_per_sec = new_tokens / inference_time if inference_time > 0 else 0

results["inference_times"].append(inference_time)
results["tokens_per_second"].append(tokens_per_sec)
results["outputs"][f"prompt_{i}"] = {
"prompt": prompt,
"output": generated_text,
"inference_time": inference_time,
"tokens_per_sec": tokens_per_sec,
}

print(f" Prompt {i + 1}: {tokens_per_sec:.1f} tokens/sec")

except Exception as e:
error_msg = f"Error on prompt {i}: {str(e)}"
results["errors"].append(error_msg)
print(f" ❌ {error_msg}")

# Calculate averages
if results["inference_times"]:
results["avg_inference_time"] = sum(results["inference_times"]) / len(
results["inference_times"]
)
results["avg_tokens_per_second"] = sum(results["tokens_per_second"]) / len(
results["tokens_per_second"]
)

print(f" 📊 Average: {results.get('avg_tokens_per_second', 0):.1f} tokens/sec")

results = _run_benchmark(model_info, results)
except Exception as e:
error_msg = f"Failed to load {model_info['name']}: {str(e)}"
results["errors"].append(error_msg)
print(f" ❌ {error_msg}")

finally:
# Cleanup
if "model" in locals():
del model
if "tokenizer" in locals():
del tokenizer
gc.collect()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
return results


def _run_benchmark(model_info: dict, results: dict) -> dict:
"""Run the actual benchmark logic."""
gc.collect()
if torch.backends.mps.is_available():
torch.mps.empty_cache()

initial_memory = get_memory_usage()

# Load model
start_time = time.time()
print(" Loading tokenizer and model...")

tokenizer = AutoTokenizer.from_pretrained(model_info["model_id"])
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
model_info["model_id"],
torch_dtype=torch.float16 if torch.backends.mps.is_available() else torch.float32,
device_map="auto" if torch.backends.mps.is_available() else None,
trust_remote_code=True,
)

if torch.backends.mps.is_available():
model = model.to("mps")

load_time = time.time() - start_time
results["load_time"] = load_time
results["model_size_mb"] = get_model_size(model)
results["memory_usage_gb"] = get_memory_usage() - initial_memory

print(f" ✅ Loaded in {load_time:.2f}s")
print(f" 📦 Model size: {results['model_size_mb']:.1f} MB")
print(f" 🧠 Memory usage: {results['memory_usage_gb']:.2f} GB")

# Test inference
print(" 🚀 Running inference tests...")

results = _run_inference_tests(model, tokenizer, results)

# Calculate averages
if results["inference_times"]:
results["avg_inference_time"] = sum(results["inference_times"]) / len(
results["inference_times"]
)
results["avg_tokens_per_second"] = sum(results["tokens_per_second"]) / len(
results["tokens_per_second"]
)

print(f" 📊 Average: {results.get('avg_tokens_per_second', 0):.1f} tokens/sec")

# Cleanup
if "model" in locals():
del model
if "tokenizer" in locals():
del tokenizer
gc.collect()
if torch.backends.mps.is_available():
torch.mps.empty_cache()

return results


def _run_inference_tests(model, tokenizer, results: dict) -> dict:
"""Run inference tests on the model."""
for i, prompt in enumerate(TEST_PROMPTS):
try:
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
if torch.backends.mps.is_available():
inputs = {k: v.to("mps") for k, v in inputs.items()}

start_time = time.time()

with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50,
do_sample=True,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id,
)

inference_time = time.time() - start_time

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

new_tokens = len(outputs[0]) - len(inputs["input_ids"][0])
tokens_per_sec = new_tokens / inference_time if inference_time > 0 else 0

results["inference_times"].append(inference_time)
results["tokens_per_second"].append(tokens_per_sec)
results["outputs"][f"prompt_{i}"] = {
"prompt": prompt,
"output": generated_text,
"inference_time": inference_time,
"tokens_per_sec": tokens_per_sec,
}

print(f" Prompt {i + 1}: {tokens_per_sec:.1f} tokens/sec")

except Exception as e:
error_msg = f"Error on prompt {i}: {str(e)}"
results["errors"].append(error_msg)
print(f" ❌ {error_msg}")

return results

Expand Down
6 changes: 4 additions & 2 deletions backend/openmlr/agent/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ async def submission_loop(session: Session, tool_router) -> None:


async def run_agent_turn(
session: Session, tool_router, user_message: str, mode: str = None
session: Session, tool_router, user_message: str, mode: str | None = None
) -> None:
"""Direct entry point: run one agent turn."""
await _run_agent(session, tool_router, user_message, mode)


async def _run_agent(session: Session, tool_router, user_message: str, mode: str = None) -> None:
async def _run_agent(
session: Session, tool_router, user_message: str, mode: str | None = None
) -> None:
"""Execute the agentic loop for a user message."""
session.clear_cancel()

Expand Down
9 changes: 5 additions & 4 deletions backend/openmlr/routes/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import logging
from typing import Annotated

from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse
Expand Down Expand Up @@ -607,7 +608,7 @@ async def submit_approval(
@router.post("/todo-approval")
async def submit_todo_approval(
request: Request,
user: User = Depends(get_current_user),
user: Annotated[User, Depends(get_current_user)],
):
"""Submit approval/rejection for proposed TODO list changes."""
body = await request.json()
Expand All @@ -622,10 +623,10 @@ async def submit_todo_approval(
active
and hasattr(active.session, "pending_todo_approval")
and active.session.pending_todo_approval
and not active.session.pending_todo_approval.done()
):
if not active.session.pending_todo_approval.done():
active.session.pending_todo_approval.set_result(result)
return {"ok": True}
active.session.pending_todo_approval.set_result(result)
return {"ok": True}

# Publish to Redis for background job workers
try:
Expand Down
48 changes: 24 additions & 24 deletions backend/openmlr/services/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,36 +275,36 @@ async def _broadcast(event: AgentEvent):
self.sessions[conversation_id] = active
return active

async def remove_session(self, conversation_id: int) -> None:
active = self.sessions.pop(conversation_id, None)
if active:
# Cancel any running agent turn
active.session.cancel()
# Resolve any pending question/approval futures to unblock the loop
if hasattr(active.session, "pending_answers") and active.session.pending_answers:
try:
if not active.session.pending_answers.done():
active.session.pending_answers.cancel()
except Exception:
pass
if (
hasattr(active.session, "pending_todo_approval")
and active.session.pending_todo_approval
):
try:
if not active.session.pending_todo_approval.done():
active.session.pending_todo_approval.cancel()
except Exception:
pass
async def _cleanup_session(self, active) -> None:
active.session.cancel()
if hasattr(active.session, "pending_answers") and active.session.pending_answers:
try:
await active.sandbox_manager.destroy()
if not active.session.pending_answers.done():
active.session.pending_answers.cancel()
except Exception:
pass
# Disconnect MCP servers
if (
hasattr(active.session, "pending_todo_approval")
and active.session.pending_todo_approval
):
try:
await active.mcp_manager.disconnect_all()
if not active.session.pending_todo_approval.done():
active.session.pending_todo_approval.cancel()
except Exception:
pass
try:
await active.sandbox_manager.destroy()
except Exception:
pass
try:
await active.mcp_manager.disconnect_all()
except Exception:
pass

async def remove_session(self, conversation_id: int) -> None:
active = self.sessions.pop(conversation_id, None)
if active:
await self._cleanup_session(active)
if self.current_conversation_id == conversation_id:
self.current_conversation_id = None

Expand Down
Loading
Loading