Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ def main():
start_gpu_scheduler()
print("Auto Research worker ready.", flush=True)

# Warm the /api/stats cache in the background so the first browser paint
# is served from cache, not a cold ~30-COUNT(*) query (issue #34).
import threading
from web.app import prewarm_stats_cache
print("Prewarming stats cache in background...", flush=True)
threading.Thread(target=prewarm_stats_cache, daemon=True).start()

# Start web server
_serve_http()
finally:
Expand Down
180 changes: 180 additions & 0 deletions tests/test_web_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,5 +438,185 @@ def test_paper_preview_routes_serve_current_tex(self):
tex_response.close()


class StatsCacheTests(unittest.TestCase):
"""Issue #34 · Feature 1 — /api/stats served from an in-process TTL cache."""

def setUp(self):
self.client = web_app.app.test_client()

def test_api_stats_served_from_cache_not_recomputed_each_request(self):
sentinel = {
"papers_processed": 1,
"results_total": 2,
"insights_total": 3,
"deep_insights_total": 4,
"submission_bundles_total": 5,
}
with mock.patch.object(
web_app, "get_stats_dict", return_value=sentinel
) as heavy:
web_app._stats_cache.invalidate()
web_app._stats_cache.prewarm()
for _ in range(8):
response = self.client.get("/api/stats")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.get_json(), sentinel)

# The heavy COUNT(*) computation must run at most once across the
# warm-up plus eight requests — proving the cache is hit, not recomputed.
self.assertLessEqual(heavy.call_count, 1)

def test_api_stats_returns_correct_fields(self):
sentinel = {
"papers_processed": 11,
"results_total": 22,
"insights_total": 33,
"deep_insights_total": 44,
"submission_bundles_total": 55,
}
with mock.patch.object(web_app, "get_stats_dict", return_value=sentinel):
web_app._stats_cache.invalidate()
web_app._stats_cache.prewarm()
payload = self.client.get("/api/stats").get_json()

# Caching must not change the contract: every field and value the
# underlying computation produced is served verbatim.
for key in (
"papers_processed",
"results_total",
"insights_total",
"deep_insights_total",
"submission_bundles_total",
):
self.assertIn(key, payload)
self.assertEqual(payload, sentinel)

def test_import_web_app_does_not_block_on_heavy_stats(self):
import importlib
from orchestrator import pipeline

with mock.patch.object(pipeline, "get_stats_dict") as heavy:
importlib.reload(web_app)
# Importing / building the app must not synchronously run the heavy
# stats query (tests import web.app; deploy imports it at startup).
heavy.assert_not_called()
# Restore a clean module bound to the real computation.
importlib.reload(web_app)


class EventsTailTests(unittest.TestCase):
"""Issue #34 · Feature 2 — /api/events?since=0 returns only the tail."""

def setUp(self):
self.client = web_app.app.test_client()

def test_api_events_since_zero_returns_tail_only(self):
full = [{"seq": i, "type": "x"} for i in range(120)]
with mock.patch.object(web_app, "get_events", return_value=list(full)):
payload = self.client.get("/api/events?since=0").get_json()

self.assertLessEqual(len(payload["events"]), 50)
# next_seq must still point past the newest event so subsequent
# ?since=next_seq polling keeps advancing correctly.
self.assertEqual(payload["next_seq"], 120)
self.assertEqual(payload["events"][-1]["seq"], 119)

def test_api_events_incremental_since_unchanged(self):
incremental = [{"seq": i, "type": "x"} for i in range(50, 120)]
with mock.patch.object(
web_app, "get_events", return_value=list(incremental)
) as get_events:
payload = self.client.get("/api/events?since=50").get_json()

get_events.assert_called_once_with(50)
# Incremental polling is untouched: every event after the cursor is
# returned, no truncation.
self.assertEqual(len(payload["events"]), 70)
self.assertEqual(payload["events"][0]["seq"], 50)
self.assertEqual(payload["next_seq"], 120)


class FirstPaintE2ETests(unittest.TestCase):
"""Issue #34 · end-to-end — the first-paint endpoints are 200 and bounded."""

def setUp(self):
self.tmpdir = tempfile.TemporaryDirectory()
self.old_db_path = database.DB_PATH
self.old_database_url = database.DATABASE_URL
for attr in ("pg_conn", "sqlite_conn", "conn"):
if hasattr(database._local, attr):
try:
getattr(database._local, attr).close()
except Exception:
pass
setattr(database._local, attr, None)
database.DATABASE_URL = ""
database.DB_PATH = Path(self.tmpdir.name) / "firstpaint.db"
database.init_db()
self.client = web_app.app.test_client()

def tearDown(self):
for attr in ("pg_conn", "sqlite_conn", "conn"):
if hasattr(database._local, attr):
try:
getattr(database._local, attr).close()
except Exception:
pass
setattr(database._local, attr, None)
database.DATABASE_URL = self.old_database_url
database.DB_PATH = self.old_db_path
self.tmpdir.cleanup()

def test_first_paint_endpoints_are_fast_and_bounded(self):
# Seed more than the tail size so truncation is actually exercised.
for i in range(80):
web_app.log_event("seed", {"i": i})

stats_payload = {
"papers_processed": 0,
"results_total": 0,
"insights_total": 0,
"deep_insights_total": 0,
"submission_bundles_total": 0,
}
first_paint_paths = [
"/api/stats",
"/api/events?since=0",
"/api/recent_discoveries",
"/api/insights?limit=6",
"/api/deep_insights?limit=4",
]

with mock.patch.object(
web_app, "get_stats_dict", return_value=stats_payload
) as heavy:
web_app._stats_cache.invalidate()
web_app._stats_cache.prewarm() # deploy/startup prewarm
responses = {p: self.client.get(p) for p in first_paint_paths}

# 2) stats is served from the warm cache: the heavy query is not
# recomputed per request (mock count, no wall-clock dependency).
self.assertLessEqual(heavy.call_count, 1)

# 1) every first-paint endpoint returns 200.
for path, response in responses.items():
self.assertEqual(response.status_code, 200, f"{path} -> {response.status_code}")

# 3) since=0 returns only the tail (<= 50 events).
events_payload = responses["/api/events?since=0"].get_json()
self.assertLessEqual(len(events_payload["events"]), 50)

# 4) stats keeps its contract fields (cache does not drop/rename them).
stats_resp = responses["/api/stats"].get_json()
for key in (
"papers_processed",
"results_total",
"insights_total",
"deep_insights_total",
"submission_bundles_total",
):
self.assertIn(key, stats_resp)


if __name__ == "__main__":
unittest.main()
39 changes: 38 additions & 1 deletion web/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from agents.taxonomy_expander import run_expansion
from web.agenda_routes import register as register_agenda_routes
from web.manuscript_routes import register as register_manuscript_routes
from web.stats_cache import StatsCache

app = Flask(__name__,
template_folder="templates",
Expand All @@ -26,6 +27,25 @@
register_agenda_routes(app)
register_manuscript_routes(app)

# In-process TTL cache for the heavy /api/stats query (issue #34). The lambda
# resolves get_stats_dict at call time so it stays patchable in tests; the
# cache is lazy — constructing it here does NOT run the heavy query at import.
STATS_CACHE_TTL_SECONDS = 30.0
_stats_cache = StatsCache(lambda: get_stats_dict(), ttl=STATS_CACHE_TTL_SECONDS)


def prewarm_stats_cache():
"""Warm the stats cache once and start its background refresher. Call from
server startup (in a thread) so the first browser paint is served from a
warm cache, not a cold ~30-COUNT(*) query, and stays fresh thereafter."""
_stats_cache.prewarm()
_stats_cache.start_background_refresh()
Comment on lines +38 to +42

# How many events /api/events?since=0 returns on first paint. The frontend only
# keeps the last 50, so returning the full ~1000-event log just wastes ~463KB
# and a concurrency slot on the first-paint critical path (issue #34).
FIRST_PAINT_EVENT_TAIL = 50

_pipeline_running = False
_pipeline_lock = threading.Lock()

Expand Down Expand Up @@ -646,7 +666,17 @@ def api_meta():

@app.route("/api/stats")
def api_stats():
return jsonify(get_stats_dict())
# Serve from the in-process TTL cache; the heavy COUNT(*) query never runs
# in this request thread (issue #34). Stale entries trigger a background
# refresh inside the cache.
Comment on lines +669 to +671
stats = _stats_cache.get()
if stats is None:
# Cold start before the startup warm-up completes: return a "warming"
# marker rather than blocking the request thread on the heavy query or
# fabricating numbers. The startup prewarm makes this window
# effectively never hit in production.
return jsonify({"warming": True})
return jsonify(stats)


@app.route("/api/providers")
Expand Down Expand Up @@ -1159,6 +1189,13 @@ def api_events():
"""Short-poll pipeline events without holding a web worker thread."""
since = max(0, request.args.get("since", 0, type=int) or 0)
events = get_events(since)
if since == 0:
# First paint: the frontend only keeps the last 50 events, so return
# just the tail instead of the full ~1000-event log. next_seq is still
# computed from the newest event below, so subsequent ?since=next_seq
# polling continues forward correctly (issue #34). Incremental
# (since>0) requests are untouched.
events = events[-FIRST_PAINT_EVENT_TAIL:]
payload_events = json.loads(json.dumps(events, ensure_ascii=False, default=str))
next_seq = since
for event in payload_events:
Expand Down
89 changes: 89 additions & 0 deletions web/stats_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Process-internal TTL cache for the heavy /api/stats query (issue #34).

The dashboard's first paint blocks on ``/api/stats``, which runs ~30
``COUNT(*)`` queries — several of them full-table scans over large tables
(``graph_relations`` ~500k rows, ``entity_resolutions`` ~180k, ``results``
~110k …). Under concurrency this balloons to ~18s and the nine metric cards
stay blank for ten-plus seconds.

This cache serves ``/api/stats`` from an in-process snapshot. The heavy query
runs only in the warm-up and in a single long-lived background refresher
thread — never in the request-serving thread. A single-process waitress
deployment makes an in-process cache sufficient (issue #34 non-goals: no Redis
/ multi-process).

Guarantees relevant to the issue's red lines:
- ``get()`` is a pure read: it never runs the heavy query in the caller
thread, so per-request COUNT(*) scans are gone.
- Numbers are never fabricated: before the first warm-up completes ``get()``
returns ``None`` (the route renders a "warming" marker), never fake values.
- The cache is lazy: constructing it does not run the query, so importing the
web app (as tests and startup do) stays cheap.
"""
import logging
import threading
import time

logger = logging.getLogger(__name__)


class StatsCache:
def __init__(self, compute, ttl: float = 30.0, time_func=time.monotonic):
self._compute = compute
self._ttl = float(ttl)
self._time = time_func
self._lock = threading.Lock()
Comment on lines +31 to +35
self._value = None
self._stamp = None # monotonic timestamp of the last successful compute
self._refresher = None # the single background refresher thread
self.compute_count = 0 # successful heavy recomputes (observability / tests)

def get(self):
"""Return the current cached stats snapshot.

Pure read — never runs the heavy query in the caller thread. Returns
the last computed dict (a snapshot at most ``ttl``-ish seconds stale),
or ``None`` before the first warm-up completes.
"""
with self._lock:
return self._value

def _recompute(self):
value = self._compute()
with self._lock:
self._value = value
self._stamp = self._time()
self.compute_count += 1
return value

def prewarm(self):
"""Synchronously compute once and populate the cache.

Called at startup (in a background thread so neither import nor server
start blocks) and by tests for determinism.
"""
return self._recompute()

def start_background_refresh(self):
"""Start the single daemon thread that refreshes the snapshot every
``ttl`` seconds. Idempotent; intended to be called once at startup."""
with self._lock:
if self._refresher is not None and self._refresher.is_alive():
return
self._refresher = threading.Thread(
target=self._refresh_loop, name="stats-cache-refresh", daemon=True
)
self._refresher.start()

def _refresh_loop(self):
while True:
time.sleep(self._ttl)
try:
self._recompute()
except Exception: # pragma: no cover - defensive, logged not fatal
logger.exception("stats cache background refresh failed")

def invalidate(self):
with self._lock:
self._value = None
self._stamp = None
Loading