diff --git a/workers/well-known-cache/src/index.ts b/workers/well-known-cache/src/index.ts index f2591472..395b99a6 100644 --- a/workers/well-known-cache/src/index.ts +++ b/workers/well-known-cache/src/index.ts @@ -5,12 +5,33 @@ interface Env { STELLAR_TOML_STALE_WHILE_REVALIDATE: string; DEFAULT_MAX_AGE: string; DEFAULT_STALE_WHILE_REVALIDATE: string; + ALLOWED_ORIGINS: string; } +function parseAllowedOrigins(raw: string): Set { + return new Set( + raw + .split(",") + .map((s) => s.trim()) + .filter(Boolean), + ); +} + +function getCorsHeaders( + requestOrigin: string | null, + allowedOrigins: Set, +): Record { + const headers: Record = { + "Access-Control-Allow-Methods": "GET, HEAD, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type", + }; -const CORS_HEADERS: Record = { - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "GET, HEAD, OPTIONS", - "Access-Control-Allow-Headers": "Content-Type", + if (requestOrigin && allowedOrigins.has(requestOrigin)) { + headers["Access-Control-Allow-Origin"] = requestOrigin; + headers["Vary"] = "Origin"; + } + + return headers; +} }; interface ErrorResponse { @@ -20,7 +41,7 @@ interface ErrorResponse { timestamp: string; } -function errorResponse(status: number, error: string, message: string): Response { +function errorResponse(status: number, error: string, message: string, corsHeaders: Record = {}): Response { const body: ErrorResponse = { status, error, @@ -31,7 +52,7 @@ function errorResponse(status: number, error: string, message: string): Response status, headers: { "Content-Type": "application/json", - ...CORS_HEADERS, + ...corsHeaders, }, }); } @@ -65,15 +86,20 @@ function logMetrics(metrics: RequestMetrics): void { export default { async fetch(request: Request, env: Env): Promise { + const allowedOrigins = parseAllowedOrigins(env.ALLOWED_ORIGINS ?? ""); + const requestOrigin = request.headers.get("Origin"); + const corsHeaders = getCorsHeaders(requestOrigin, allowedOrigins); + if (request.method === "OPTIONS") { - return new Response(null, { status: 204, headers: CORS_HEADERS }); + return new Response(null, { status: 204, headers: corsHeaders }); } if (!["GET", "HEAD"].includes(request.method)) { return errorResponse( 405, "Method Not Allowed", - `HTTP method ${request.method} is not supported. Use GET or HEAD.` + `HTTP method ${request.method} is not supported. Use GET or HEAD.`, + corsHeaders ); } @@ -81,7 +107,7 @@ export default { try { url = new URL(request.url); } catch { - return errorResponse(400, "Bad Request", "Invalid request URL."); + return errorResponse(400, "Bad Request", "Invalid request URL.", corsHeaders); } try { @@ -91,7 +117,7 @@ export default { if (cached) { const res = new Response(cached.body, cached); res.headers.set("cf-cache-status", "HIT"); - for (const [k, v] of Object.entries(CORS_HEADERS)) res.headers.set(k, v); + for (const [k, v] of Object.entries(corsHeaders)) res.headers.set(k, v); return res; } @@ -100,20 +126,21 @@ export default { return errorResponse( origin.status, origin.statusText || "Upstream Error", - `Origin server returned ${origin.status} for ${url.pathname}.` + `Origin server returned ${origin.status} for ${url.pathname}.`, + corsHeaders ); } const res = new Response(origin.body, origin); res.headers.set("Cache-Control", cacheControlFor(url.pathname)); res.headers.set("cf-cache-status", "MISS"); - for (const [k, v] of Object.entries(CORS_HEADERS)) res.headers.set(k, v); + for (const [k, v] of Object.entries(corsHeaders)) res.headers.set(k, v); await cache.put(request, res.clone()); return res; } catch (err) { const message = err instanceof Error ? err.message : "An unexpected error occurred."; - return errorResponse(502, "Bad Gateway", `Failed to fetch origin: ${message}`); + return errorResponse(502, "Bad Gateway", `Failed to fetch origin: ${message}`, corsHeaders); } }, } satisfies ExportedHandler; diff --git a/wrangler.toml b/wrangler.toml index c1f9c9c6..2bdcd620 100644 --- a/wrangler.toml +++ b/wrangler.toml @@ -14,3 +14,5 @@ STELLAR_TOML_STALE_WHILE_REVALIDATE = "86400" # Cache TTL for other .well-known paths (seconds) DEFAULT_MAX_AGE = "300" DEFAULT_STALE_WHILE_REVALIDATE = "3600" +# Comma-separated list of allowed CORS origins (exact match) +ALLOWED_ORIGINS = "https://yourdomain.com"