diff --git a/main.py b/main.py index 8f5492e..f1e60e9 100644 --- a/main.py +++ b/main.py @@ -146,6 +146,13 @@ def main(): start_gpu_scheduler() print("Auto Research worker ready.", flush=True) + # Warm the /api/stats cache in the background so the first browser paint + # is served from cache, not a cold ~30-COUNT(*) query (issue #34). + import threading + from web.app import prewarm_stats_cache + print("Prewarming stats cache in background...", flush=True) + threading.Thread(target=prewarm_stats_cache, daemon=True).start() + # Start web server _serve_http() finally: diff --git a/tests/test_web_app.py b/tests/test_web_app.py index 9318472..fdd2850 100644 --- a/tests/test_web_app.py +++ b/tests/test_web_app.py @@ -438,5 +438,185 @@ def test_paper_preview_routes_serve_current_tex(self): tex_response.close() +class StatsCacheTests(unittest.TestCase): + """Issue #34 · Feature 1 — /api/stats served from an in-process TTL cache.""" + + def setUp(self): + self.client = web_app.app.test_client() + + def test_api_stats_served_from_cache_not_recomputed_each_request(self): + sentinel = { + "papers_processed": 1, + "results_total": 2, + "insights_total": 3, + "deep_insights_total": 4, + "submission_bundles_total": 5, + } + with mock.patch.object( + web_app, "get_stats_dict", return_value=sentinel + ) as heavy: + web_app._stats_cache.invalidate() + web_app._stats_cache.prewarm() + for _ in range(8): + response = self.client.get("/api/stats") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.get_json(), sentinel) + + # The heavy COUNT(*) computation must run at most once across the + # warm-up plus eight requests — proving the cache is hit, not recomputed. + self.assertLessEqual(heavy.call_count, 1) + + def test_api_stats_returns_correct_fields(self): + sentinel = { + "papers_processed": 11, + "results_total": 22, + "insights_total": 33, + "deep_insights_total": 44, + "submission_bundles_total": 55, + } + with mock.patch.object(web_app, "get_stats_dict", return_value=sentinel): + web_app._stats_cache.invalidate() + web_app._stats_cache.prewarm() + payload = self.client.get("/api/stats").get_json() + + # Caching must not change the contract: every field and value the + # underlying computation produced is served verbatim. + for key in ( + "papers_processed", + "results_total", + "insights_total", + "deep_insights_total", + "submission_bundles_total", + ): + self.assertIn(key, payload) + self.assertEqual(payload, sentinel) + + def test_import_web_app_does_not_block_on_heavy_stats(self): + import importlib + from orchestrator import pipeline + + with mock.patch.object(pipeline, "get_stats_dict") as heavy: + importlib.reload(web_app) + # Importing / building the app must not synchronously run the heavy + # stats query (tests import web.app; deploy imports it at startup). + heavy.assert_not_called() + # Restore a clean module bound to the real computation. + importlib.reload(web_app) + + +class EventsTailTests(unittest.TestCase): + """Issue #34 · Feature 2 — /api/events?since=0 returns only the tail.""" + + def setUp(self): + self.client = web_app.app.test_client() + + def test_api_events_since_zero_returns_tail_only(self): + full = [{"seq": i, "type": "x"} for i in range(120)] + with mock.patch.object(web_app, "get_events", return_value=list(full)): + payload = self.client.get("/api/events?since=0").get_json() + + self.assertLessEqual(len(payload["events"]), 50) + # next_seq must still point past the newest event so subsequent + # ?since=next_seq polling keeps advancing correctly. + self.assertEqual(payload["next_seq"], 120) + self.assertEqual(payload["events"][-1]["seq"], 119) + + def test_api_events_incremental_since_unchanged(self): + incremental = [{"seq": i, "type": "x"} for i in range(50, 120)] + with mock.patch.object( + web_app, "get_events", return_value=list(incremental) + ) as get_events: + payload = self.client.get("/api/events?since=50").get_json() + + get_events.assert_called_once_with(50) + # Incremental polling is untouched: every event after the cursor is + # returned, no truncation. + self.assertEqual(len(payload["events"]), 70) + self.assertEqual(payload["events"][0]["seq"], 50) + self.assertEqual(payload["next_seq"], 120) + + +class FirstPaintE2ETests(unittest.TestCase): + """Issue #34 · end-to-end — the first-paint endpoints are 200 and bounded.""" + + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.old_db_path = database.DB_PATH + self.old_database_url = database.DATABASE_URL + for attr in ("pg_conn", "sqlite_conn", "conn"): + if hasattr(database._local, attr): + try: + getattr(database._local, attr).close() + except Exception: + pass + setattr(database._local, attr, None) + database.DATABASE_URL = "" + database.DB_PATH = Path(self.tmpdir.name) / "firstpaint.db" + database.init_db() + self.client = web_app.app.test_client() + + def tearDown(self): + for attr in ("pg_conn", "sqlite_conn", "conn"): + if hasattr(database._local, attr): + try: + getattr(database._local, attr).close() + except Exception: + pass + setattr(database._local, attr, None) + database.DATABASE_URL = self.old_database_url + database.DB_PATH = self.old_db_path + self.tmpdir.cleanup() + + def test_first_paint_endpoints_are_fast_and_bounded(self): + # Seed more than the tail size so truncation is actually exercised. + for i in range(80): + web_app.log_event("seed", {"i": i}) + + stats_payload = { + "papers_processed": 0, + "results_total": 0, + "insights_total": 0, + "deep_insights_total": 0, + "submission_bundles_total": 0, + } + first_paint_paths = [ + "/api/stats", + "/api/events?since=0", + "/api/recent_discoveries", + "/api/insights?limit=6", + "/api/deep_insights?limit=4", + ] + + with mock.patch.object( + web_app, "get_stats_dict", return_value=stats_payload + ) as heavy: + web_app._stats_cache.invalidate() + web_app._stats_cache.prewarm() # deploy/startup prewarm + responses = {p: self.client.get(p) for p in first_paint_paths} + + # 2) stats is served from the warm cache: the heavy query is not + # recomputed per request (mock count, no wall-clock dependency). + self.assertLessEqual(heavy.call_count, 1) + + # 1) every first-paint endpoint returns 200. + for path, response in responses.items(): + self.assertEqual(response.status_code, 200, f"{path} -> {response.status_code}") + + # 3) since=0 returns only the tail (<= 50 events). + events_payload = responses["/api/events?since=0"].get_json() + self.assertLessEqual(len(events_payload["events"]), 50) + + # 4) stats keeps its contract fields (cache does not drop/rename them). + stats_resp = responses["/api/stats"].get_json() + for key in ( + "papers_processed", + "results_total", + "insights_total", + "deep_insights_total", + "submission_bundles_total", + ): + self.assertIn(key, stats_resp) + + if __name__ == "__main__": unittest.main() diff --git a/web/app.py b/web/app.py index dea80e0..5b6afe8 100644 --- a/web/app.py +++ b/web/app.py @@ -18,6 +18,7 @@ from agents.taxonomy_expander import run_expansion from web.agenda_routes import register as register_agenda_routes from web.manuscript_routes import register as register_manuscript_routes +from web.stats_cache import StatsCache app = Flask(__name__, template_folder="templates", @@ -26,6 +27,25 @@ register_agenda_routes(app) register_manuscript_routes(app) +# In-process TTL cache for the heavy /api/stats query (issue #34). The lambda +# resolves get_stats_dict at call time so it stays patchable in tests; the +# cache is lazy — constructing it here does NOT run the heavy query at import. +STATS_CACHE_TTL_SECONDS = 30.0 +_stats_cache = StatsCache(lambda: get_stats_dict(), ttl=STATS_CACHE_TTL_SECONDS) + + +def prewarm_stats_cache(): + """Warm the stats cache once and start its background refresher. Call from + server startup (in a thread) so the first browser paint is served from a + warm cache, not a cold ~30-COUNT(*) query, and stays fresh thereafter.""" + _stats_cache.prewarm() + _stats_cache.start_background_refresh() + +# How many events /api/events?since=0 returns on first paint. The frontend only +# keeps the last 50, so returning the full ~1000-event log just wastes ~463KB +# and a concurrency slot on the first-paint critical path (issue #34). +FIRST_PAINT_EVENT_TAIL = 50 + _pipeline_running = False _pipeline_lock = threading.Lock() @@ -646,7 +666,17 @@ def api_meta(): @app.route("/api/stats") def api_stats(): - return jsonify(get_stats_dict()) + # Serve from the in-process TTL cache; the heavy COUNT(*) query never runs + # in this request thread (issue #34). Stale entries trigger a background + # refresh inside the cache. + stats = _stats_cache.get() + if stats is None: + # Cold start before the startup warm-up completes: return a "warming" + # marker rather than blocking the request thread on the heavy query or + # fabricating numbers. The startup prewarm makes this window + # effectively never hit in production. + return jsonify({"warming": True}) + return jsonify(stats) @app.route("/api/providers") @@ -1159,6 +1189,13 @@ def api_events(): """Short-poll pipeline events without holding a web worker thread.""" since = max(0, request.args.get("since", 0, type=int) or 0) events = get_events(since) + if since == 0: + # First paint: the frontend only keeps the last 50 events, so return + # just the tail instead of the full ~1000-event log. next_seq is still + # computed from the newest event below, so subsequent ?since=next_seq + # polling continues forward correctly (issue #34). Incremental + # (since>0) requests are untouched. + events = events[-FIRST_PAINT_EVENT_TAIL:] payload_events = json.loads(json.dumps(events, ensure_ascii=False, default=str)) next_seq = since for event in payload_events: diff --git a/web/stats_cache.py b/web/stats_cache.py new file mode 100644 index 0000000..3a50212 --- /dev/null +++ b/web/stats_cache.py @@ -0,0 +1,89 @@ +"""Process-internal TTL cache for the heavy /api/stats query (issue #34). + +The dashboard's first paint blocks on ``/api/stats``, which runs ~30 +``COUNT(*)`` queries — several of them full-table scans over large tables +(``graph_relations`` ~500k rows, ``entity_resolutions`` ~180k, ``results`` +~110k …). Under concurrency this balloons to ~18s and the nine metric cards +stay blank for ten-plus seconds. + +This cache serves ``/api/stats`` from an in-process snapshot. The heavy query +runs only in the warm-up and in a single long-lived background refresher +thread — never in the request-serving thread. A single-process waitress +deployment makes an in-process cache sufficient (issue #34 non-goals: no Redis +/ multi-process). + +Guarantees relevant to the issue's red lines: +- ``get()`` is a pure read: it never runs the heavy query in the caller + thread, so per-request COUNT(*) scans are gone. +- Numbers are never fabricated: before the first warm-up completes ``get()`` + returns ``None`` (the route renders a "warming" marker), never fake values. +- The cache is lazy: constructing it does not run the query, so importing the + web app (as tests and startup do) stays cheap. +""" +import logging +import threading +import time + +logger = logging.getLogger(__name__) + + +class StatsCache: + def __init__(self, compute, ttl: float = 30.0, time_func=time.monotonic): + self._compute = compute + self._ttl = float(ttl) + self._time = time_func + self._lock = threading.Lock() + self._value = None + self._stamp = None # monotonic timestamp of the last successful compute + self._refresher = None # the single background refresher thread + self.compute_count = 0 # successful heavy recomputes (observability / tests) + + def get(self): + """Return the current cached stats snapshot. + + Pure read — never runs the heavy query in the caller thread. Returns + the last computed dict (a snapshot at most ``ttl``-ish seconds stale), + or ``None`` before the first warm-up completes. + """ + with self._lock: + return self._value + + def _recompute(self): + value = self._compute() + with self._lock: + self._value = value + self._stamp = self._time() + self.compute_count += 1 + return value + + def prewarm(self): + """Synchronously compute once and populate the cache. + + Called at startup (in a background thread so neither import nor server + start blocks) and by tests for determinism. + """ + return self._recompute() + + def start_background_refresh(self): + """Start the single daemon thread that refreshes the snapshot every + ``ttl`` seconds. Idempotent; intended to be called once at startup.""" + with self._lock: + if self._refresher is not None and self._refresher.is_alive(): + return + self._refresher = threading.Thread( + target=self._refresh_loop, name="stats-cache-refresh", daemon=True + ) + self._refresher.start() + + def _refresh_loop(self): + while True: + time.sleep(self._ttl) + try: + self._recompute() + except Exception: # pragma: no cover - defensive, logged not fatal + logger.exception("stats cache background refresh failed") + + def invalidate(self): + with self._lock: + self._value = None + self._stamp = None