diff --git a/.gitignore b/.gitignore index d30afca..3154eea 100644 --- a/.gitignore +++ b/.gitignore @@ -35,3 +35,9 @@ workspace/tmp/ tests/js/node_modules/ tests/js/package-lock.json tests/js/*.log + +# User direction submissions (runtime data, keep out of the repo) +research_agendas/inbox/*.yaml +research_agendas/inbox/*.yml +research_agendas/inbox/processed/ +research_agendas/inbox/failed/ diff --git a/agents/agenda_budget.py b/agents/agenda_budget.py new file mode 100644 index 0000000..b070ee5 --- /dev/null +++ b/agents/agenda_budget.py @@ -0,0 +1,242 @@ +"""Per-agenda LLM token accounting and budget enforcement. + +Each research agenda carries a token budget (research_agendas.token_budget, +falling back to config.AGENDA_TOKEN_BUDGET_DEFAULT). A budget of 0, NULL +(with default 0) or a negative value means "no cap": check_budget always +passes and the agenda is never flipped to 'paused_budget', but record_usage +still writes the agenda_token_ledger row and bumps token_spent, so accounting +stays complete either way. Work performed on behalf of an agenda runs inside +`agenda_scope(agenda_id, operation)`; the LLM client then: + +1. calls `check_budget(agenda_id)` BEFORE issuing the provider request — if the + budget is exhausted the call fails with AgendaBudgetExceededError before any + tokens are spent, so the caller can stop cleanly without partial writes; +2. calls `record_usage(...)` after a successful response — one ledger row per + call plus an atomic increment of research_agendas.token_spent. + +When token_spent crosses the budget the agenda is flipped to status +'paused_budget'. Raising the budget (token_spent < budget again) or calling +`resume_agenda` re-enables it. + +Note: the check-then-record pair is not a distributed lock; concurrent calls +may overshoot the budget by roughly one call's worth of tokens. That is an +accepted tolerance for a soft cost cap. +""" + +from __future__ import annotations + +import contextvars +from contextlib import contextmanager +from typing import Any, Iterator + +from config import AGENDA_TOKEN_BUDGET_DEFAULT +from db import database as db + +AGENDA_STATUS_PAUSED_BUDGET = "paused_budget" +AGENDA_STATUS_ACTIVE = "active" + +# (agenda_id, operation) for the work currently running in this context. +_scope_var: contextvars.ContextVar[tuple[int, str] | None] = contextvars.ContextVar( + "agenda_budget_scope", default=None +) + + +class AgendaBudgetExceededError(RuntimeError): + """Raised before an LLM call when the agenda's token budget is spent.""" + + def __init__(self, agenda_id: int, token_spent: int, token_budget: int): + self.agenda_id = int(agenda_id) + self.token_spent = int(token_spent) + self.token_budget = int(token_budget) + super().__init__( + f"agenda {agenda_id} token budget exhausted " + f"({token_spent}/{token_budget} tokens); raise token_budget or call " + f"POST /api/research_agenda/{agenda_id}/resume to continue" + ) + + +@contextmanager +def agenda_scope(agenda_id: int, operation: str = "llm_call") -> Iterator[None]: + """Attribute all LLM calls inside the block to the given agenda.""" + token = _scope_var.set((int(agenda_id), str(operation))) + try: + yield + finally: + _scope_var.reset(token) + + +def current_scope() -> tuple[int, str] | None: + """Return (agenda_id, operation) for the current context, if any.""" + return _scope_var.get() + + +def effective_budget(token_budget: Any) -> int: + """Resolve a row's token_budget to the enforced value (NULL -> default). + + The resolved value may be <= 0, which means no cap is enforced (usage is + still recorded by record_usage). + """ + if token_budget is None: + return int(AGENDA_TOKEN_BUDGET_DEFAULT) + return int(token_budget) + + +def get_budget_state(agenda_id: int) -> dict[str, Any] | None: + row = db.fetchone( + "SELECT id, token_budget, token_spent, status FROM research_agendas WHERE id=?", + (int(agenda_id),), + ) + if not row: + return None + budget = effective_budget(row.get("token_budget")) + spent = int(row.get("token_spent") or 0) + return { + "agenda_id": int(row["id"]), + "token_budget": budget, + "token_budget_raw": row.get("token_budget"), + "token_spent": spent, + "status": str(row.get("status") or AGENDA_STATUS_ACTIVE), + "exhausted": budget > 0 and spent >= budget, + } + + +def check_budget(agenda_id: int) -> None: + """Raise AgendaBudgetExceededError if the agenda may not spend more tokens. + + A budget <= 0 (including the NULL -> default fallback) disables the cap: + the check always passes and the agenda is never paused for budget reasons, + while record_usage keeps writing the ledger and token_spent. + + A 'paused_budget' agenda whose budget was raised in the meantime + (token_spent < budget again) is automatically reactivated, so increasing + the budget alone is enough to continue. + """ + state = get_budget_state(agenda_id) + if state is None: + # Unknown agenda: nothing to enforce. Scoped callers validate + # existence elsewhere; do not block the call on accounting state. + return + budget = state["token_budget"] + spent = state["token_spent"] + if budget <= 0: # explicit 0/negative budget disables the cap + return + if spent >= budget: + if state["status"] != AGENDA_STATUS_PAUSED_BUDGET: + _set_status(agenda_id, AGENDA_STATUS_PAUSED_BUDGET) + raise AgendaBudgetExceededError(agenda_id, spent, budget) + if state["status"] == AGENDA_STATUS_PAUSED_BUDGET: + # Budget was raised above current spend: unpause and continue. + _set_status(agenda_id, AGENDA_STATUS_ACTIVE) + + +def record_usage( + agenda_id: int, + operation: str, + tokens: int, + cost_usd: float | None = None, +) -> dict[str, Any] | None: + """Append a ledger row and bump research_agendas.token_spent. + + If the new total crosses the budget, the agenda is set to 'paused_budget' + so the next check_budget() stops further spending. Ledger insert, counter + update and the status flip commit together. + """ + tokens = int(tokens or 0) + agenda_id = int(agenda_id) + db.execute( + """ + INSERT INTO agenda_token_ledger (agenda_id, operation, tokens, cost_usd) + VALUES (?, ?, ?, ?) + """, + (agenda_id, str(operation or "llm_call"), tokens, cost_usd), + ) + db.execute( + "UPDATE research_agendas SET token_spent = COALESCE(token_spent, 0) + ?, " + "updated_at=CURRENT_TIMESTAMP WHERE id=?", + (tokens, agenda_id), + ) + state = get_budget_state(agenda_id) + if state and state["exhausted"] and state["status"] != AGENDA_STATUS_PAUSED_BUDGET: + db.execute( + "UPDATE research_agendas SET status=?, updated_at=CURRENT_TIMESTAMP WHERE id=?", + (AGENDA_STATUS_PAUSED_BUDGET, agenda_id), + ) + state["status"] = AGENDA_STATUS_PAUSED_BUDGET + db.commit() + return state + + +def resume_agenda(agenda_id: int, *, token_budget: int | None = None) -> dict[str, Any] | None: + """Reactivate a budget-paused agenda, optionally raising its budget.""" + agenda_id = int(agenda_id) + if token_budget is not None: + db.execute( + "UPDATE research_agendas SET token_budget=?, updated_at=CURRENT_TIMESTAMP WHERE id=?", + (int(token_budget), agenda_id), + ) + db.execute( + "UPDATE research_agendas SET status=?, updated_at=CURRENT_TIMESTAMP WHERE id=? AND status=?", + (AGENDA_STATUS_ACTIVE, agenda_id, AGENDA_STATUS_PAUSED_BUDGET), + ) + db.commit() + return get_budget_state(agenda_id) + + +def _set_status(agenda_id: int, status: str) -> None: + db.execute( + "UPDATE research_agendas SET status=?, updated_at=CURRENT_TIMESTAMP WHERE id=?", + (status, int(agenda_id)), + ) + db.commit() + + +def usage_summary() -> dict[str, Any]: + """Aggregate the ledger per agenda + overall totals (local accounting view).""" + per_agenda = db.fetchall( + """ + SELECT ra.id AS agenda_id, ra.name, ra.status, ra.submitter, + ra.token_budget, ra.token_spent, + COALESCE(SUM(l.tokens), 0) AS ledger_tokens, + SUM(l.cost_usd) AS ledger_cost_usd, + COUNT(l.id) AS ledger_entries + FROM research_agendas ra + LEFT JOIN agenda_token_ledger l ON l.agenda_id = ra.id + GROUP BY ra.id, ra.name, ra.status, ra.submitter, ra.token_budget, ra.token_spent + ORDER BY ra.id + """ + ) + rows = [] + total_tokens = 0 + total_cost = 0.0 + has_cost = False + for r in per_agenda: + spent = int(r.get("token_spent") or 0) + budget = effective_budget(r.get("token_budget")) + ledger_tokens = int(r.get("ledger_tokens") or 0) + cost = r.get("ledger_cost_usd") + if cost is not None: + has_cost = True + total_cost += float(cost) + total_tokens += ledger_tokens + rows.append( + { + "agenda_id": r.get("agenda_id"), + "name": r.get("name"), + "status": r.get("status") or AGENDA_STATUS_ACTIVE, + "submitter": r.get("submitter"), + "token_budget": budget, + "token_spent": spent, + "ledger_tokens": ledger_tokens, + "ledger_cost_usd": cost, + "ledger_entries": int(r.get("ledger_entries") or 0), + "remaining": max(0, budget - spent) if budget > 0 else None, + } + ) + return { + "agendas": rows, + "totals": { + "tokens": total_tokens, + "cost_usd": total_cost if has_cost else None, + "agenda_count": len(rows), + }, + } diff --git a/agents/agenda_loader.py b/agents/agenda_loader.py index 713d53f..312095a 100644 --- a/agents/agenda_loader.py +++ b/agents/agenda_loader.py @@ -8,9 +8,13 @@ - save_agenda(agenda) -> int # insert; returns agenda_id - update_agenda(agenda_id, agenda) -> None - get_agenda(agenda_id) -> ResearchAgenda | None -- get_active_agenda() -> ResearchAgenda | None +- get_active_agenda() -> ResearchAgenda | None # newest active (several may be active) - list_agendas(*, only_active=False) -> list[ResearchAgenda] -- set_active_agenda(agenda_id) -> None # exclusive active flag +- set_active_agenda(agenda_id) -> None # mark one agenda active + +Multiple agendas may be active at the same time; callers that operate on a +specific agenda should pass agenda_id explicitly. get_active_agenda() is kept +as a convenience for single-agenda deployments and returns the newest active row. """ from __future__ import annotations @@ -54,6 +58,10 @@ def _decode(field_name: str, default: Any) -> Any: required_output=ensure_dict(_decode("required_output_json", {})), raw_config=ensure_dict(raw_config_obj), is_active=bool(row.get("is_active", 1)), + submitter=str(row.get("submitter") or ""), + token_budget=row.get("token_budget"), + token_spent=int(row.get("token_spent") or 0), + status=str(row.get("status") or "active"), ) agenda.validate() return agenda @@ -75,6 +83,10 @@ def parse_agenda(payload: Mapping[str, Any], *, agenda_id: int | None = None) -> required_output=ensure_dict(payload.get("required_output") or {}), raw_config=dict(payload), is_active=bool(payload.get("is_active", True)), + submitter=str(payload.get("submitter") or "").strip(), + token_budget=payload.get("token_budget"), + token_spent=int(payload.get("token_spent") or 0), + status=str(payload.get("status") or "active"), ) agenda.validate() return agenda @@ -99,14 +111,19 @@ def load_agenda_from_file(path: str | Path) -> ResearchAgenda: def save_agenda(agenda: ResearchAgenda) -> int: - """Insert a new agenda. Returns the new agenda_id.""" + """Insert a new agenda. Returns the new agenda_id. + + Does not deactivate other agendas: several agendas may run concurrently, + each isolated by agenda_id and its own token budget. + """ agenda.validate() new_id = db.insert_returning_id( """ INSERT INTO research_agendas (version, name, description, focus_json, prefer_json, reject_json, - required_output_json, raw_config_json, is_active) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + required_output_json, raw_config_json, is_active, + submitter, token_budget, token_spent, status) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) RETURNING id """, ( @@ -119,10 +136,12 @@ def save_agenda(agenda: ResearchAgenda) -> int: json.dumps(agenda.required_output, ensure_ascii=False), json.dumps(agenda.raw_config, ensure_ascii=False), 1 if agenda.is_active else 0, + agenda.submitter or None, + agenda.token_budget, + int(agenda.token_spent or 0), + agenda.status, ), ) - if agenda.is_active: - set_active_agenda(new_id) db.commit() agenda.agenda_id = new_id return new_id @@ -136,7 +155,8 @@ def update_agenda(agenda_id: int, agenda: ResearchAgenda) -> None: UPDATE research_agendas SET version=?, name=?, description=?, focus_json=?, prefer_json=?, reject_json=?, required_output_json=?, raw_config_json=?, - is_active=?, updated_at=CURRENT_TIMESTAMP + is_active=?, submitter=?, token_budget=?, status=?, + updated_at=CURRENT_TIMESTAMP WHERE id=? """, ( @@ -149,11 +169,12 @@ def update_agenda(agenda_id: int, agenda: ResearchAgenda) -> None: json.dumps(agenda.required_output, ensure_ascii=False), json.dumps(agenda.raw_config, ensure_ascii=False), 1 if agenda.is_active else 0, + agenda.submitter or None, + agenda.token_budget, + agenda.status, agenda_id, ), ) - if agenda.is_active: - set_active_agenda(agenda_id) db.commit() @@ -168,8 +189,13 @@ def get_agenda(agenda_id: int) -> ResearchAgenda | None: def get_active_agenda() -> ResearchAgenda | None: + """Return the newest active agenda. + + Several agendas may be active at once; this helper exists for + single-agenda deployments and callers without an explicit agenda_id. + """ row = db.fetchone( - "SELECT * FROM research_agendas WHERE is_active=1 ORDER BY created_at DESC LIMIT 1", + "SELECT * FROM research_agendas WHERE is_active=1 ORDER BY created_at DESC, id DESC LIMIT 1", (), ) if not row: @@ -192,13 +218,14 @@ def list_agendas(*, only_active: bool = False) -> list[ResearchAgenda]: def set_active_agenda(agenda_id: int) -> None: - """Mark a single agenda active; clear is_active on all others.""" - db.execute( - "UPDATE research_agendas SET is_active=0 WHERE id<>?", - (agenda_id,), - ) + """Mark an agenda active. + + Historically this cleared is_active on every other agenda (single-active + model). Agendas are now isolated per agenda_id, so activating one no + longer deactivates the rest. + """ db.execute( - "UPDATE research_agendas SET is_active=1 WHERE id=?", + "UPDATE research_agendas SET is_active=1, updated_at=CURRENT_TIMESTAMP WHERE id=?", (agenda_id,), ) db.commit() diff --git a/agents/agenda_selector.py b/agents/agenda_selector.py index 2cfd476..21edeac 100644 --- a/agents/agenda_selector.py +++ b/agents/agenda_selector.py @@ -27,6 +27,7 @@ from contracts.agenda import VALID_SELECTION_STATUS, AgendaSelection, ResearchAgenda from contracts.base import ContractValidationError, ensure_dict, ensure_list, ensure_string_list from db import database as db +from db.sql_dialect import escape_like # ---------- scoring weights ---------- @@ -236,19 +237,70 @@ def evaluate_candidates( # ---------- persistence + DB-facing API ---------- -def _fetch_insight_pool(limit: int = 200) -> list[dict[str, Any]]: - rows = db.fetchall( - """ +def agenda_scope_keywords(agenda: ResearchAgenda) -> list[str]: + """Keywords that define the agenda's topical scope: focus + prefer.keywords.""" + seen: set[str] = set() + out: list[str] = [] + for kw in list(agenda.focus or []) + ensure_string_list((agenda.prefer or {}).get("keywords")): + k = str(kw).strip().lower() + if k and k not in seen: + seen.add(k) + out.append(k) + return out + + +_POOL_SELECT = """ SELECT id, tier, status, title, problem_statement, formal_structure, existing_weakness, transformation, proposed_method, adversarial_score, novelty_status, resource_class, experimentability, submission_status, outcome FROM deep_insights - WHERE status IS NULL OR status NOT IN ('rejected', 'archived') - ORDER BY id DESC - LIMIT ? - """, - (int(limit),), +""" + + +def _fetch_insight_pool( + limit: int = 200, + agenda: ResearchAgenda | None = None, +) -> list[dict[str, Any]]: + """Fetch candidate insights, optionally scoped to one agenda. + + Without an agenda this is the original whole-table query (backward + compatible). With an agenda the pool is restricted to: + - insights tagged with this agenda_id (produced for this agenda), plus + - untagged insights (agenda_id IS NULL) whose text matches the agenda's + scope keywords (focus + prefer.keywords). + Insights tagged with a different agenda_id are always excluded. + """ + status_filter = "(status IS NULL OR status NOT IN ('rejected', 'archived'))" + if agenda is None: + rows = db.fetchall( + f"{_POOL_SELECT} WHERE {status_filter} ORDER BY id DESC LIMIT ?", + (int(limit),), + ) + return rows + + keywords = agenda_scope_keywords(agenda) + clauses: list[str] = [] + params: list[Any] = [] + if agenda.agenda_id: + clauses.append("agenda_id = ?") + params.append(int(agenda.agenda_id)) + if keywords: + likes = [] + for kw in keywords: + likes.append( + "LOWER(COALESCE(title, '') || ' ' || COALESCE(problem_statement, '') " + "|| ' ' || COALESCE(formal_structure, '')) LIKE ? ESCAPE '\\'" + ) + params.append(f"%{escape_like(kw)}%") + clauses.append(f"(agenda_id IS NULL AND ({' OR '.join(likes)}))") + else: + # No scope keywords: keep untagged insights visible (legacy data). + clauses.append("agenda_id IS NULL") + rows = db.fetchall( + f"{_POOL_SELECT} WHERE {status_filter} AND ({' OR '.join(clauses)}) " + "ORDER BY id DESC LIMIT ?", + (*params, int(limit)), ) return rows @@ -276,14 +328,26 @@ def select_and_persist( *, limit: int = 200, pool: list[Mapping[str, Any]] | None = None, + scope_to_agenda: bool = False, ) -> AgendaSelection: - """Run selection over the deep_insights table, persist result, return contract.""" + """Run selection over the deep_insights table, persist result, return contract. + + With scope_to_agenda=True the candidate pool is pre-filtered to the + agenda's own insights + keyword-matching untagged ones (see + _fetch_insight_pool); default False keeps the historical whole-pool + behavior where scoring alone decides. + """ agenda.validate() if not agenda.agenda_id: raise ContractValidationError( "agenda must be persisted (have agenda_id) before running selection" ) - insight_pool = list(pool) if pool is not None else _fetch_insight_pool(limit=limit) + if pool is not None: + insight_pool = list(pool) + else: + insight_pool = _fetch_insight_pool( + limit=limit, agenda=agenda if scope_to_agenda else None + ) if not insight_pool: # persist an empty selection so the UI can see "no candidates" sel = AgendaSelection( diff --git a/agents/direction_intake.py b/agents/direction_intake.py new file mode 100644 index 0000000..0b71e02 --- /dev/null +++ b/agents/direction_intake.py @@ -0,0 +1,219 @@ +"""User research-direction intake: deterministic YAML -> ResearchAgenda mapping. + +Input schema (user-facing, see research_agendas/inbox/README.md): + + direction: "..." # required, natural language + keywords: [a, b, c] # optional + constraints: # optional, free text + compute: "..." + data: "..." + goal: experiment_plan # idea_only | experiment_plan | signal | verified_evidence + contact: "..." # required, nickname or email + token_budget: 500000 # optional token cap; default 0 = no cap, + # usage is recorded either way + +Mapping (rule-based, no LLM): + direction -> description (+ auto-generated slug name) + keywords -> focus + constraints.compute -> prefer.resource_class via keyword rules (original + text is preserved in raw_config) + goal -> required_output.goal + contact -> submitter + +`parse_direction_yaml` also returns an echo dict: a templated summary of what +the system understood, for the submitter to confirm. +""" + +from __future__ import annotations + +import hashlib +import re +from typing import Any, Mapping + +import yaml # type: ignore + +from agents.agenda_loader import parse_agenda +from contracts.agenda import ResearchAgenda +from contracts.base import ContractValidationError, ensure_string_list + + +VALID_GOALS = ("idea_only", "experiment_plan", "signal", "verified_evidence") +DEFAULT_GOAL = "experiment_plan" + +GOAL_LABELS_ZH = { + "idea_only": "仅研究想法", + "experiment_plan": "可执行实验计划", + "signal": "结构化信号报告", + "verified_evidence": "经验证的实验证据", +} + +# Compute-constraint keyword rules, checked in order; first match wins. +# Conservative on purpose: anything that sounds like "small machine" maps to +# the cheaper resource classes used by deep_insights.resource_class. +_COMPUTE_RULES: tuple[tuple[tuple[str, ...], list[str]], ...] = ( + ( + ("cpu", "无gpu", "no gpu", "笔记本", "laptop", "notebook"), + ["cpu"], + ), + ( + ( + "单卡", "单gpu", "single gpu", "single-gpu", "1 gpu", "one gpu", "1gpu", + "colab", "t4", "消费级", "consumer gpu", + ), + ["cpu", "gpu_small"], + ), +) + + +class DirectionParseError(ValueError): + """Raised when a direction submission cannot be mapped to an agenda.""" + + +def map_compute_constraint(text: Any) -> list[str] | None: + """Best-effort keyword mapping from a free-text compute constraint.""" + blob = str(text or "").strip().lower() + if not blob: + return None + for markers, resource_classes in _COMPUTE_RULES: + if any(marker in blob for marker in markers): + return list(resource_classes) + return None + + +def _ascii_tokens(text: str, *, min_len: int = 3, max_tokens: int = 8) -> list[str]: + tokens: list[str] = [] + seen: set[str] = set() + for tok in re.findall(r"[A-Za-z][A-Za-z0-9\-]*", str(text or "")): + tok = tok.lower() + if len(tok) < min_len or tok in seen: + continue + seen.add(tok) + tokens.append(tok) + if len(tokens) >= max_tokens: + break + return tokens + + +def _slug_name(direction: str, keywords: list[str]) -> str: + """Deterministic agenda name: slug of keywords/direction + short hash.""" + basis = keywords if keywords else _ascii_tokens(direction, min_len=3, max_tokens=4) + slug = "-".join(re.sub(r"[^a-z0-9]+", "-", k.lower()).strip("-") for k in basis[:4]) + slug = re.sub(r"-{2,}", "-", slug).strip("-")[:48] + digest = hashlib.sha1(direction.encode("utf-8")).hexdigest()[:8] + return f"direction-{slug}-{digest}" if slug else f"direction-{digest}" + + +def parse_direction_payload(payload: Mapping[str, Any]) -> ResearchAgenda: + """Map a parsed direction dict to a validated ResearchAgenda.""" + if not isinstance(payload, Mapping): + raise DirectionParseError("direction submission must be a YAML mapping") + + direction = str(payload.get("direction") or "").strip() + if not direction: + raise DirectionParseError("'direction' is required (natural-language research direction)") + + contact = str(payload.get("contact") or "").strip() + if not contact: + raise DirectionParseError("'contact' is required (nickname or email)") + + goal = str(payload.get("goal") or DEFAULT_GOAL).strip().lower() + if goal not in VALID_GOALS: + raise DirectionParseError( + f"'goal' must be one of {list(VALID_GOALS)}, got '{goal}'" + ) + + keywords = ensure_string_list(payload.get("keywords") or []) + focus = keywords or _ascii_tokens(direction) + + constraints = payload.get("constraints") + if constraints is not None and not isinstance(constraints, Mapping): + raise DirectionParseError("'constraints' must be a mapping of free-text fields") + constraints = dict(constraints or {}) + + prefer: dict[str, Any] = {} + resource_classes = map_compute_constraint(constraints.get("compute")) + if resource_classes: + prefer["resource_class"] = resource_classes + + if not focus and not prefer: + raise DirectionParseError( + "could not derive any scope keywords; add a 'keywords' list " + "(the direction text has no extractable terms)" + ) + + token_budget = payload.get("token_budget") + if token_budget is not None: + try: + token_budget = int(token_budget) + except (TypeError, ValueError) as exc: + raise DirectionParseError("'token_budget' must be an integer") from exc + + agenda_payload: dict[str, Any] = { + "version": "v1", + "name": _slug_name(direction, keywords), + "description": direction, + "focus": focus, + "prefer": prefer, + "required_output": {"goal": goal}, + "submitter": contact, + # Keep the original submission verbatim for auditability. + "source": "direction_intake_v1", + "direction": direction, + "keywords": keywords, + "constraints": constraints, + "goal": goal, + "contact": contact, + } + if token_budget is not None: + agenda_payload["token_budget"] = token_budget + + try: + return parse_agenda(agenda_payload) + except ContractValidationError as exc: + raise DirectionParseError(f"mapped agenda failed validation: {exc}") from exc + + +def build_echo(agenda: ResearchAgenda, payload: Mapping[str, Any]) -> dict[str, Any]: + """Templated confirmation of what the system understood (for the submitter).""" + constraints = payload.get("constraints") + constraints = dict(constraints) if isinstance(constraints, Mapping) else {} + goal = str(agenda.required_output.get("goal") or DEFAULT_GOAL) + resource_classes = ensure_string_list((agenda.prefer or {}).get("resource_class")) + + parts = [ + f"已登记研究方向:{agenda.description}", + f"识别到的范围关键词:{'、'.join(agenda.focus) if agenda.focus else '(无,建议补充 keywords)'}", + f"目标产出:{GOAL_LABELS_ZH.get(goal, goal)}({goal})", + ] + if constraints.get("compute"): + mapped = "、".join(resource_classes) if resource_classes else "未识别(原文已保留)" + parts.append(f"算力约束:{constraints['compute']} → 资源档位 {mapped}") + if constraints.get("data"): + parts.append(f"数据约束:{constraints['data']}(原文保留,供研究执行时参考)") + parts.append(f"联系人:{agenda.submitter}") + + return { + "type": "direction_intake_echo", + "name": agenda.name, + "direction": agenda.description, + "focus": list(agenda.focus), + "goal": goal, + "constraints": constraints, + "resource_class": resource_classes or None, + "submitter": agenda.submitter, + "summary": ";".join(parts), + } + + +def parse_direction_yaml(text: str) -> tuple[ResearchAgenda, dict[str, Any]]: + """Parse a direction YAML document into (ResearchAgenda, echo dict).""" + if not str(text or "").strip(): + raise DirectionParseError("empty submission") + try: + payload = yaml.safe_load(text) + except yaml.YAMLError as exc: + raise DirectionParseError(f"invalid YAML: {exc}") from exc + if not isinstance(payload, Mapping): + raise DirectionParseError("direction submission must be a YAML mapping") + agenda = parse_direction_payload(payload) + return agenda, build_echo(agenda, payload) diff --git a/agents/llm_client.py b/agents/llm_client.py index ca9d256..e15f994 100644 --- a/agents/llm_client.py +++ b/agents/llm_client.py @@ -591,12 +591,37 @@ def is_llm_transient_provider_error(exc: Exception) -> bool: return any(marker in msg for marker in markers) +def _agenda_budget_scope(): + """Return the active agenda budget scope, if agenda accounting is in use. + + Imported lazily: llm_client stays importable without the DB layer, and the + hook costs nothing for calls that run outside an agenda scope. + """ + try: + from agents import agenda_budget + except Exception: # pragma: no cover - defensive: accounting is optional + return None + return agenda_budget.current_scope() + + def call_llm(system_prompt: str, user_prompt: str, temperature: float = 0.0, max_tokens: int = None) -> tuple[str, int]: - """Call LLM with automatic provider selection, rate limiting, and failover.""" + """Call LLM with automatic provider selection, rate limiting, and failover. + + When running inside agents.agenda_budget.agenda_scope(...), the call is + metered against that agenda: budget is checked before contacting any + provider (raising AgendaBudgetExceededError without spending tokens) and + usage is recorded in the agenda token ledger afterwards. + """ max_tokens = max_tokens or LLM_MAX_OUTPUT_TOKENS _init_providers() + budget_scope = _agenda_budget_scope() + if budget_scope is not None: + from agents import agenda_budget + # Raises AgendaBudgetExceededError before any provider request. + agenda_budget.check_budget(budget_scope[0]) + last_error = None tried = set() MAX_429_RETRIES = 3 @@ -642,6 +667,18 @@ def call_llm(system_prompt: str, user_prompt: str, temperature: float = 0.0, stats["cached_tokens"] += cached_toks stats["input_tokens"] += input_toks _release_provider(provider["name"]) + if budget_scope is not None and tokens: + try: + from agents import agenda_budget + agenda_budget.record_usage(budget_scope[0], budget_scope[1], tokens) + except Exception as acct_err: # noqa: BLE001 + # Accounting failure must not discard a paid response; + # log it so the ledger gap is visible. + print( + f"[LLM] WARNING: agenda token accounting failed " + f"(agenda {budget_scope[0]}): {acct_err}", + flush=True, + ) return text, tokens except Exception as e: diff --git a/agents/paper_idea_agent.py b/agents/paper_idea_agent.py index e470dcc..aa65724 100644 --- a/agents/paper_idea_agent.py +++ b/agents/paper_idea_agent.py @@ -9,10 +9,11 @@ Call 3: Experimental Design — complete plan with baselines, datasets, ablations """ import json +from agents.agenda_budget import AgendaBudgetExceededError from agents.discovery_metadata import build_evidence_packet, enrich_deep_insight from agents.insight_validation import get_evosci_input_issue from agents.llm_client import call_llm_json, is_llm_auth_error, is_llm_provider_unavailable_error -from agents.signal_harvester import get_tier2_signals +from agents.signal_harvester import agenda_taxonomy_node_ids, get_tier2_signals from db import database as db @@ -398,11 +399,17 @@ def discover_paper_ideas( *, tier2_plateau_limit: int = 20, tier2_limitation_nodes: int = 15, + agenda=None, ) -> list[dict]: """Run the 3-stage paper idea discovery pipeline. Returns list of deep_insight dicts ready for storage. If max_papers is None, every sharpened problem (up to max_problems) is expanded. + + With an agenda (contracts.agenda.ResearchAgenda), the signal scan is + circled to the matching taxonomy subgraph and produced ideas are tagged + with agenda_id. Budget exhaustion stops the loop cleanly, returning the + ideas accepted so far. """ if max_papers is None: max_papers = max_problems @@ -411,10 +418,25 @@ def discover_paper_ideas( total_tokens = 0 total_calls = 0 - # Stage 0: Gather signals + # Stage 0: Gather signals (scoped to the agenda's subgraph when known) + scope_node_ids = None + scope_keywords = None + if agenda is not None: + from agents.agenda_selector import agenda_scope_keywords + + scope_keywords = agenda_scope_keywords(agenda) or None + scope_node_ids = agenda_taxonomy_node_ids(scope_keywords or []) or None + if scope_node_ids is None: + print( + f"[PAPER_IDEA] Agenda '{agenda.name}' matched no taxonomy nodes; " + "falling back to global signal scan", + flush=True, + ) signals = get_tier2_signals( plateau_limit=tier2_plateau_limit, limitation_node_limit=tier2_limitation_nodes, + node_ids=scope_node_ids, + scope_keywords=scope_keywords, ) has_signals = ( signals["contradiction_clusters"] @@ -438,6 +460,9 @@ def discover_paper_ideas( result1, tokens1 = call_llm_json(PROBLEM_SHARPENING_SYSTEM, problem_prompt) total_tokens += tokens1 total_calls += 1 + except AgendaBudgetExceededError as e: + print(f"[PAPER_IDEA] Stopped before problem sharpening: {e}", flush=True) + return [] except Exception as e: if _llm_temporarily_unavailable(e): print(f"[PAPER_IDEA] Problem sharpening skipped: LLM unavailable ({e})", flush=True) @@ -473,6 +498,9 @@ def discover_paper_ideas( result2, tokens2 = call_llm_json(METHOD_INVENTION_SYSTEM, method_prompt) total_tokens += tokens2 total_calls += 1 + except AgendaBudgetExceededError as e: + print(f"[PAPER_IDEA] Stopped at method invention: {e}", flush=True) + break except Exception as e: if _llm_temporarily_unavailable(e): print(f"[PAPER_IDEA] Method invention paused: LLM unavailable ({e})", flush=True) @@ -498,6 +526,9 @@ def discover_paper_ideas( result3, tokens3 = call_llm_json(EXPERIMENT_DESIGN_SYSTEM, exp_prompt) total_tokens += tokens3 total_calls += 1 + except AgendaBudgetExceededError as e: + print(f"[PAPER_IDEA] Stopped at experiment design: {e}", flush=True) + break except Exception as e: if _llm_temporarily_unavailable(e): print(f"[PAPER_IDEA] Experiment design skipped: LLM unavailable ({e})", flush=True) @@ -570,6 +601,7 @@ def discover_paper_ideas( "novelty_status": "unchecked", "generation_tokens": total_tokens, "llm_calls": total_calls, + "agenda_id": agenda.agenda_id if agenda is not None else None, } input_issue = get_evosci_input_issue(deep_insight, mode="verification") diff --git a/agents/paradigm_agent.py b/agents/paradigm_agent.py index 4dd2bf2..30c7dfc 100644 --- a/agents/paradigm_agent.py +++ b/agents/paradigm_agent.py @@ -20,8 +20,9 @@ is_llm_auth_error, is_llm_provider_unavailable_error, ) +from agents.agenda_budget import AgendaBudgetExceededError from contracts import DeepInsightSpec, normalize_deep_insight_storage -from agents.signal_harvester import get_tier1_signals +from agents.signal_harvester import agenda_taxonomy_node_ids, get_tier1_signals from config import LLM_MODEL, PROMPT_VERSION from db import database as db from db.insight_outcomes import new_generation_run_id, record_created @@ -326,11 +327,25 @@ def _call_with_provider(system: str, user: str, provider_name: str = None) -> tu _init_providers() for p in _providers: if p["name"] == provider_name: + from agents.agenda_budget import check_budget, current_scope, record_usage from agents.llm_client import _call_provider, _rate_limiters + budget_scope = current_scope() + if budget_scope is not None: + # Same metering as llm_client.call_llm: this path goes to a + # specific provider directly, so account for it here too. + check_budget(budget_scope[0]) limiter = _rate_limiters.get(p["name"]) if limiter: limiter.wait() text, tokens, _, _ = _call_provider(p, system, user, 16_000) + if budget_scope is not None and tokens: + try: + record_usage(budget_scope[0], budget_scope[1], tokens) + except Exception as acct_err: # noqa: BLE001 + print( + f"[PARADIGM] WARNING: agenda token accounting failed: {acct_err}", + flush=True, + ) import re text = text.strip() text = re.sub(r'[\s\S]*?', '', text).strip() @@ -364,18 +379,38 @@ def discover_paradigm_insights( *, tier1_top_overlaps: int = 20, tier1_top_patterns: int = 15, + agenda=None, ) -> list[dict]: """Run the 3-stage paradigm discovery pipeline. Returns list of deep_insight dicts ready for storage. + + With an agenda (contracts.agenda.ResearchAgenda), the signal scan is + circled to the taxonomy subgraph matching the agenda's scope keywords and + every produced insight is tagged with agenda_id. Budget exhaustion + (AgendaBudgetExceededError from the metered LLM client) stops the loop + cleanly, returning the insights accepted so far. """ print(f"[PARADIGM] Starting Tier 1 discovery (max {max_candidates} candidates)...", flush=True) total_tokens = 0 total_calls = 0 - # Stage 0: Gather signals + # Stage 0: Gather signals (scoped to the agenda's subgraph when known) + scope_node_ids = None + if agenda is not None: + from agents.agenda_selector import agenda_scope_keywords + + scope_node_ids = agenda_taxonomy_node_ids(agenda_scope_keywords(agenda)) or None + if scope_node_ids is None: + print( + f"[PARADIGM] Agenda '{agenda.name}' matched no taxonomy nodes; " + "falling back to global signal scan", + flush=True, + ) signals = get_tier1_signals( - top_overlaps=tier1_top_overlaps, top_patterns=tier1_top_patterns + top_overlaps=tier1_top_overlaps, + top_patterns=tier1_top_patterns, + node_ids=scope_node_ids, ) if not signals["entity_overlaps"] and not signals["pattern_matches"]: print("[PARADIGM] No signals available. Run signal_harvester first.", flush=True) @@ -388,6 +423,9 @@ def discover_paradigm_insights( result1, tokens1 = call_llm_json(STRUCTURE_DETECTION_SYSTEM, structure_prompt) total_tokens += tokens1 total_calls += 1 + except AgendaBudgetExceededError as e: + print(f"[PARADIGM] Stopped before structure detection: {e}", flush=True) + return [] except Exception as e: if _llm_temporarily_unavailable(e): print(f"[PARADIGM] Structure detection skipped: LLM unavailable ({e})", flush=True) @@ -423,6 +461,9 @@ def discover_paradigm_insights( result2, tokens2 = call_llm_json(FORMALIZATION_SYSTEM, formal_prompt) total_tokens += tokens2 total_calls += 1 + except AgendaBudgetExceededError as e: + print(f"[PARADIGM] Stopped at formalization: {e}", flush=True) + break except Exception as e: if _llm_temporarily_unavailable(e): print(f"[PARADIGM] Formalization paused: LLM unavailable ({e})", flush=True) @@ -446,6 +487,9 @@ def discover_paradigm_insights( ADVERSARIAL_SYSTEM, adversarial_prompt, provider_name="minimax") total_tokens += tokens3 total_calls += 1 + except AgendaBudgetExceededError as e: + print(f"[PARADIGM] Stopped at adversarial challenge: {e}", flush=True) + break except Exception as e: if _llm_temporarily_unavailable(e): print(f"[PARADIGM] Adversarial challenge skipped: LLM unavailable ({e})", flush=True) @@ -513,6 +557,7 @@ def discover_paradigm_insights( "novelty_status": "unchecked", "generation_tokens": total_tokens, "llm_calls": total_calls, + "agenda_id": agenda.agenda_id if agenda is not None else None, } # Also store minimal experiment info in experimental_plan @@ -576,8 +621,8 @@ def store_deep_insight(insight: dict) -> int: generation_tokens, llm_calls, generation_run_id, source_signal_ids, source_paper_ids, prompt_version, model_version, exemplars_used, - token_cost_usd, wall_clock_seconds, outcome) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + token_cost_usd, wall_clock_seconds, outcome, agenda_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) RETURNING id""", ( insight.get("tier", 1), @@ -619,6 +664,7 @@ def store_deep_insight(insight: dict) -> int: insight.get("token_cost_usd"), insight.get("wall_clock_seconds"), insight.get("outcome", "pending"), + insight.get("agenda_id"), ), ) db.commit() diff --git a/agents/signal_harvester.py b/agents/signal_harvester.py index 2eceb0a..d1d4292 100644 --- a/agents/signal_harvester.py +++ b/agents/signal_harvester.py @@ -14,6 +14,7 @@ from difflib import SequenceMatcher from contracts import DiscoverySignalBundle from db import database as db +from db.sql_dialect import escape_like from db.insight_outcomes import record_harvester_run @@ -708,20 +709,81 @@ def harvest_all() -> dict: return stats -def get_tier1_signals(top_overlaps: int = 20, top_patterns: int = 15) -> DiscoverySignalBundle: - """Assemble signals for Tier 1 (Paradigm Discovery) agent.""" - overlaps = db.fetchall( - """SELECT * FROM node_entity_overlap - ORDER BY overlap_score DESC LIMIT ?""", (top_overlaps,)) +def agenda_taxonomy_node_ids(keywords: list[str]) -> list[str]: + """Circle the taxonomy subgraph matching the given scope keywords. - pattern_ms = db.fetchall(""" - SELECT pm.*, pa.pattern_text as text_a, pa.pattern_type as type_a, - pb.pattern_text as text_b, pb.pattern_type as type_b - FROM pattern_matches pm - JOIN patterns pa ON pm.pattern_a_id = pa.id - JOIN patterns pb ON pm.pattern_b_id = pb.id - ORDER BY pm.similarity_score DESC LIMIT ? - """, (top_patterns,)) + Deterministic SQL LIKE over taxonomy node name/description; used to + restrict Tier 1/2 signal queries to one agenda's topical area. + """ + cleaned = [str(k).strip().lower() for k in (keywords or []) if str(k).strip()] + if not cleaned: + return [] + likes = " OR ".join( + ["LOWER(name || ' ' || COALESCE(description, '')) LIKE ? ESCAPE '\\'"] + * len(cleaned) + ) + rows = db.fetchall( + f"SELECT id FROM taxonomy_nodes WHERE {likes} ORDER BY id", + tuple(f"%{escape_like(kw)}%" for kw in cleaned), + ) + return [str(r["id"]) for r in rows] + + +def _node_in_clause(columns: list[str], node_ids: list[str]) -> tuple[str, list[str]]: + """Build '(col_a IN (...) OR col_b IN (...))' + params for node scoping.""" + placeholders = ", ".join(["?"] * len(node_ids)) + clause = " OR ".join(f"{col} IN ({placeholders})" for col in columns) + params: list[str] = [] + for _ in columns: + params.extend(node_ids) + return f"({clause})", params + + +def get_tier1_signals( + top_overlaps: int = 20, + top_patterns: int = 15, + *, + node_ids: list[str] | None = None, +) -> DiscoverySignalBundle: + """Assemble signals for Tier 1 (Paradigm Discovery) agent. + + node_ids, when given, restricts entity overlaps and pattern matches to the + taxonomy subgraph (agenda scoping); default None keeps the global view. + """ + if node_ids: + ov_clause, ov_params = _node_in_clause(["node_a_id", "node_b_id"], node_ids) + overlaps = db.fetchall( + f"""SELECT * FROM node_entity_overlap + WHERE {ov_clause} + ORDER BY overlap_score DESC LIMIT ?""", + (*ov_params, top_overlaps), + ) + pm_clause, pm_params = _node_in_clause(["pm.node_a_id", "pm.node_b_id"], node_ids) + pattern_ms = db.fetchall( + f""" + SELECT pm.*, pa.pattern_text as text_a, pa.pattern_type as type_a, + pb.pattern_text as text_b, pb.pattern_type as type_b + FROM pattern_matches pm + JOIN patterns pa ON pm.pattern_a_id = pa.id + JOIN patterns pb ON pm.pattern_b_id = pb.id + WHERE {pm_clause} + ORDER BY pm.similarity_score DESC LIMIT ? + """, + (*pm_params, top_patterns), + ) + else: + overlaps = db.fetchall( + """SELECT * FROM node_entity_overlap + ORDER BY overlap_score DESC LIMIT ?""", (top_overlaps,)) + + pattern_ms = db.fetchall(""" + SELECT pm.*, pa.pattern_text as text_a, pa.pattern_type as type_a, + pb.pattern_text as text_b, pb.pattern_type as type_b + FROM pattern_matches pm + JOIN patterns pa ON pm.pattern_a_id = pa.id + JOIN patterns pb ON pm.pattern_b_id = pb.id + ORDER BY pm.similarity_score DESC LIMIT ? + """, (top_patterns,)) clusters = db.fetchall( "SELECT * FROM contradiction_clusters WHERE cluster_size >= 2 ORDER BY cluster_size DESC") @@ -760,44 +822,94 @@ def get_tier2_signals( *, plateau_limit: int = 20, limitation_node_limit: int = 15, + node_ids: list[str] | None = None, + scope_keywords: list[str] | None = None, ) -> DiscoverySignalBundle: - """Assemble signals for Tier 2 (Paper-Ready Ideas) agent.""" + """Assemble signals for Tier 2 (Paper-Ready Ideas) agent. + + node_ids / scope_keywords, when given, restrict node-keyed signals + (plateaus, limitation clusters) and tier-1 insight seeds to one agenda's + topical area; defaults keep the global view. + """ clusters = db.fetchall( "SELECT * FROM contradiction_clusters ORDER BY cluster_size DESC") - plateaus = db.fetchall( - "SELECT * FROM performance_plateaus ORDER BY method_count DESC LIMIT ?", - (plateau_limit,), - ) + if node_ids: + pl_clause, pl_params = _node_in_clause(["node_id"], node_ids) + plateaus = db.fetchall( + f"SELECT * FROM performance_plateaus WHERE {pl_clause} " + "ORDER BY method_count DESC LIMIT ?", + (*pl_params, plateau_limit), + ) + lim_clause, lim_params = _node_in_clause(["node_id"], node_ids) + limitation_clusters = db.fetchall( + f""" + SELECT node_id, COUNT(*) as lim_count, + GROUP_CONCAT(paper_id) as paper_ids + FROM ( + SELECT pt.node_id, pi.paper_id + FROM paper_insights pi + JOIN paper_taxonomy pt ON pt.paper_id = pi.paper_id + WHERE pi.limitations IS NOT NULL AND pi.limitations != '[]' + ) + WHERE {lim_clause} + GROUP BY node_id + HAVING COUNT(*) >= 3 + ORDER BY lim_count DESC + LIMIT ? + """, + (*lim_params, limitation_node_limit), + ) + else: + plateaus = db.fetchall( + "SELECT * FROM performance_plateaus ORDER BY method_count DESC LIMIT ?", + (plateau_limit,), + ) - limitation_clusters = db.fetchall( - """ - SELECT node_id, COUNT(*) as lim_count, - GROUP_CONCAT(paper_id) as paper_ids - FROM ( - SELECT pt.node_id, pi.paper_id - FROM paper_insights pi - JOIN paper_taxonomy pt ON pt.paper_id = pi.paper_id - WHERE pi.limitations IS NOT NULL AND pi.limitations != '[]' + limitation_clusters = db.fetchall( + """ + SELECT node_id, COUNT(*) as lim_count, + GROUP_CONCAT(paper_id) as paper_ids + FROM ( + SELECT pt.node_id, pi.paper_id + FROM paper_insights pi + JOIN paper_taxonomy pt ON pt.paper_id = pi.paper_id + WHERE pi.limitations IS NOT NULL AND pi.limitations != '[]' + ) + GROUP BY node_id + HAVING COUNT(*) >= 3 + ORDER BY lim_count DESC + LIMIT ? + """, + (limitation_node_limit,), ) - GROUP BY node_id - HAVING COUNT(*) >= 3 - ORDER BY lim_count DESC - LIMIT ? - """, - (limitation_node_limit,), - ) try: + keyword_filter = "" + keyword_params: list[str] = [] + cleaned_keywords = [ + str(k).strip().lower() for k in (scope_keywords or []) if str(k).strip() + ] + if cleaned_keywords: + likes = " OR ".join( + [ + "LOWER(COALESCE(title, '') || ' ' " + "|| COALESCE(evidence_summary, '')) LIKE ? ESCAPE '\\'" + ] + * len(cleaned_keywords) + ) + keyword_filter = f" AND ({likes})" + keyword_params = [f"%{escape_like(kw)}%" for kw in cleaned_keywords] high_insights = db.fetchall( - """ + f""" SELECT id, title, mechanism_type, evidence_packet, adversarial_score, evidence_summary, experimental_plan, signal_mix, resource_class FROM deep_insights - WHERE tier=1 + WHERE tier=1{keyword_filter} ORDER BY COALESCE(adversarial_score, 0) DESC, created_at DESC LIMIT 10 - """ + """, + tuple(keyword_params), ) except Exception: high_insights = [] diff --git a/config.py b/config.py index 528351a..a34ce62 100644 --- a/config.py +++ b/config.py @@ -244,6 +244,15 @@ def _split_csv(value: str | list | tuple | None) -> list[str]: DISCOVERY_BULK_TIER2_LIMIT_NODES = _env_int("DEEPGRAPH_BULK_TIER2_LIMIT_NODES", 30, "discovery.bulk_tier2_limit_nodes") EVOSCI_VERIFY_TIMEOUT = _env_int("DEEPGRAPH_EVOSCI_VERIFY_TIMEOUT", 900, "experiment.evosci_verify_timeout") +# Research agendas: default per-agenda LLM token budget. Used when a +# research_agendas row has token_budget NULL. <= 0 disables the cap (no +# enforcement, never pauses an agenda) while usage accounting stays on: +# every scoped call still writes agenda_token_ledger and bumps token_spent. +# Default 0 = accounting only; set a positive value to enforce a cap. +AGENDA_TOKEN_BUDGET_DEFAULT = _env_int( + "DEEPGRAPH_AGENDA_TOKEN_BUDGET_DEFAULT", 0, "agenda.token_budget_default" +) + # SciForge Experiment Validation EXPERIMENT_TIME_BUDGET = _env_int("SCIFORGE_TIME_BUDGET", 300, "experiment.time_budget_seconds") EXPERIMENT_MAX_ITERATIONS = _env_int("SCIFORGE_MAX_ITERATIONS", 100, "experiment.max_iterations") diff --git a/contracts/agenda.py b/contracts/agenda.py index 4e37da7..5e7a25b 100644 --- a/contracts/agenda.py +++ b/contracts/agenda.py @@ -25,6 +25,13 @@ ) +VALID_AGENDA_STATUS = { + "active", + # Token budget exhausted; the agenda stops consuming LLM calls until the + # budget is raised or the resume endpoint is called. + "paused_budget", +} + VALID_REVIEW_RECOMMENDATIONS = { "accept", "minor_revision", @@ -63,6 +70,11 @@ class ResearchAgenda(ContractRecord): raw_config: dict[str, Any] = field(default_factory=dict) agenda_id: int | None = None is_active: bool = True + # Multi-agenda isolation + token budget (see db/schema_agenda.sql) + submitter: str = "" + token_budget: int | None = None + token_spent: int = 0 + status: str = "active" def validate(self) -> None: super().validate() @@ -73,6 +85,12 @@ def validate(self) -> None: self.reject = ensure_dict(self.reject) self.required_output = ensure_dict(self.required_output) self.raw_config = ensure_dict(self.raw_config) + self.token_budget = coerce_optional_int(self.token_budget) + self.token_spent = coerce_optional_int(self.token_spent) or 0 + if self.status not in VALID_AGENDA_STATUS: + raise ContractValidationError( + f"ResearchAgenda status '{self.status}' not in {sorted(VALID_AGENDA_STATUS)}" + ) if not self.focus and not self.prefer: raise ContractValidationError( "ResearchAgenda needs at least one focus keyword or prefer rule" diff --git a/db/database.py b/db/database.py index 83fca77..af3f39a 100644 --- a/db/database.py +++ b/db/database.py @@ -360,6 +360,36 @@ def _ensure_vnext_migrations() -> None: conn.commit() +def _ensure_agenda_isolation_schema() -> None: + """Multi-agenda isolation: budget columns + per-agenda insight tagging (existing DBs). + + New columns also appear in schema_agenda*.sql for fresh databases; this + migration upgrades databases created before they were added. + """ + _ensure_columns( + "research_agendas", + { + "submitter": "TEXT", + "token_budget": "INTEGER", + "token_spent": "INTEGER DEFAULT 0", + "status": "TEXT DEFAULT 'active'", + }, + ) + _ensure_columns( + "deep_insights", + { + "agenda_id": "INTEGER", + }, + ) + conn = get_conn() + _execute_startup_statement( + conn, + "CREATE INDEX IF NOT EXISTS idx_deep_insights_agenda ON deep_insights(agenda_id)", + best_effort_if_locked=_use_pg(), + ) + conn.commit() + + def _ensure_grounding_schema() -> None: """Add cite-and-verify columns for claims and results (existing DBs).""" _ensure_columns( @@ -652,6 +682,7 @@ def init_db(): except Exception: pass _ensure_vnext_migrations() + _ensure_agenda_isolation_schema() _ensure_grounding_schema() schema_feedback = Path(__file__).parent / "schema_insight_feedback.sql" if schema_feedback.exists(): @@ -702,6 +733,7 @@ def init_db(): if schema_format_lint_path.exists(): conn.executescript(schema_format_lint_path.read_text(encoding="utf-8")) _ensure_vnext_migrations() + _ensure_agenda_isolation_schema() _ensure_grounding_schema() schema_feedback = Path(__file__).parent / "schema_insight_feedback.sql" if schema_feedback.exists(): diff --git a/db/schema_agenda.sql b/db/schema_agenda.sql index 1e88d49..07a57b8 100644 --- a/db/schema_agenda.sql +++ b/db/schema_agenda.sql @@ -13,6 +13,15 @@ CREATE TABLE IF NOT EXISTS research_agendas ( required_output_json TEXT NOT NULL DEFAULT '{}', raw_config_json TEXT, is_active INTEGER NOT NULL DEFAULT 1, + -- Multi-agenda isolation + token budgets: + -- submitter: who asked for this direction (nickname / email) + -- token_budget: max LLM tokens this agenda may spend (NULL -> config default) + -- token_spent: running total, maintained by agents.agenda_budget + -- status: 'active' | 'paused_budget' (budget exhausted, resume to continue) + submitter TEXT, + token_budget INTEGER, + token_spent INTEGER NOT NULL DEFAULT 0, + status TEXT NOT NULL DEFAULT 'active', created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP ); @@ -97,3 +106,18 @@ CREATE INDEX IF NOT EXISTS idx_agenda_evidence_gates_selection ON agenda_evidence_gates(selection_id, created_at DESC); CREATE INDEX IF NOT EXISTS idx_agenda_evidence_gates_status ON agenda_evidence_gates(status); + +-- Per-agenda LLM token accounting. One row per metered call; token_spent on +-- research_agendas is the running aggregate (kept in sync by agents.agenda_budget). +CREATE TABLE IF NOT EXISTS agenda_token_ledger ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + agenda_id INTEGER NOT NULL, + operation TEXT NOT NULL, + tokens INTEGER NOT NULL DEFAULT 0, + cost_usd REAL, + created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (agenda_id) REFERENCES research_agendas(id) +); + +CREATE INDEX IF NOT EXISTS idx_agenda_token_ledger_agenda + ON agenda_token_ledger(agenda_id, created_at DESC); diff --git a/db/schema_agenda_postgres.sql b/db/schema_agenda_postgres.sql index 3dc8a44..a87e164 100644 --- a/db/schema_agenda_postgres.sql +++ b/db/schema_agenda_postgres.sql @@ -13,6 +13,11 @@ CREATE TABLE IF NOT EXISTS research_agendas ( required_output_json TEXT NOT NULL DEFAULT '{}', raw_config_json TEXT, is_active INTEGER NOT NULL DEFAULT 1, + -- Multi-agenda isolation + token budgets (see schema_agenda.sql for docs) + submitter TEXT, + token_budget INTEGER, + token_spent INTEGER NOT NULL DEFAULT 0, + status TEXT NOT NULL DEFAULT 'active', created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ); @@ -88,3 +93,16 @@ CREATE INDEX IF NOT EXISTS idx_agenda_evidence_gates_selection ON agenda_evidence_gates(selection_id, created_at DESC); CREATE INDEX IF NOT EXISTS idx_agenda_evidence_gates_status ON agenda_evidence_gates(status); + +-- Per-agenda LLM token accounting (see schema_agenda.sql for docs) +CREATE TABLE IF NOT EXISTS agenda_token_ledger ( + id BIGSERIAL PRIMARY KEY, + agenda_id INTEGER NOT NULL, + operation TEXT NOT NULL, + tokens INTEGER NOT NULL DEFAULT 0, + cost_usd REAL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX IF NOT EXISTS idx_agenda_token_ledger_agenda + ON agenda_token_ledger(agenda_id, created_at DESC); diff --git a/db/sql_dialect.py b/db/sql_dialect.py index cb91f70..f2831e8 100644 --- a/db/sql_dialect.py +++ b/db/sql_dialect.py @@ -5,6 +5,23 @@ import re +def escape_like(term: str) -> str: + """Escape LIKE wildcards in a user-supplied term bound as a parameter. + + `%`, `_` and the escape character `\\` itself are prefixed with `\\` so + the term matches literally instead of widening the pattern. The SQL clause + must declare the escape character explicitly -- write ``LIKE ? ESCAPE '\\'`` + -- which is valid on both SQLite and PostgreSQL (PostgreSQL defaults to + backslash; SQLite has no default escape character). + """ + return ( + str(term) + .replace("\\", "\\\\") + .replace("%", "\\%") + .replace("_", "\\_") + ) + + def to_postgres(sql: str) -> str: """Best-effort ? -> %s and common SQLite idioms. Review generated SQL for edge cases.""" out = sql.replace("?", "%s") diff --git a/deploy/agenda-inbox-watcher.example b/deploy/agenda-inbox-watcher.example new file mode 100644 index 0000000..69ec01c --- /dev/null +++ b/deploy/agenda-inbox-watcher.example @@ -0,0 +1,35 @@ +# Run the agenda inbox watcher on a schedule (pick ONE of the two options). +# Adjust WorkingDirectory / paths / user to your deployment. + +# --- Option A: systemd service + timer ------------------------------------- +# /etc/systemd/system/deepgraph-agenda-inbox.service +# +# [Unit] +# Description=DeepGraph agenda inbox watcher (single scan) +# After=network.target +# +# [Service] +# Type=oneshot +# User=deepgraph +# WorkingDirectory=/opt/deepgraph +# ExecStart=/usr/bin/python3 -m scripts.agenda_inbox_watcher --once +# +# /etc/systemd/system/deepgraph-agenda-inbox.timer +# +# [Unit] +# Description=Scan the agenda inbox every minute +# +# [Timer] +# OnBootSec=2min +# OnUnitActiveSec=1min +# +# [Install] +# WantedBy=timers.target +# +# Enable with: +# systemctl enable --now deepgraph-agenda-inbox.timer + +# --- Option B: cron --------------------------------------------------------- +# crontab -e (as the deepgraph user): +# +# * * * * * cd /opt/deepgraph && /usr/bin/python3 -m scripts.agenda_inbox_watcher --once >> /var/log/deepgraph/agenda-inbox.log 2>&1 diff --git a/orchestrator/discovery_scheduler.py b/orchestrator/discovery_scheduler.py index 600a4b7..6979e6a 100644 --- a/orchestrator/discovery_scheduler.py +++ b/orchestrator/discovery_scheduler.py @@ -48,11 +48,23 @@ def _init_schema_v2(): db.init_db() -def harvest_signals() -> dict: - """Run signal harvesting (SQL only, no LLM cost).""" +def harvest_signals(agenda=None) -> dict: + """Run signal harvesting (SQL only, no LLM cost). + + The optional agenda is accepted for call-site symmetry with the Tier 1/2 + entries. Signal tables are shared whole-corpus aggregates (no per-agenda + column), so harvesting itself stays global; agenda scoping is applied + where the signals are consumed (get_tier1_signals / get_tier2_signals). + """ _init_schema_v2() from agents.signal_harvester import harvest_all - log_event("discovery", {"step": "signal_harvest_start"}) + log_event( + "discovery", + { + "step": "signal_harvest_start", + **({"agenda_id": agenda.agenda_id} if agenda is not None else {}), + }, + ) stats = harvest_all() log_event("discovery", {"step": "signal_harvest_done", **stats}) return stats @@ -62,9 +74,17 @@ def run_tier1_discovery( max_candidates: int | None = None, *, bulk: bool = False, + agenda=None, ) -> list[dict]: - """Run Tier 1 (Paradigm) discovery. Returns stored insight IDs.""" + """Run Tier 1 (Paradigm) discovery. Returns stored insight IDs. + + With an agenda (contracts.agenda.ResearchAgenda) the scan is restricted to + the agenda's taxonomy subgraph, LLM spend is metered against its token + budget, and stored insights carry agenda_id. Without one, behavior is + unchanged. + """ _init_schema_v2() + from agents.agenda_budget import agenda_scope from agents.paradigm_agent import discover_paradigm_insights, store_deep_insight if max_candidates is None: @@ -82,16 +102,26 @@ def run_tier1_discovery( "bulk": bulk, "signal_overlaps": top_ov, "signal_patterns": top_pat, + **({"agenda_id": agenda.agenda_id} if agenda is not None else {}), }, ) print("[DISCOVERY] Starting Tier 1 (Paradigm) discovery...", flush=True) try: - insights = discover_paradigm_insights( - max_candidates=max_candidates, - tier1_top_overlaps=top_ov, - tier1_top_patterns=top_pat, - ) + if agenda is not None: + with agenda_scope(agenda.agenda_id, "tier1_discovery"): + insights = discover_paradigm_insights( + max_candidates=max_candidates, + tier1_top_overlaps=top_ov, + tier1_top_patterns=top_pat, + agenda=agenda, + ) + else: + insights = discover_paradigm_insights( + max_candidates=max_candidates, + tier1_top_overlaps=top_ov, + tier1_top_patterns=top_pat, + ) stored = [] for ins in insights: insight_id = store_deep_insight(ins) @@ -131,12 +161,16 @@ def run_tier2_discovery( max_papers: int | None = None, *, bulk: bool = False, + agenda=None, ) -> list[dict]: """Run Tier 2 (Paper Ideas) discovery. Returns stored insight IDs. In bulk mode, expands every sharpened problem (max_papers follows max_problems). + With an agenda the scan is restricted to its taxonomy subgraph, LLM spend + is metered against its token budget, and stored ideas carry agenda_id. """ _init_schema_v2() + from agents.agenda_budget import agenda_scope from agents.paradigm_agent import store_deep_insight from agents.paper_idea_agent import discover_paper_ideas @@ -153,17 +187,32 @@ def run_tier2_discovery( log_event( "discovery", - {"step": "tier2_start", "max_problems": max_problems, "bulk": bulk}, + { + "step": "tier2_start", + "max_problems": max_problems, + "bulk": bulk, + **({"agenda_id": agenda.agenda_id} if agenda is not None else {}), + }, ) print("[DISCOVERY] Starting Tier 2 (Paper Ideas) discovery...", flush=True) try: - insights = discover_paper_ideas( - max_problems=max_problems, - max_papers=mpapers, - tier2_plateau_limit=plateaus, - tier2_limitation_nodes=lim_nodes, - ) + if agenda is not None: + with agenda_scope(agenda.agenda_id, "tier2_discovery"): + insights = discover_paper_ideas( + max_problems=max_problems, + max_papers=mpapers, + tier2_plateau_limit=plateaus, + tier2_limitation_nodes=lim_nodes, + agenda=agenda, + ) + else: + insights = discover_paper_ideas( + max_problems=max_problems, + max_papers=mpapers, + tier2_plateau_limit=plateaus, + tier2_limitation_nodes=lim_nodes, + ) stored = [] for ins in insights: insight_id = store_deep_insight(ins) @@ -209,21 +258,27 @@ def run_full_discovery( tier2_papers: int | None = None, *, bulk: bool = False, + agenda=None, ) -> dict: """Run the full discovery pipeline: harvest → Tier 1 → Tier 2.""" results = {"started_at": datetime.utcnow().isoformat(), "bulk": bulk} + if agenda is not None: + results["agenda_id"] = agenda.agenda_id # Step 1: Harvest signals - results["signals"] = harvest_signals() + results["signals"] = harvest_signals(agenda=agenda) # Step 2: Tier 1 - results["tier1"] = run_tier1_discovery(max_candidates=tier1_candidates, bulk=bulk) + results["tier1"] = run_tier1_discovery( + max_candidates=tier1_candidates, bulk=bulk, agenda=agenda + ) # Step 3: Tier 2 results["tier2"] = run_tier2_discovery( max_problems=tier2_problems, max_papers=tier2_papers, bulk=bulk, + agenda=agenda, ) results["completed_at"] = datetime.utcnow().isoformat() diff --git a/research_agendas/inbox/README.md b/research_agendas/inbox/README.md new file mode 100644 index 0000000..f3df520 --- /dev/null +++ b/research_agendas/inbox/README.md @@ -0,0 +1,15 @@ +# Direction inbox + +Drop a direction file (`*.yaml`, schema below) into this directory, then run `python3 -m scripts.agenda_inbox_watcher --once` (or let the systemd timer / cron job do it — see `deploy/agenda-inbox-watcher.example`). +Processed files move to `processed/` with a `.echo.json` confirmation next to them; rejected files move to `failed/` with a `.error.txt` explaining why. + +```yaml +direction: "用扩散模型做小样本医学影像分割,关注跨中心泛化" # required, natural language +keywords: [medical imaging, diffusion, few-shot] # optional +constraints: # optional, free text + compute: "单卡以内" + data: "仅公开数据集" +goal: experiment_plan # idea_only | experiment_plan | signal | verified_evidence +contact: "nickname or email" # required +token_budget: 500000 # optional LLM token cap; default 0 = no cap (usage is still recorded) +``` diff --git a/scripts/agenda_inbox_watcher.py b/scripts/agenda_inbox_watcher.py new file mode 100644 index 0000000..be32ec2 --- /dev/null +++ b/scripts/agenda_inbox_watcher.py @@ -0,0 +1,137 @@ +"""Watch research_agendas/inbox/ for user direction submissions (.yaml). + +For each new YAML file: +- reject files larger than MAX_SUBMISSION_BYTES without reading their content, +- parse it with agents.direction_intake (deterministic, no LLM), +- persist the mapped agenda via agents.agenda_loader (direct DB write — works + whether or not the web app is running), +- write .echo.json (what the system understood, for the submitter), +- move the file to inbox/processed/ on success or inbox/failed/ (plus + .error.txt) on failure. + +Usage: + python3 -m scripts.agenda_inbox_watcher --once # single scan + python3 -m scripts.agenda_inbox_watcher # poll every 60s + python3 -m scripts.agenda_inbox_watcher --interval 30 + +See deploy/agenda-inbox-watcher.example for a systemd timer / cron setup. +""" + +from __future__ import annotations + +import argparse +import json +import shutil +import sys +import time +from datetime import datetime, timezone +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parent.parent +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +DEFAULT_INBOX = REPO_ROOT / "research_agendas" / "inbox" + +# Direction files are short YAML documents; anything bigger is malformed or +# abusive. Oversized files are quarantined without their content being read. +MAX_SUBMISSION_BYTES = 256 * 1024 + + +def _unique_target(directory: Path, name: str) -> Path: + """Destination path inside directory; suffix a timestamp on collision.""" + target = directory / name + if not target.exists(): + return target + stamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S") + p = Path(name) + return directory / f"{p.stem}.{stamp}{p.suffix}" + + +def process_file(path: Path, processed_dir: Path, failed_dir: Path) -> dict: + """Process one submission file. Returns a small result dict for logging.""" + from agents import agenda_loader + from agents.direction_intake import DirectionParseError, parse_direction_yaml + + try: + size = path.stat().st_size + if size > MAX_SUBMISSION_BYTES: + raise DirectionParseError( + f"file too large: {size} bytes (limit {MAX_SUBMISSION_BYTES}); " + "content not read" + ) + text = path.read_text(encoding="utf-8") + agenda, echo = parse_direction_yaml(text) + agenda_id = agenda_loader.save_agenda(agenda) + echo["agenda_id"] = agenda_id + target = _unique_target(processed_dir, path.name) + echo_path = target.parent / (target.name + ".echo.json") + echo_path.write_text( + json.dumps(echo, ensure_ascii=False, indent=2) + "\n", encoding="utf-8" + ) + shutil.move(str(path), str(target)) + print(f"[INBOX] OK {path.name} -> agenda #{agenda_id} ({agenda.name})", flush=True) + return {"file": path.name, "status": "ok", "agenda_id": agenda_id} + except (DirectionParseError, OSError, UnicodeDecodeError) as exc: + reason = str(exc) + except Exception as exc: # noqa: BLE001 - keep the watcher alive + reason = f"{type(exc).__name__}: {exc}" + + target = _unique_target(failed_dir, path.name) + try: + shutil.move(str(path), str(target)) + (target.parent / (target.name + ".error.txt")).write_text( + reason + "\n", encoding="utf-8" + ) + except OSError as move_err: + print(f"[INBOX] Could not quarantine {path.name}: {move_err}", flush=True) + print(f"[INBOX] FAILED {path.name}: {reason}", flush=True) + return {"file": path.name, "status": "failed", "error": reason} + + +def scan_inbox(inbox: Path) -> list[dict]: + """One scan pass over the inbox. Creates the directory layout if missing.""" + processed_dir = inbox / "processed" + failed_dir = inbox / "failed" + for d in (inbox, processed_dir, failed_dir): + d.mkdir(parents=True, exist_ok=True) + + results = [] + for path in sorted(inbox.iterdir()): + if not path.is_file() or path.suffix.lower() not in (".yaml", ".yml"): + continue + results.append(process_file(path, processed_dir, failed_dir)) + return results + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser(description=__doc__.split("\n", 1)[0]) + parser.add_argument("--once", action="store_true", help="scan once and exit") + parser.add_argument( + "--interval", type=int, default=60, help="poll interval in seconds (default 60)" + ) + parser.add_argument( + "--inbox", type=Path, default=DEFAULT_INBOX, help="inbox directory to watch" + ) + args = parser.parse_args(argv) + + from db import database as db + + db.init_db() + + if args.once: + results = scan_inbox(args.inbox) + print(f"[INBOX] Scan done: {len(results)} file(s) handled", flush=True) + return 0 + + print( + f"[INBOX] Watching {args.inbox} every {args.interval}s (Ctrl-C to stop)", + flush=True, + ) + while True: + scan_inbox(args.inbox) + time.sleep(max(1, args.interval)) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/test_agenda_budget.py b/tests/test_agenda_budget.py new file mode 100644 index 0000000..4a6ddad --- /dev/null +++ b/tests/test_agenda_budget.py @@ -0,0 +1,346 @@ +"""Tests for per-agenda token budgets (agents.agenda_budget + llm_client hook). + +Covers ledger accounting, pause on budget exhaustion, the check-before-call +guard (no provider request once paused), resume (explicit endpoint or raised +budget), and the /api/token_usage summary endpoint. +""" + +from __future__ import annotations + +import os +import tempfile +import unittest +from pathlib import Path + +os.environ.setdefault("DEEPGRAPH_DATABASE_URL", "") + + +SAMPLE_AGENDA = { + "version": "v1", + "name": "budget_test_agenda", + "focus": ["long context"], + "prefer": {"keywords": ["linear attention"]}, + "submitter": "alice", + "token_budget": 1500, +} + + +class BudgetTestBase(unittest.TestCase): + def setUp(self): + self._tmpdir = tempfile.TemporaryDirectory() + os.environ["DEEPGRAPH_DB_PATH"] = str(Path(self._tmpdir.name) / "test.db") + from db import database as db + + self._original_db_path = db.DB_PATH + for attr in ("sqlite_conn", "pg_conn", "conn"): + if hasattr(db._local, attr): + try: + getattr(db._local, attr).close() + except Exception: + pass + delattr(db._local, attr) + db.DB_PATH = Path(os.environ["DEEPGRAPH_DB_PATH"]) + db.init_db() + self.db = db + + def tearDown(self): + from db import database as db + + for attr in ("sqlite_conn", "pg_conn", "conn"): + if hasattr(db._local, attr): + try: + getattr(db._local, attr).close() + except Exception: + pass + delattr(db._local, attr) + db.DB_PATH = self._original_db_path + self._tmpdir.cleanup() + os.environ.pop("DEEPGRAPH_DB_PATH", None) + + def _save_agenda(self, **overrides): + from agents.agenda_loader import parse_agenda, save_agenda + + agenda = parse_agenda({**SAMPLE_AGENDA, **overrides}) + save_agenda(agenda) + return agenda + + +class BudgetAccountingTests(BudgetTestBase): + def test_default_budget_comes_from_config(self): + from agents import agenda_budget + from config import AGENDA_TOKEN_BUDGET_DEFAULT + + agenda = self._save_agenda(name="no_budget", token_budget=None) + state = agenda_budget.get_budget_state(agenda.agenda_id) + self.assertEqual(state["token_budget"], AGENDA_TOKEN_BUDGET_DEFAULT) + self.assertEqual(state["token_spent"], 0) + self.assertEqual(state["status"], "active") + + def test_record_usage_accumulates_and_writes_ledger(self): + from agents import agenda_budget + + agenda = self._save_agenda() + agenda_budget.record_usage(agenda.agenda_id, "tier1_discovery", 400) + agenda_budget.record_usage(agenda.agenda_id, "tier2_discovery", 300) + + state = agenda_budget.get_budget_state(agenda.agenda_id) + self.assertEqual(state["token_spent"], 700) + self.assertEqual(state["status"], "active") + + rows = self.db.fetchall( + "SELECT operation, tokens FROM agenda_token_ledger WHERE agenda_id=? ORDER BY id", + (agenda.agenda_id,), + ) + self.assertEqual( + [(r["operation"], r["tokens"]) for r in rows], + [("tier1_discovery", 400), ("tier2_discovery", 300)], + ) + + def test_exceeding_budget_pauses_agenda_and_blocks_checks(self): + from agents import agenda_budget + + agenda = self._save_agenda() # budget 1500 + agenda_budget.record_usage(agenda.agenda_id, "op", 1600) + + state = agenda_budget.get_budget_state(agenda.agenda_id) + self.assertEqual(state["status"], "paused_budget") + with self.assertRaises(agenda_budget.AgendaBudgetExceededError): + agenda_budget.check_budget(agenda.agenda_id) + + def test_resume_reactivates(self): + from agents import agenda_budget + + agenda = self._save_agenda() + agenda_budget.record_usage(agenda.agenda_id, "op", 1600) + # Resume with a raised budget: checks pass again + state = agenda_budget.resume_agenda(agenda.agenda_id, token_budget=5000) + self.assertEqual(state["status"], "active") + agenda_budget.check_budget(agenda.agenda_id) # must not raise + + def test_raising_budget_alone_unblocks(self): + from agents import agenda_budget + + agenda = self._save_agenda() + agenda_budget.record_usage(agenda.agenda_id, "op", 1600) + self.db.execute( + "UPDATE research_agendas SET token_budget=10000 WHERE id=?", + (agenda.agenda_id,), + ) + self.db.commit() + agenda_budget.check_budget(agenda.agenda_id) # must not raise + state = agenda_budget.get_budget_state(agenda.agenda_id) + self.assertEqual(state["status"], "active") + + def test_zero_budget_disables_cap(self): + from agents import agenda_budget + + agenda = self._save_agenda(name="uncapped", token_budget=0) + agenda_budget.record_usage(agenda.agenda_id, "op", 10_000_000) + agenda_budget.check_budget(agenda.agenda_id) # must not raise + + def test_zero_budget_still_accounts_usage(self): + """budget 0 = no cap, never paused, but ledger + token_spent keep counting.""" + from agents import agenda_budget + + agenda = self._save_agenda(name="uncapped_accounting", token_budget=0) + agenda_budget.record_usage(agenda.agenda_id, "tier1_discovery", 4_000_000) + agenda_budget.record_usage(agenda.agenda_id, "tier2_discovery", 6_000_000) + agenda_budget.check_budget(agenda.agenda_id) # must not raise + + state = agenda_budget.get_budget_state(agenda.agenda_id) + self.assertEqual(state["token_spent"], 10_000_000) + self.assertEqual(state["status"], "active") # never paused_budget + self.assertFalse(state["exhausted"]) + rows = self.db.fetchall( + "SELECT operation, tokens FROM agenda_token_ledger WHERE agenda_id=? ORDER BY id", + (agenda.agenda_id,), + ) + self.assertEqual( + [(r["operation"], r["tokens"]) for r in rows], + [("tier1_discovery", 4_000_000), ("tier2_discovery", 6_000_000)], + ) + + def test_null_budget_with_zero_default_is_uncapped(self): + """token_budget NULL + default 0: enforcement off, accounting on.""" + import unittest as _ut + + from agents import agenda_budget + from config import AGENDA_TOKEN_BUDGET_DEFAULT + + if AGENDA_TOKEN_BUDGET_DEFAULT > 0: + raise _ut.SkipTest("env overrides the shipped default of 0") + + agenda = self._save_agenda(name="null_budget", token_budget=None) + agenda_budget.record_usage(agenda.agenda_id, "op", 3_000_000) + agenda_budget.check_budget(agenda.agenda_id) # must not raise + state = agenda_budget.get_budget_state(agenda.agenda_id) + self.assertEqual(state["token_spent"], 3_000_000) + self.assertEqual(state["status"], "active") + + +class LlmClientHookTests(BudgetTestBase): + """call_llm must meter scoped calls and stop cleanly once the budget is spent.""" + + def setUp(self): + super().setUp() + import agents.llm_client as llm + + self.llm = llm + self._saved = ( + llm._providers, + llm._provider_stats, + llm._rate_limiters, + llm._provider_cooldown, + llm._call_provider, + ) + llm._providers = [ + { + "name": "fake", + "base_url": "http://localhost", + "api_key": "k", + "model": "m", + "protocol": "chat_completions", + "rpm": 0, + } + ] + llm._provider_stats = { + "fake": { + "calls": 0, "tokens": 0, "errors": 0, "total_latency": 0, + "in_flight": 0, "cached_tokens": 0, "input_tokens": 0, + } + } + llm._rate_limiters = {} + llm._provider_cooldown = {} + self.provider_calls = [] + + def _fake_call(provider, system_prompt, user_prompt, max_tokens): + self.provider_calls.append(provider["name"]) + return ("response text long enough to count", 1000, 0, 500) + + llm._call_provider = _fake_call + + def tearDown(self): + llm = self.llm + ( + llm._providers, + llm._provider_stats, + llm._rate_limiters, + llm._provider_cooldown, + llm._call_provider, + ) = self._saved + super().tearDown() + + def test_scoped_calls_are_metered_and_capped(self): + from agents import agenda_budget + + agenda = self._save_agenda() # budget 1500, fake call = 1000 tokens + + with agenda_budget.agenda_scope(agenda.agenda_id, "unit_test"): + self.llm.call_llm("sys", "user") # spend 1000 (< 1500: allowed) + self.llm.call_llm("sys", "user") # spend 2000 (pauses the agenda) + with self.assertRaises(agenda_budget.AgendaBudgetExceededError): + self.llm.call_llm("sys", "user") # blocked BEFORE the provider + + # Third call never reached the provider: clean stop, no token spent + self.assertEqual(len(self.provider_calls), 2) + state = agenda_budget.get_budget_state(agenda.agenda_id) + self.assertEqual(state["token_spent"], 2000) + self.assertEqual(state["status"], "paused_budget") + ledger = self.db.fetchall( + "SELECT operation, tokens FROM agenda_token_ledger WHERE agenda_id=?", + (agenda.agenda_id,), + ) + self.assertEqual(len(ledger), 2) + self.assertEqual(ledger[0]["operation"], "unit_test") + + def test_resume_allows_further_calls(self): + from agents import agenda_budget + + agenda = self._save_agenda() + with agenda_budget.agenda_scope(agenda.agenda_id, "unit_test"): + self.llm.call_llm("sys", "user") + self.llm.call_llm("sys", "user") + with self.assertRaises(agenda_budget.AgendaBudgetExceededError): + self.llm.call_llm("sys", "user") + + agenda_budget.resume_agenda(agenda.agenda_id, token_budget=10_000) + self.llm.call_llm("sys", "user") # works again + + self.assertEqual(len(self.provider_calls), 3) + state = agenda_budget.get_budget_state(agenda.agenda_id) + self.assertEqual(state["token_spent"], 3000) + self.assertEqual(state["status"], "active") + + def test_unscoped_calls_not_metered(self): + self.llm.call_llm("sys", "user") + rows = self.db.fetchall("SELECT * FROM agenda_token_ledger") + self.assertEqual(rows, []) + + +class BudgetRoutesTests(BudgetTestBase): + def setUp(self): + super().setUp() + from web import app as app_module + + self.client = app_module.app.test_client() + + def test_resume_endpoint_and_token_usage(self): + from agents import agenda_budget + + agenda = self._save_agenda() + agenda_budget.record_usage(agenda.agenda_id, "tier1_discovery", 1600) + + r = self.client.get("/api/token_usage") + self.assertEqual(r.status_code, 200) + body = r.get_json() + self.assertEqual(body["totals"]["tokens"], 1600) + row = next(a for a in body["agendas"] if a["agenda_id"] == agenda.agenda_id) + self.assertEqual(row["status"], "paused_budget") + self.assertEqual(row["token_spent"], 1600) + self.assertEqual(row["ledger_entries"], 1) + + r = self.client.post( + f"/api/research_agenda/{agenda.agenda_id}/resume", + json={"token_budget": 5000}, + ) + self.assertEqual(r.status_code, 200) + self.assertEqual(r.get_json()["budget"]["status"], "active") + self.assertEqual(r.get_json()["budget"]["token_budget"], 5000) + + def test_resume_endpoint_validates_input(self): + agenda = self._save_agenda() + r = self.client.post( + f"/api/research_agenda/{agenda.agenda_id}/resume", + json={"token_budget": "lots"}, + ) + self.assertEqual(r.status_code, 400) + r = self.client.post("/api/research_agenda/999999/resume", json={}) + self.assertEqual(r.status_code, 404) + + def test_agenda_insights_endpoint_isolated(self): + agenda_a = self._save_agenda(name="iso_a") + agenda_b = self._save_agenda(name="iso_b") + self.db.execute( + "INSERT INTO deep_insights (id, tier, status, title, agenda_id) " + "VALUES (1, 2, 'candidate', 'insight for A', ?)", + (agenda_a.agenda_id,), + ) + self.db.execute( + "INSERT INTO deep_insights (id, tier, status, title, agenda_id) " + "VALUES (2, 2, 'candidate', 'insight for B', ?)", + (agenda_b.agenda_id,), + ) + self.db.commit() + + r = self.client.get(f"/api/research_agenda/{agenda_a.agenda_id}/insights") + self.assertEqual(r.status_code, 200) + titles = [i["title"] for i in r.get_json()["insights"]] + self.assertEqual(titles, ["insight for A"]) + + r = self.client.get(f"/api/research_agenda/{agenda_b.agenda_id}/selections") + self.assertEqual(r.status_code, 200) + self.assertEqual(r.get_json()["selections"], []) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_agenda_contract.py b/tests/test_agenda_contract.py index e8647f5..6f64289 100644 --- a/tests/test_agenda_contract.py +++ b/tests/test_agenda_contract.py @@ -216,15 +216,25 @@ def test_save_and_get_agenda(self): rows = list_agendas() self.assertEqual(len(rows), 1) - def test_active_flag_is_exclusive(self): - from agents.agenda_loader import get_active_agenda, parse_agenda, save_agenda + def test_multiple_agendas_stay_active(self): + from agents.agenda_loader import ( + get_active_agenda, + get_agenda, + list_agendas, + parse_agenda, + save_agenda, + ) a1 = save_agenda(parse_agenda(dict(SAMPLE_AGENDA_DICT, name="a1"))) a2 = save_agenda(parse_agenda(dict(SAMPLE_AGENDA_DICT, name="a2"))) - # second insert with is_active=True should become the sole active row - active = get_active_agenda() - self.assertEqual(active.agenda_id, a2) - self.assertNotEqual(active.agenda_id, a1) + # Both agendas run concurrently; saving the second one must not + # deactivate the first (isolation is per agenda_id, not per flag). + self.assertTrue(get_agenda(a1).is_active) + self.assertTrue(get_agenda(a2).is_active) + active_ids = {a.agenda_id for a in list_agendas(only_active=True)} + self.assertEqual(active_ids, {a1, a2}) + # The single-agenda convenience accessor returns the newest active row. + self.assertEqual(get_active_agenda().agenda_id, a2) def test_schema_tables_present(self): from db import database as db diff --git a/tests/test_agenda_scope.py b/tests/test_agenda_scope.py new file mode 100644 index 0000000..d663ad6 --- /dev/null +++ b/tests/test_agenda_scope.py @@ -0,0 +1,245 @@ +"""Acceptance tests for agenda-scoped candidate circling (multi-agenda isolation). + +Two agendas in different domains (medical imaging vs cryptography) must not +leak into each other's candidate pool: scoped queries return only records from +the agenda's own keyword domain or tagged with its agenda_id. +""" + +from __future__ import annotations + +import os +import tempfile +import unittest +from pathlib import Path + +os.environ.setdefault("DEEPGRAPH_DATABASE_URL", "") + + +MEDICAL_AGENDA = { + "version": "v1", + "name": "medical_imaging_v1", + "description": "few-shot medical image segmentation", + "focus": ["medical imaging", "segmentation"], + "prefer": {"keywords": ["diffusion model"]}, +} + +CRYPTO_AGENDA = { + "version": "v1", + "name": "cryptography_v1", + "description": "post-quantum cryptography", + "focus": ["cryptography", "lattice"], + "prefer": {"keywords": ["post-quantum"]}, +} + +MEDICAL_TERMS = ("medical", "segmentation", "diffusion", "imaging") +CRYPTO_TERMS = ("cryptography", "lattice", "post-quantum", "encryption") + + +class AgendaScopeTestBase(unittest.TestCase): + def setUp(self): + self._tmpdir = tempfile.TemporaryDirectory() + os.environ["DEEPGRAPH_DB_PATH"] = str(Path(self._tmpdir.name) / "test.db") + from db import database as db + + self._original_db_path = db.DB_PATH + for attr in ("sqlite_conn", "pg_conn", "conn"): + if hasattr(db._local, attr): + try: + getattr(db._local, attr).close() + except Exception: + pass + delattr(db._local, attr) + db.DB_PATH = Path(os.environ["DEEPGRAPH_DB_PATH"]) + db.init_db() + self.db = db + + def tearDown(self): + from db import database as db + + for attr in ("sqlite_conn", "pg_conn", "conn"): + if hasattr(db._local, attr): + try: + getattr(db._local, attr).close() + except Exception: + pass + delattr(db._local, attr) + db.DB_PATH = self._original_db_path + self._tmpdir.cleanup() + os.environ.pop("DEEPGRAPH_DB_PATH", None) + + def _seed_insight(self, insight_id, title, problem, agenda_id=None): + self.db.execute( + """ + INSERT INTO deep_insights + (id, tier, status, title, problem_statement, adversarial_score, + novelty_status, resource_class, experimentability, agenda_id) + VALUES (?, 2, 'candidate', ?, ?, 7.0, 'novel', 'cpu', 'easy', ?) + """, + (insight_id, title, problem, agenda_id), + ) + self.db.commit() + + +class InsightPoolIsolationTests(AgendaScopeTestBase): + def setUp(self): + super().setUp() + from agents.agenda_loader import parse_agenda, save_agenda + + self.medical = parse_agenda(MEDICAL_AGENDA) + save_agenda(self.medical) + self.crypto = parse_agenda(CRYPTO_AGENDA) + save_agenda(self.crypto) + + # Untagged corpus: two domains + self._seed_insight(1, "Diffusion model for medical imaging segmentation", + "Few-shot generalization across hospitals") + self._seed_insight(2, "Cross-center medical imaging benchmark", + "Segmentation models break across centers") + self._seed_insight(3, "Lattice-based post-quantum cryptography scheme", + "Encryption against quantum adversaries") + self._seed_insight(4, "Side-channel analysis of lattice cryptography", + "Lattice implementations leak timing information") + # Tagged outputs: one per agenda + self._seed_insight(10, "Agenda-produced medical insight", + "Generated under the medical agenda", + agenda_id=self.medical.agenda_id) + self._seed_insight(11, "Agenda-produced crypto insight", + "Generated under the crypto agenda", + agenda_id=self.crypto.agenda_id) + + def _titles(self, rows): + return [str(r["title"]).lower() for r in rows] + + def test_pools_do_not_cross_domains(self): + from agents.agenda_selector import _fetch_insight_pool + + medical_pool = self._titles(_fetch_insight_pool(agenda=self.medical)) + crypto_pool = self._titles(_fetch_insight_pool(agenda=self.crypto)) + + self.assertTrue(medical_pool) + self.assertTrue(crypto_pool) + # No crypto-domain term in the medical pool, and vice versa + for title in medical_pool: + for term in CRYPTO_TERMS: + self.assertNotIn(term, title) + for title in crypto_pool: + for term in MEDICAL_TERMS: + self.assertNotIn(term, title) + + def test_tagged_insights_stay_with_their_agenda(self): + from agents.agenda_selector import _fetch_insight_pool + + medical_ids = {r["id"] for r in _fetch_insight_pool(agenda=self.medical)} + crypto_ids = {r["id"] for r in _fetch_insight_pool(agenda=self.crypto)} + + self.assertIn(10, medical_ids) + self.assertNotIn(11, medical_ids) + self.assertIn(11, crypto_ids) + self.assertNotIn(10, crypto_ids) + + def test_unscoped_pool_unchanged(self): + from agents.agenda_selector import _fetch_insight_pool + + ids = {r["id"] for r in _fetch_insight_pool()} + self.assertEqual(ids, {1, 2, 3, 4, 10, 11}) + + def test_scoped_selection_picks_within_domain(self): + from agents.agenda_selector import select_and_persist + + sel_med = select_and_persist(self.medical, scope_to_agenda=True) + sel_cry = select_and_persist(self.crypto, scope_to_agenda=True) + self.assertIn(sel_med.selected_insight_id, {1, 2, 10}) + self.assertIn(sel_cry.selected_insight_id, {3, 4, 11}) + + +class TaxonomyCircleTests(AgendaScopeTestBase): + def _seed_node(self, node_id, name, description=""): + self.db.execute( + "INSERT INTO taxonomy_nodes (id, name, description, parent_id, depth) " + "VALUES (?, ?, ?, NULL, 1)", + (node_id, name, description), + ) + self.db.commit() + + def test_node_circle_matches_keywords_only(self): + from agents.signal_harvester import agenda_taxonomy_node_ids + + self._seed_node("ml.medimg", "Medical Imaging", "segmentation and diagnosis") + self._seed_node("ml.crypto", "Cryptography", "lattice-based encryption") + self._seed_node("ml.nlp", "Natural Language Processing", "") + + medical_nodes = agenda_taxonomy_node_ids(["medical imaging", "segmentation"]) + crypto_nodes = agenda_taxonomy_node_ids(["cryptography", "lattice"]) + + self.assertEqual(medical_nodes, ["ml.medimg"]) + self.assertEqual(crypto_nodes, ["ml.crypto"]) + self.assertEqual(agenda_taxonomy_node_ids([]), []) + + def test_wildcard_keyword_does_not_match_everything(self): + from agents.signal_harvester import agenda_taxonomy_node_ids + + self._seed_node("ml.medimg", "Medical Imaging", "segmentation") + self._seed_node("ml.crypto", "Cryptography", "lattice") + self._seed_node("ml.quant", "Quantization", "keeps 99% accuracy at 4-bit") + + # LIKE wildcards in user keywords must match literally, not widen + # scope: '%' only hits the node whose text contains a literal '%'. + self.assertEqual(agenda_taxonomy_node_ids(["%"]), ["ml.quant"]) + self.assertEqual(agenda_taxonomy_node_ids(["_"]), []) + self.assertEqual(agenda_taxonomy_node_ids(["99% accuracy"]), ["ml.quant"]) + + +class LikeWildcardEscapeTests(AgendaScopeTestBase): + """User-supplied scope keywords go into SQL LIKE patterns; wildcard + characters must not widen the candidate pool beyond the literal term.""" + + def test_escape_like_escapes_wildcards(self): + from db.sql_dialect import escape_like + + self.assertEqual(escape_like(r"50%_\x"), r"50\%\_\\x") + self.assertEqual(escape_like("plain term"), "plain term") + + def test_percent_keyword_does_not_widen_insight_pool(self): + from agents.agenda_loader import parse_agenda + from agents.agenda_selector import _fetch_insight_pool + + self._seed_insight(1, "Diffusion model for medical imaging", + "Few-shot generalization") + self._seed_insight(2, "Lattice-based cryptography scheme", + "Post-quantum encryption") + + wild = parse_agenda({ + "version": "v1", + "name": "wildcard_probe", + "focus": ["%"], + }) + # Pre-escape this pattern ('%%%') matched the whole table + self.assertEqual(_fetch_insight_pool(agenda=wild), []) + + underscore = parse_agenda({ + "version": "v1", + "name": "underscore_probe", + "focus": ["________"], + }) + self.assertEqual(_fetch_insight_pool(agenda=underscore), []) + + def test_literal_percent_keyword_matches_only_literal_text(self): + from agents.agenda_loader import parse_agenda + from agents.agenda_selector import _fetch_insight_pool + + self._seed_insight(1, "Quantization keeps 99% accuracy at 4-bit", + "Compression without quality loss") + self._seed_insight(2, "Reaching 99 points of accuracy with ensembling", + "Ensembles for tabular data") + + agenda = parse_agenda({ + "version": "v1", + "name": "literal_percent", + "focus": ["99% accuracy"], + }) + ids = {r["id"] for r in _fetch_insight_pool(agenda=agenda)} + self.assertEqual(ids, {1}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_direction_intake.py b/tests/test_direction_intake.py new file mode 100644 index 0000000..82d3398 --- /dev/null +++ b/tests/test_direction_intake.py @@ -0,0 +1,259 @@ +"""Tests for agents.direction_intake + scripts.agenda_inbox_watcher. + +Covers the deterministic YAML -> ResearchAgenda mapping (good input, defaults, +compute-constraint mapping, echo content) and bad-input handling, plus the +inbox watcher's processed/failed file flow against a temp database. +""" + +from __future__ import annotations + +import json +import os +import sys +import tempfile +import unittest +from pathlib import Path + +os.environ.setdefault("DEEPGRAPH_DATABASE_URL", "") + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "scripts")) + + +SAMPLE_DIRECTION_YAML = """ +direction: "用扩散模型做小样本医学影像分割,关注跨中心泛化" +keywords: [medical imaging, diffusion, few-shot] +constraints: + compute: "单卡以内" + data: "仅公开数据集" +goal: experiment_plan +contact: "alice@example.com" +""" + + +class DirectionParsingTests(unittest.TestCase): + def test_full_sample_maps_all_fields(self): + from agents.direction_intake import parse_direction_yaml + + agenda, echo = parse_direction_yaml(SAMPLE_DIRECTION_YAML) + self.assertEqual(agenda.description, "用扩散模型做小样本医学影像分割,关注跨中心泛化") + self.assertEqual(agenda.focus, ["medical imaging", "diffusion", "few-shot"]) + self.assertEqual(agenda.required_output, {"goal": "experiment_plan"}) + self.assertEqual(agenda.submitter, "alice@example.com") + # "单卡以内" -> conservative resource classes + self.assertEqual(agenda.prefer.get("resource_class"), ["cpu", "gpu_small"]) + # Original submission preserved verbatim + self.assertEqual(agenda.raw_config["constraints"]["compute"], "单卡以内") + self.assertEqual(agenda.raw_config["constraints"]["data"], "仅公开数据集") + self.assertTrue(agenda.name.startswith("direction-medical-imaging")) + # Echo summarises what was understood + self.assertEqual(echo["type"], "direction_intake_echo") + self.assertEqual(echo["focus"], agenda.focus) + self.assertEqual(echo["goal"], "experiment_plan") + self.assertIn("单卡以内", echo["summary"]) + self.assertIn("alice@example.com", echo["summary"]) + + def test_name_is_deterministic(self): + from agents.direction_intake import parse_direction_yaml + + a1, _ = parse_direction_yaml(SAMPLE_DIRECTION_YAML) + a2, _ = parse_direction_yaml(SAMPLE_DIRECTION_YAML) + self.assertEqual(a1.name, a2.name) + + def test_goal_defaults_to_experiment_plan(self): + from agents.direction_intake import parse_direction_yaml + + agenda, _ = parse_direction_yaml( + "direction: few-shot segmentation with diffusion models\n" + "contact: bob\n" + ) + self.assertEqual(agenda.required_output, {"goal": "experiment_plan"}) + + def test_focus_falls_back_to_direction_tokens(self): + from agents.direction_intake import parse_direction_yaml + + agenda, _ = parse_direction_yaml( + "direction: cross-center generalization for medical segmentation\n" + "contact: bob\n" + ) + self.assertIn("medical", agenda.focus) + self.assertIn("segmentation", agenda.focus) + + def test_compute_mapping_rules(self): + from agents.direction_intake import map_compute_constraint + + self.assertEqual(map_compute_constraint("单卡以内"), ["cpu", "gpu_small"]) + self.assertEqual(map_compute_constraint("Single GPU please"), ["cpu", "gpu_small"]) + self.assertEqual(map_compute_constraint("笔记本就能跑"), ["cpu"]) + self.assertEqual(map_compute_constraint("CPU only"), ["cpu"]) + self.assertIsNone(map_compute_constraint("8x H100 cluster")) + self.assertIsNone(map_compute_constraint("")) + + # ---------- bad input ---------- + + def test_missing_direction_rejected(self): + from agents.direction_intake import DirectionParseError, parse_direction_yaml + + with self.assertRaises(DirectionParseError): + parse_direction_yaml("keywords: [a, b]\ncontact: bob\n") + + def test_missing_contact_rejected(self): + from agents.direction_intake import DirectionParseError, parse_direction_yaml + + with self.assertRaises(DirectionParseError): + parse_direction_yaml("direction: some research direction\n") + + def test_invalid_goal_rejected(self): + from agents.direction_intake import DirectionParseError, parse_direction_yaml + + with self.assertRaises(DirectionParseError): + parse_direction_yaml( + "direction: some research direction\ncontact: bob\ngoal: world_peace\n" + ) + + def test_non_mapping_yaml_rejected(self): + from agents.direction_intake import DirectionParseError, parse_direction_yaml + + with self.assertRaises(DirectionParseError): + parse_direction_yaml("- just\n- a list\n") + + def test_empty_text_rejected(self): + from agents.direction_intake import DirectionParseError, parse_direction_yaml + + with self.assertRaises(DirectionParseError): + parse_direction_yaml(" \n") + + def test_broken_yaml_rejected(self): + from agents.direction_intake import DirectionParseError, parse_direction_yaml + + with self.assertRaises(DirectionParseError): + parse_direction_yaml("direction: [unclosed\ncontact: bob\n") + + def test_chinese_only_direction_without_keywords_rejected(self): + from agents.direction_intake import DirectionParseError, parse_direction_yaml + + # No keywords, no extractable ASCII terms, no compute constraint: + # nothing to scope on -> reject with guidance instead of a blind agenda. + with self.assertRaises(DirectionParseError): + parse_direction_yaml("direction: 量子计算研究\ncontact: bob\n") + + def test_bad_token_budget_rejected(self): + from agents.direction_intake import DirectionParseError, parse_direction_yaml + + with self.assertRaises(DirectionParseError): + parse_direction_yaml( + "direction: diffusion segmentation\ncontact: bob\ntoken_budget: lots\n" + ) + + +class InboxWatcherTests(unittest.TestCase): + def setUp(self): + self._tmpdir = tempfile.TemporaryDirectory() + os.environ["DEEPGRAPH_DB_PATH"] = str(Path(self._tmpdir.name) / "test.db") + from db import database as db + + self._original_db_path = db.DB_PATH + for attr in ("sqlite_conn", "pg_conn", "conn"): + if hasattr(db._local, attr): + try: + getattr(db._local, attr).close() + except Exception: + pass + delattr(db._local, attr) + db.DB_PATH = Path(os.environ["DEEPGRAPH_DB_PATH"]) + db.init_db() + self.db = db + self.inbox = Path(self._tmpdir.name) / "inbox" + + def tearDown(self): + from db import database as db + + for attr in ("sqlite_conn", "pg_conn", "conn"): + if hasattr(db._local, attr): + try: + getattr(db._local, attr).close() + except Exception: + pass + delattr(db._local, attr) + db.DB_PATH = self._original_db_path + self._tmpdir.cleanup() + os.environ.pop("DEEPGRAPH_DB_PATH", None) + + def test_good_file_is_processed_with_echo(self): + import agenda_inbox_watcher as watcher + + self.inbox.mkdir(parents=True) + (self.inbox / "alice.yaml").write_text(SAMPLE_DIRECTION_YAML, encoding="utf-8") + + results = watcher.scan_inbox(self.inbox) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["status"], "ok") + agenda_id = results[0]["agenda_id"] + + # File moved to processed/ + echo written + self.assertFalse((self.inbox / "alice.yaml").exists()) + processed = self.inbox / "processed" / "alice.yaml" + self.assertTrue(processed.exists()) + echo = json.loads( + (self.inbox / "processed" / "alice.yaml.echo.json").read_text(encoding="utf-8") + ) + self.assertEqual(echo["agenda_id"], agenda_id) + self.assertEqual(echo["focus"], ["medical imaging", "diffusion", "few-shot"]) + + # Agenda actually persisted + row = self.db.fetchone("SELECT * FROM research_agendas WHERE id=?", (agenda_id,)) + self.assertIsNotNone(row) + self.assertEqual(row["submitter"], "alice@example.com") + + def test_bad_file_is_quarantined_with_error(self): + import agenda_inbox_watcher as watcher + + self.inbox.mkdir(parents=True) + (self.inbox / "broken.yaml").write_text("keywords: [a]\n", encoding="utf-8") + + results = watcher.scan_inbox(self.inbox) + self.assertEqual(results[0]["status"], "failed") + self.assertFalse((self.inbox / "broken.yaml").exists()) + failed = self.inbox / "failed" / "broken.yaml" + self.assertTrue(failed.exists()) + error_text = (self.inbox / "failed" / "broken.yaml.error.txt").read_text(encoding="utf-8") + self.assertIn("direction", error_text) + # Nothing persisted + row = self.db.fetchone("SELECT COUNT(*) AS c FROM research_agendas") + self.assertEqual(row["c"], 0) + + def test_oversized_file_quarantined_without_parsing(self): + import agenda_inbox_watcher as watcher + + self.inbox.mkdir(parents=True) + big = "# padding\n" * (watcher.MAX_SUBMISSION_BYTES // 10 + 1) + path = self.inbox / "huge.yaml" + path.write_text(big, encoding="utf-8") + self.assertGreater(path.stat().st_size, watcher.MAX_SUBMISSION_BYTES) + + results = watcher.scan_inbox(self.inbox) + self.assertEqual(results[0]["status"], "failed") + self.assertIn("too large", results[0]["error"]) + self.assertFalse(path.exists()) + failed = self.inbox / "failed" / "huge.yaml" + self.assertTrue(failed.exists()) + error_text = (self.inbox / "failed" / "huge.yaml.error.txt").read_text( + encoding="utf-8" + ) + self.assertIn("too large", error_text) + self.assertIn(str(watcher.MAX_SUBMISSION_BYTES), error_text) + # Nothing persisted + row = self.db.fetchone("SELECT COUNT(*) AS c FROM research_agendas") + self.assertEqual(row["c"], 0) + + def test_non_yaml_files_ignored(self): + import agenda_inbox_watcher as watcher + + self.inbox.mkdir(parents=True) + (self.inbox / "notes.txt").write_text("hello", encoding="utf-8") + results = watcher.scan_inbox(self.inbox) + self.assertEqual(results, []) + self.assertTrue((self.inbox / "notes.txt").exists()) + + +if __name__ == "__main__": + unittest.main() diff --git a/web/agenda_routes.py b/web/agenda_routes.py index 9065c98..b380509 100644 --- a/web/agenda_routes.py +++ b/web/agenda_routes.py @@ -10,6 +10,12 @@ POST /api/research_agenda/selection//review run reviewer POST /api/research_agenda/selection//plan build revision plan GET /api/research_agenda/loop/ full end-to-end snapshot + +Multi-agenda isolation + budgets: + GET /api/research_agenda//insights insights tagged with this agenda + GET /api/research_agenda//selections selections of this agenda + POST /api/research_agenda//resume reactivate a budget-paused agenda + GET /api/token_usage token ledger summary (all agendas) """ from __future__ import annotations @@ -19,13 +25,14 @@ import yaml # type: ignore from flask import Blueprint, jsonify, request -from agents import agenda_loader, agenda_orchestrator, agenda_selector, evidence_gate, reviewer_adapter, revision_planner +from agents import agenda_budget, agenda_loader, agenda_orchestrator, agenda_selector, evidence_gate, reviewer_adapter, revision_planner from contracts.agenda import LoopInspectionSnapshot from contracts.base import ContractValidationError from db import database as db bp = Blueprint("research_agenda", __name__, url_prefix="/api/research_agenda") +usage_bp = Blueprint("token_usage", __name__, url_prefix="/api") def _agenda_to_dict(agenda): @@ -42,6 +49,10 @@ def _agenda_to_dict(agenda): "required_output": agenda.required_output, "is_active": agenda.is_active, "raw_config": agenda.raw_config, + "submitter": agenda.submitter, + "token_budget": agenda_budget.effective_budget(agenda.token_budget), + "token_spent": agenda.token_spent, + "status": agenda.status, } @@ -96,6 +107,63 @@ def current_agenda(): return jsonify({"agenda": _agenda_to_dict(agenda)}) +# ---------- multi-agenda isolation + budget ---------- + + +@bp.route("//insights", methods=["GET"]) +def agenda_insights(agenda_id: int): + """Insights produced for this agenda (tagged with agenda_id).""" + if agenda_loader.get_agenda(agenda_id) is None: + return jsonify({"error": "agenda_not_found", "agenda_id": agenda_id}), 404 + limit = max(1, min(request.args.get("limit", default=100, type=int), 500)) + rows = db.fetchall( + "SELECT id, tier, status, title, adversarial_score, novelty_status, " + "resource_class, experimentability, submission_status, outcome, created_at " + "FROM deep_insights WHERE agenda_id=? ORDER BY id DESC LIMIT ?", + (agenda_id, limit), + ) + return jsonify({"agenda_id": agenda_id, "insights": rows}) + + +@bp.route("//selections", methods=["GET"]) +def agenda_selections(agenda_id: int): + """Selections belonging to this agenda only.""" + if agenda_loader.get_agenda(agenda_id) is None: + return jsonify({"error": "agenda_not_found", "agenda_id": agenda_id}), 404 + limit = max(1, min(request.args.get("limit", default=50, type=int), 500)) + rows = db.fetchall( + "SELECT * FROM agenda_selections WHERE agenda_id=? " + "ORDER BY created_at DESC, id DESC LIMIT ?", + (agenda_id, limit), + ) + return jsonify({ + "agenda_id": agenda_id, + "selections": [agenda_selector._row_to_selection_dict(r) for r in rows], + }) + + +@bp.route("//resume", methods=["POST"]) +def resume_agenda_endpoint(agenda_id: int): + """Reactivate a budget-paused agenda; optionally raise its token budget.""" + if agenda_loader.get_agenda(agenda_id) is None: + return jsonify({"error": "agenda_not_found", "agenda_id": agenda_id}), 404 + body = request.get_json(silent=True) or {} + token_budget = body.get("token_budget") + if token_budget is not None: + if isinstance(token_budget, bool) or not isinstance(token_budget, int): + return jsonify({"error": "invalid_request", "message": "token_budget must be an integer"}), 400 + if token_budget < 0: + return jsonify({"error": "invalid_request", "message": "token_budget must be >= 0"}), 400 + state = agenda_budget.resume_agenda(agenda_id, token_budget=token_budget) + return jsonify({"agenda_id": agenda_id, "budget": state}) + + +@usage_bp.route("/token_usage", methods=["GET"]) +def token_usage(): + """Local token accounting per agenda + totals (read-only).""" + return jsonify(agenda_budget.usage_summary()) + + # ---------- selection ---------- @@ -104,6 +172,7 @@ def trigger_selection(): body = request.get_json(silent=True) or {} agenda_id = body.get("agenda_id") dispatch_mode = body.get("dispatch_mode", "auto") + scoped = bool(body.get("scoped")) if dispatch_mode not in ("auto", "link", "enqueue", "bench", "none"): return ( jsonify({"error": "invalid_dispatch_mode", "message": "must be auto|link|enqueue|bench|none"}), @@ -118,7 +187,7 @@ def trigger_selection(): if agenda is None: return jsonify({"error": "no_active_agenda"}), 404 - selection = agenda_selector.select_and_persist(agenda) + selection = agenda_selector.select_and_persist(agenda, scope_to_agenda=scoped) dispatch_result = None dispatch_succeeded = None if dispatch_mode != "none" and selection.selected_insight_id: @@ -350,3 +419,4 @@ def loop_inspection(selection_id: int): def register(app): app.register_blueprint(bp) + app.register_blueprint(usage_bp)