-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
347 lines (286 loc) · 13.2 KB
/
app.py
File metadata and controls
347 lines (286 loc) · 13.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
"""StoryForge — thin entry point.
Starts FastAPI, mounts API routes and static files.
The UI is served from web/ as a static Alpine.js SPA.
CORS policy:
Allowed origins are read from the STORYFORGE_ALLOWED_ORIGINS env var
(comma-separated list). Defaults to localhost:7860 only.
Wildcard "*" is intentionally NOT used — set explicit origins for production.
"""
import logging
import logging.handlers
import os
import sys
import time
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from config import ConfigManager
# Logging
from services.structured_logger import configure_logging
configure_logging()
# Replace the plain FileHandler with a RotatingFileHandler
# so log files never grow unbounded (D4: log rotation).
_root_logger = logging.getLogger()
for _h in list(_root_logger.handlers):
if isinstance(_h, logging.FileHandler) and not isinstance(
_h, logging.handlers.RotatingFileHandler
):
_fmt = _h.formatter
_root_logger.removeHandler(_h)
_h.close()
_rotating = logging.handlers.RotatingFileHandler(
"storyforge.log",
maxBytes=10 * 1024 * 1024, # 10 MB per file
backupCount=5,
encoding="utf-8",
)
_rotating.setFormatter(_fmt)
_root_logger.addHandler(_rotating)
break
logger = logging.getLogger(__name__)
# Uptime tracking
_START_TIME = time.time()
# ---------------------------------------------------------------------------
# CORS configuration helpers
# ---------------------------------------------------------------------------
#
# SEC-5 CORS Audit (Sprint 15) — verified safe:
# - No wildcard '*' is used in production. Wildcard entries in
# STORYFORGE_ALLOWED_ORIGINS are detected and rejected with a warning.
# - Default fallback is localhost:7860 only (safe for development).
# - Production deployments MUST set STORYFORGE_ALLOWED_ORIGINS to the
# explicit list of frontend origins, e.g.:
# STORYFORGE_ALLOWED_ORIGINS=https://app.storyforge.io,https://www.storyforge.io
# - Credentials are allowed (allow_credentials=True), requiring explicit
# origins — this is incompatible with '*' by the CORS spec.
# - Allowed methods: GET, POST, PUT, DELETE, OPTIONS (no TRACE/CONNECT).
# - Allowed headers: Authorization, Content-Type, Accept (minimal set).
#
_DEFAULT_ORIGINS = ["http://localhost:7860", "http://127.0.0.1:7860"]
def _get_allowed_origins() -> list[str]:
"""Read allowed CORS origins from STORYFORGE_ALLOWED_ORIGINS env var.
Falls back to localhost:7860 defaults. Rejects wildcard '*' with a warning.
Production usage:
export STORYFORGE_ALLOWED_ORIGINS="https://app.storyforge.io,https://cdn.storyforge.io"
"""
raw = os.environ.get("STORYFORGE_ALLOWED_ORIGINS", "")
if raw.strip():
origins = [o.strip() for o in raw.split(",") if o.strip()]
if "*" in origins:
logger.warning(
"STORYFORGE_ALLOWED_ORIGINS contains '*' — ignoring and using "
"safe defaults instead. Set explicit origins for production."
)
return _DEFAULT_ORIGINS
return origins
return _DEFAULT_ORIGINS
def _preflight_check() -> bool:
"""Validate DB and Redis connectivity before starting the server.
Returns True if all required services are reachable, False otherwise.
DB failure is always fatal. Redis failure is fatal only when
STORYFORGE_REDIS_REQUIRED=1, otherwise it is logged as a warning.
"""
from api.health_routes import _check_database, _check_redis
all_ok = True
# --- Database (required) ---
db_result = _check_database()
db_status = db_result.get("status")
if db_status == "ok":
logger.info("Preflight: database OK")
elif db_status == "not_configured":
logger.info("Preflight: database not configured — skipping")
else:
logger.error(
"Preflight: database UNREACHABLE (%s) — cannot start",
db_result.get("detail", db_status),
)
all_ok = False
# --- Redis (optional unless STORYFORGE_REDIS_REQUIRED=1) ---
redis_result = _check_redis()
redis_status = redis_result.get("status")
redis_required = os.environ.get("STORYFORGE_REDIS_REQUIRED", "").lower() in ("1", "true")
if redis_status == "ok":
logger.info("Preflight: Redis OK")
elif redis_status == "not_configured":
logger.info("Preflight: Redis not configured — running without cache")
elif redis_required:
logger.error(
"Preflight: Redis UNREACHABLE (%s) and STORYFORGE_REDIS_REQUIRED=1 — cannot start",
redis_result.get("detail", redis_status),
)
all_ok = False
else:
logger.warning(
"Preflight: Redis unavailable (%s) — continuing without cache (set "
"STORYFORGE_REDIS_REQUIRED=1 to make this fatal)",
redis_result.get("detail", redis_status),
)
return all_ok
def main():
"""Launch StoryForge — Alpine.js Web UI at /."""
from api import api_router
main_app = FastAPI(
title="StoryForge",
description=(
"AI-powered story generation platform. "
"Generate long-form Vietnamese stories with multi-layer pipeline: "
"story generation, drama simulation, and video storyboarding."
),
version="3.0.0",
docs_url="/docs",
redoc_url="/redoc",
openapi_tags=[
{"name": "pipeline", "description": "Run and manage story generation pipelines"},
{"name": "config", "description": "Manage application configuration and model presets"},
{"name": "export", "description": "Export stories to PDF, EPUB, and other formats"},
{"name": "analytics", "description": "Usage analytics and story statistics"},
{"name": "metrics", "description": "System performance metrics"},
{"name": "dashboard", "description": "Dashboard summary data"},
{"name": "auth", "description": "Authentication and user management"},
{"name": "ab", "description": "A/B testing for pipeline variants"},
{"name": "branch", "description": "Story branching and alternate paths"},
{"name": "audio", "description": "Text-to-speech and audio generation"},
],
)
# --- CORS middleware (restrictive: explicit origins only, no wildcard) ---
allowed_origins = _get_allowed_origins()
logger.info(f"CORS allowed origins: {allowed_origins}")
main_app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Authorization", "Content-Type", "Accept", "X-CSRF-Token"],
)
# --- CSRF protection middleware (double-submit cookie) ---
from middleware.csrf import CSRFMiddleware
main_app.add_middleware(CSRFMiddleware)
# --- Request trace ID middleware (must be outermost so all downstream layers see it) ---
from middleware.trace_id import TraceIDMiddleware
main_app.add_middleware(TraceIDMiddleware)
# --- Security headers middleware (CSP, X-Frame-Options, etc.) ---
from middleware.security_headers import SecurityHeadersMiddleware
main_app.add_middleware(SecurityHeadersMiddleware)
# --- Input sanitization middleware (prompt injection detection) ---
from middleware.sanitization import SanitizationMiddleware
main_app.add_middleware(SanitizationMiddleware)
# --- Rate limiting middleware (Redis or in-memory, per-IP) ---
from middleware.rate_limiter import RateLimitMiddleware
main_app.add_middleware(RateLimitMiddleware)
# --- Audit logging middleware ---
from middleware.audit_middleware import AuditMiddleware
main_app.add_middleware(AuditMiddleware)
# --- Request metrics middleware ---
from middleware.metrics_middleware import MetricsMiddleware
main_app.add_middleware(MetricsMiddleware)
from errors.exceptions import StoryForgeError
from errors.handlers import storyforge_error_handler
main_app.add_exception_handler(StoryForgeError, storyforge_error_handler)
# Global exception handler: log full traceback, return generic 500.
# Must be registered AFTER domain-specific handlers so those still fire first.
from api import register_exception_handlers
register_exception_handlers(main_app)
from fastapi.responses import JSONResponse
from services.security.input_sanitizer import InjectionBlockedError
@main_app.exception_handler(InjectionBlockedError)
async def injection_blocked_handler(request, exc):
return JSONResponse(status_code=422, content={"detail": str(exc)})
# Graceful shutdown: cancel and await active pipeline tasks
@main_app.on_event("shutdown")
async def on_shutdown():
from api.pipeline_routes import shutdown_pipeline_tasks
await shutdown_pipeline_tasks(timeout=30)
# API routes
main_app.include_router(api_router)
# --- API v1 versioned routes (mirrors /api/ with version header) ---
from api.v1 import v1_router, DeprecationMiddleware
main_app.include_router(v1_router)
main_app.add_middleware(DeprecationMiddleware)
# --- Body size limit (outermost — runs first, blocks oversized requests early) ---
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi import Request
from starlette.responses import JSONResponse as _SJSONResponse
MAX_BODY_SIZE = 10 * 1024 * 1024 # 10 MB
class BodySizeLimitMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
content_length = request.headers.get("content-length")
if content_length and int(content_length) > MAX_BODY_SIZE:
return _SJSONResponse(
status_code=413,
content={"detail": "Request body too large. Maximum size is 10MB."}
)
return await call_next(request)
main_app.add_middleware(BodySizeLimitMiddleware)
# Static files
base_dir = os.path.dirname(os.path.abspath(__file__))
web_dir = os.path.join(base_dir, "web")
locales_dir = os.path.join(base_dir, "locales")
# Mount locales FIRST (more specific path takes precedence)
if os.path.isdir(locales_dir):
main_app.mount("/static/locales", StaticFiles(directory=locales_dir), name="locales")
# Then mount web/ for remaining static files
main_app.mount("/static", StaticFiles(directory=web_dir), name="static")
# Serve index.html at root
@main_app.get("/")
async def serve_index():
return FileResponse(os.path.join(web_dir, "index.html"))
@main_app.get("/favicon.svg")
async def serve_favicon():
return FileResponse(os.path.join(web_dir, "favicon.svg"), media_type="image/svg+xml")
# Health check — lightweight with cached DB/Redis probes (30s TTL)
from fastapi.responses import JSONResponse as _JSONResponse
_health_cache: dict = {}
_HEALTH_CACHE_TTL = 30
def _cached_check(name: str, check_fn) -> dict:
cached = _health_cache.get(name)
now = time.time()
if cached and now - cached["ts"] < _HEALTH_CACHE_TTL:
return cached["result"]
result = check_fn()
_health_cache[name] = {"result": result, "ts": now}
return result
@main_app.get("/api/health")
async def health():
from api.health_routes import _check_database, _check_redis
cfg = ConfigManager()
llm_ok = bool(cfg.llm.api_key)
db_status = _cached_check("database", _check_database)
redis_status = _cached_check("redis", _check_redis)
db_ok = db_status.get("status") == "ok"
redis_str = redis_status.get("status", "unknown")
# Redis is optional (Phase 3) — report "fallback" not "error" when not required
if redis_str == "error" and os.environ.get(
"STORYFORGE_REDIS_REQUIRED", ""
).lower() not in ("1", "true"):
redis_str = "fallback"
degraded = not db_ok and db_status.get("status") != "not_configured"
status = "degraded" if degraded else "ok"
return _JSONResponse(
status_code=503 if degraded else 200,
content={
"status": status,
"version": "3.0",
"uptime_seconds": round(time.time() - _START_TIME, 1),
"services": {
"llm": llm_ok,
"database": db_status.get("status", "unknown"),
"redis": redis_str,
},
},
)
_secret = os.environ.get("STORYFORGE_SECRET_KEY", "")
if _secret in ("", "change-me-in-production"):
logger.warning(
"STORYFORGE_SECRET_KEY is not set or still default — "
"secrets at rest will NOT be encrypted. "
"Set a strong key for production use."
)
if not _preflight_check():
logger.error("Preflight checks failed — aborting startup")
sys.exit(1)
logger.info("StoryForge starting — Web UI at http://localhost:7860")
uvicorn.run(main_app, host="0.0.0.0", port=7860, log_level="info")
if __name__ == "__main__":
main()