diff --git a/backend/README.md b/backend/README.md index 2cc093f..78916fa 100644 --- a/backend/README.md +++ b/backend/README.md @@ -6,10 +6,17 @@ High-performance Rust backend for Log-Based Alerting. - **Axum**: High-performance web framework. - **SQLx**: Async PostgreSQL driver with compile-time checked queries. - **Redis**: Caching and threshold tracking. +- **Upload Validation**: Safe file upload validation with size, name, and MIME checks. - **Tracing**: Observability and structured logging. +- **Error Handling**: Structured `AppError` responses for HTTP clients. ## API Endpoints +### Health & Observability +- `GET /health/live` - Liveness probe for process health. +- `GET /health/ready` - Readiness probe for PostgreSQL + Redis connectivity. +- `GET /metrics` - Prometheus metrics exposition endpoint. + ### Rules Management - `GET /api/alerts/rules` - List all alerting rules. - `POST /api/alerts/rules` - Create a new alerting rule. @@ -99,11 +106,15 @@ docker compose up -d --build # Check service health docker compose ps -# Test the health endpoint -curl http://localhost:8080/health +# Test the health endpoints +curl http://localhost:8080/health/live +curl http://localhost:8080/health/ready + +# Expected readiness response: +# {"status":"healthy","database":"healthy","cache":"healthy","version":"0.1.0"} -# Expected response: -# {"status":"ok","version":"0.1.0","database":"healthy","redis":"healthy"} +# Test the metrics endpoint +curl http://localhost:8080/metrics # Test the API status endpoint curl http://localhost:8080/api/v1/status diff --git a/backend/src/api/handlers/health.rs b/backend/src/api/handlers/health.rs new file mode 100644 index 0000000..834ac32 --- /dev/null +++ b/backend/src/api/handlers/health.rs @@ -0,0 +1,238 @@ +//! Health check endpoints. +//! +//! Provides two endpoints: +//! +//! - `GET /health/live` — liveness probe: returns 200 if the process is running. +//! - `GET /health/ready` — readiness probe: returns 200 only when PostgreSQL and +//! Redis are reachable; returns 503 otherwise. +//! +//! Both endpoints return a JSON body with per-component status details so that +//! operators can quickly identify which dependency is unhealthy. + +use axum::{extract::State, http::StatusCode, response::IntoResponse, Json}; +use redis::aio::ConnectionManager; +use redis::AsyncCommands; +use serde::Serialize; +use sqlx::PgPool; +use tracing::{debug, instrument, warn}; + +/// Minimal application state required by health check handlers. +#[derive(Clone)] +pub struct HealthState { + pub db: PgPool, + pub redis: ConnectionManager, +} + +// --------------------------------------------------------------------------- +// Response types +// --------------------------------------------------------------------------- + +/// Status of a single dependency. +#[derive(Debug, Serialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum ComponentStatus { + Healthy, + Unhealthy, +} + +/// Response body for the readiness probe. +#[derive(Debug, Serialize)] +pub struct ReadinessResponse { + /// Overall status: `"healthy"` or `"degraded"`. + pub status: String, + /// PostgreSQL connectivity. + pub database: ComponentStatus, + /// Redis connectivity. + pub cache: ComponentStatus, + /// Application version from `CARGO_PKG_VERSION`. + pub version: String, +} + +/// Response body for the liveness probe. +#[derive(Debug, Serialize)] +pub struct LivenessResponse { + pub status: &'static str, + pub version: String, +} + +// --------------------------------------------------------------------------- +// Handlers +// --------------------------------------------------------------------------- + +/// `GET /health/live` — liveness probe. +/// +/// Always returns `200 OK` as long as the process is running. Kubernetes uses +/// this to decide whether to restart the container. +#[instrument(skip_all)] +pub async fn liveness() -> impl IntoResponse { + debug!("Liveness probe"); + ( + StatusCode::OK, + Json(LivenessResponse { + status: "ok", + version: env!("CARGO_PKG_VERSION").to_string(), + }), + ) +} + +/// `GET /health/ready` — readiness probe. +/// +/// Checks PostgreSQL and Redis connectivity. Returns `200 OK` when all +/// dependencies are healthy, or `503 Service Unavailable` when any are not. +/// Kubernetes uses this to decide whether to route traffic to the pod. +#[instrument(skip_all)] +pub async fn readiness(State(state): State) -> impl IntoResponse { + let db_status = check_database(&state).await; + let cache_status = check_cache(&state).await; + + let all_healthy = + db_status == ComponentStatus::Healthy && cache_status == ComponentStatus::Healthy; + + let status_code = if all_healthy { + StatusCode::OK + } else { + StatusCode::SERVICE_UNAVAILABLE + }; + + ( + status_code, + Json(ReadinessResponse { + status: if all_healthy { "healthy".into() } else { "degraded".into() }, + database: db_status, + cache: cache_status, + version: env!("CARGO_PKG_VERSION").to_string(), + }), + ) +} + +// --------------------------------------------------------------------------- +// Dependency checks +// --------------------------------------------------------------------------- + +async fn check_database(state: &HealthState) -> ComponentStatus { + match sqlx::query_scalar::<_, i32>("SELECT 1") + .fetch_one(&state.db) + .await + { + Ok(_) => { + debug!("Database health check passed"); + ComponentStatus::Healthy + } + Err(e) => { + warn!("Database health check failed: {e}"); + ComponentStatus::Unhealthy + } + } +} + +async fn check_cache(state: &HealthState) -> ComponentStatus { + let mut conn = state.redis.clone(); + match redis::cmd("PING").query_async::(&mut conn).await { + Ok(_) => { + debug!("Cache health check passed"); + ComponentStatus::Healthy + } + Err(e) => { + warn!("Cache health check failed: {e}"); + ComponentStatus::Unhealthy + } + } +} + +// --------------------------------------------------------------------------- +// Router helper +// --------------------------------------------------------------------------- + +/// Returns an Axum router with the health check routes mounted. +/// +/// Mount this under `/health` in the main application router: +/// +/// ```rust,no_run +/// use axum::Router; +/// use backend::api::handlers::health; +/// +/// let app: Router = Router::new() +/// .nest("/health", health::router()); +/// ``` +pub fn router() -> axum::Router { + use axum::routing::get; + axum::Router::new() + .route("/live", get(liveness)) + .route("/ready", get(readiness)) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use axum::{body::Body, http::Request}; + use tower::ServiceExt; + + /// Build a minimal router with only the liveness endpoint (no AppState needed). + fn liveness_app() -> axum::Router { + use axum::routing::get; + axum::Router::new().route("/live", get(liveness)) + } + + #[tokio::test] + async fn liveness_returns_200() { + let app = liveness_app(); + let response = app + .oneshot(Request::builder().uri("/live").body(Body::empty()).unwrap()) + .await + .unwrap(); + assert_eq!(response.status(), StatusCode::OK); + } + + #[tokio::test] + async fn liveness_body_contains_ok() { + let app = liveness_app(); + let response = app + .oneshot(Request::builder().uri("/live").body(Body::empty()).unwrap()) + .await + .unwrap(); + let body = axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!(json["status"], "ok"); + assert!(json["version"].is_string()); + } + + #[test] + fn readiness_response_serializes_healthy() { + let resp = ReadinessResponse { + status: "healthy".into(), + database: ComponentStatus::Healthy, + cache: ComponentStatus::Healthy, + version: "0.1.0".into(), + }; + let json = serde_json::to_value(&resp).unwrap(); + assert_eq!(json["status"], "healthy"); + assert_eq!(json["database"], "healthy"); + assert_eq!(json["cache"], "healthy"); + } + + #[test] + fn readiness_response_serializes_degraded() { + let resp = ReadinessResponse { + status: "degraded".into(), + database: ComponentStatus::Unhealthy, + cache: ComponentStatus::Healthy, + version: "0.1.0".into(), + }; + let json = serde_json::to_value(&resp).unwrap(); + assert_eq!(json["status"], "degraded"); + assert_eq!(json["database"], "unhealthy"); + assert_eq!(json["cache"], "healthy"); + } + + #[test] + fn component_status_eq() { + assert_eq!(ComponentStatus::Healthy, ComponentStatus::Healthy); + assert_ne!(ComponentStatus::Healthy, ComponentStatus::Unhealthy); + } +} diff --git a/backend/src/api/handlers/mod.rs b/backend/src/api/handlers/mod.rs index 6b82151..4a2fb94 100644 --- a/backend/src/api/handlers/mod.rs +++ b/backend/src/api/handlers/mod.rs @@ -1,3 +1,4 @@ pub mod dashboard; +pub mod health; pub mod profiling; pub mod stellar; diff --git a/backend/src/services/metrics.rs b/backend/src/services/metrics.rs new file mode 100644 index 0000000..3e4f860 --- /dev/null +++ b/backend/src/services/metrics.rs @@ -0,0 +1,491 @@ +//! Prometheus metrics collection service. +//! +//! Provides a lightweight, zero-dependency metrics registry that tracks HTTP +//! request counts, latencies, database pool stats, and cache hit/miss rates. +//! Metrics are exposed in the Prometheus text exposition format via +//! `GET /metrics`. +//! +//! # Example +//! +//! ```rust,no_run +//! use backend::services::metrics::MetricsRegistry; +//! +//! let registry = MetricsRegistry::new(); +//! registry.http_requests_total.inc("GET", "/health", 200); +//! let output = registry.render(); +//! assert!(output.contains("http_requests_total")); +//! ``` + +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::Instant; + +use axum::{extract::State, http::StatusCode, response::IntoResponse}; +use tracing::{debug, instrument}; + +// --------------------------------------------------------------------------- +// Counter +// --------------------------------------------------------------------------- + +/// A thread-safe monotonically increasing counter with optional labels. +#[derive(Debug, Default)] +pub struct Counter { + inner: Mutex>, +} + +impl Counter { + pub fn new() -> Self { + Self::default() + } + + /// Increment the counter for the given label set. + pub fn inc(&self, labels: &str) { + let mut map = self.inner.lock().unwrap(); + *map.entry(labels.to_string()).or_insert(0) += 1; + } + + /// Increment by `n`. + pub fn inc_by(&self, labels: &str, n: u64) { + let mut map = self.inner.lock().unwrap(); + *map.entry(labels.to_string()).or_insert(0) += n; + } + + /// Snapshot all label→value pairs. + pub fn snapshot(&self) -> HashMap { + self.inner.lock().unwrap().clone() + } +} + +// --------------------------------------------------------------------------- +// Gauge +// --------------------------------------------------------------------------- + +/// A thread-safe gauge (can go up or down). +#[derive(Debug, Default)] +pub struct Gauge { + inner: Mutex>, +} + +impl Gauge { + pub fn new() -> Self { + Self::default() + } + + pub fn set(&self, labels: &str, value: f64) { + let mut map = self.inner.lock().unwrap(); + map.insert(labels.to_string(), value); + } + + pub fn snapshot(&self) -> HashMap { + self.inner.lock().unwrap().clone() + } +} + +// --------------------------------------------------------------------------- +// Histogram (fixed buckets) +// --------------------------------------------------------------------------- + +/// Observation buckets for latency histograms (milliseconds). +pub const LATENCY_BUCKETS_MS: &[f64] = &[5.0, 10.0, 25.0, 50.0, 100.0, 250.0, 500.0, 1000.0, 2500.0, f64::INFINITY]; + +/// A simple histogram with fixed upper-bound buckets. +#[derive(Debug)] +pub struct Histogram { + buckets: Vec, + inner: Mutex>, +} + +#[derive(Debug, Clone, Default)] +struct HistogramData { + counts: Vec, + sum: f64, + total: u64, +} + +impl Histogram { + pub fn new(buckets: Vec) -> Self { + Self { buckets, inner: Mutex::new(HashMap::new()) } + } + + /// Record an observation (value in milliseconds). + pub fn observe(&self, labels: &str, value_ms: f64) { + let mut map = self.inner.lock().unwrap(); + let data = map.entry(labels.to_string()).or_insert_with(|| HistogramData { + counts: vec![0; self.buckets.len()], + sum: 0.0, + total: 0, + }); + for (i, &bound) in self.buckets.iter().enumerate() { + if value_ms <= bound { + data.counts[i] += 1; + } + } + data.sum += value_ms; + data.total += 1; + } + + pub fn snapshot(&self) -> HashMap, f64, u64)> { + let map = self.inner.lock().unwrap(); + map.iter() + .map(|(k, v)| { + let buckets = self.buckets.iter().copied().zip(v.counts.iter().copied()).collect(); + (k.clone(), (buckets, v.sum, v.total)) + }) + .collect() + } +} + +// --------------------------------------------------------------------------- +// Registry +// --------------------------------------------------------------------------- + +/// Central metrics registry for the Crucible backend. +#[derive(Debug)] +pub struct MetricsRegistry { + /// Total HTTP requests, labelled by method, path, and status code. + pub http_requests_total: Counter, + /// HTTP request duration histogram (milliseconds). + pub http_request_duration_ms: Histogram, + /// Total errors, labelled by error kind. + pub errors_total: Counter, + /// Database connection pool size (active connections). + pub db_pool_connections: Gauge, + /// Cache hit counter. + pub cache_hits_total: Counter, + /// Cache miss counter. + pub cache_misses_total: Counter, + /// Total file uploads processed. + pub file_uploads_total: Counter, + /// Total bytes uploaded. + pub file_upload_bytes_total: Counter, + /// Application start timestamp (Unix seconds). + pub process_start_time_seconds: AtomicU64, +} + +impl MetricsRegistry { + pub fn new() -> Self { + use std::time::{SystemTime, UNIX_EPOCH}; + let start = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + Self { + http_requests_total: Counter::new(), + http_request_duration_ms: Histogram::new(LATENCY_BUCKETS_MS.to_vec()), + errors_total: Counter::new(), + db_pool_connections: Gauge::new(), + cache_hits_total: Counter::new(), + cache_misses_total: Counter::new(), + file_uploads_total: Counter::new(), + file_upload_bytes_total: Counter::new(), + process_start_time_seconds: AtomicU64::new(start), + } + } + + /// Record an HTTP request completion. + pub fn record_request(&self, method: &str, path: &str, status: u16, duration_ms: f64) { + let labels = format!(r#"method="{method}",path="{path}",status="{status}""#); + self.http_requests_total.inc(&labels); + let hist_labels = format!(r#"method="{method}",path="{path}""#); + self.http_request_duration_ms.observe(&hist_labels, duration_ms); + } + + /// Record an error by kind. + pub fn record_error(&self, kind: &str) { + self.errors_total.inc(&format!(r#"kind="{kind}""#)); + } + + /// Record a cache hit. + pub fn record_cache_hit(&self, cache: &str) { + self.cache_hits_total.inc(&format!(r#"cache="{cache}""#)); + } + + /// Record a cache miss. + pub fn record_cache_miss(&self, cache: &str) { + self.cache_misses_total.inc(&format!(r#"cache="{cache}""#)); + } + + /// Record a file upload. + pub fn record_file_upload(&self, mime: &str, bytes: u64) { + let label = format!(r#"mime="{mime}""#); + self.file_uploads_total.inc(&label); + self.file_upload_bytes_total.inc_by(&label, bytes); + } + + /// Record the current active database pool connection count. + pub fn record_db_pool_connections(&self, active_connections: u64) { + self.db_pool_connections.set(r#"pool="active""#, active_connections as f64); + } + + /// Render all metrics in Prometheus text exposition format. + #[instrument(skip(self))] + pub fn render(&self) -> String { + let mut out = String::with_capacity(4096); + + // process_start_time_seconds + let start = self.process_start_time_seconds.load(Ordering::Relaxed); + out.push_str("# HELP process_start_time_seconds Unix timestamp of process start.\n"); + out.push_str("# TYPE process_start_time_seconds gauge\n"); + out.push_str(&format!("process_start_time_seconds {start}\n\n")); + + // http_requests_total + out.push_str("# HELP http_requests_total Total HTTP requests by method, path, and status.\n"); + out.push_str("# TYPE http_requests_total counter\n"); + for (labels, count) in self.http_requests_total.snapshot() { + out.push_str(&format!("http_requests_total{{{labels}}} {count}\n")); + } + out.push('\n'); + + // http_request_duration_ms + out.push_str("# HELP http_request_duration_ms HTTP request duration in milliseconds.\n"); + out.push_str("# TYPE http_request_duration_ms histogram\n"); + for (labels, (buckets, sum, count)) in self.http_request_duration_ms.snapshot() { + for (bound, bucket_count) in &buckets { + let le = if bound.is_infinite() { "+Inf".to_string() } else { bound.to_string() }; + out.push_str(&format!( + "http_request_duration_ms_bucket{{{labels},le=\"{le}\"}} {bucket_count}\n" + )); + } + out.push_str(&format!("http_request_duration_ms_sum{{{labels}}} {sum}\n")); + out.push_str(&format!("http_request_duration_ms_count{{{labels}}} {count}\n")); + } + out.push('\n'); + + // errors_total + out.push_str("# HELP errors_total Total application errors by kind.\n"); + out.push_str("# TYPE errors_total counter\n"); + for (labels, count) in self.errors_total.snapshot() { + out.push_str(&format!("errors_total{{{labels}}} {count}\n")); + } + out.push('\n'); + + // db_pool_connections + out.push_str("# HELP db_pool_connections Active database pool connections.\n"); + out.push_str("# TYPE db_pool_connections gauge\n"); + for (labels, value) in self.db_pool_connections.snapshot() { + out.push_str(&format!("db_pool_connections{{{labels}}} {value}\n")); + } + out.push('\n'); + + // cache_hits_total / cache_misses_total + out.push_str("# HELP cache_hits_total Total cache hits.\n"); + out.push_str("# TYPE cache_hits_total counter\n"); + for (labels, count) in self.cache_hits_total.snapshot() { + out.push_str(&format!("cache_hits_total{{{labels}}} {count}\n")); + } + out.push('\n'); + + out.push_str("# HELP cache_misses_total Total cache misses.\n"); + out.push_str("# TYPE cache_misses_total counter\n"); + for (labels, count) in self.cache_misses_total.snapshot() { + out.push_str(&format!("cache_misses_total{{{labels}}} {count}\n")); + } + out.push('\n'); + + // file_uploads_total / file_upload_bytes_total + out.push_str("# HELP file_uploads_total Total file uploads by MIME type.\n"); + out.push_str("# TYPE file_uploads_total counter\n"); + for (labels, count) in self.file_uploads_total.snapshot() { + out.push_str(&format!("file_uploads_total{{{labels}}} {count}\n")); + } + out.push('\n'); + + out.push_str("# HELP file_upload_bytes_total Total bytes uploaded by MIME type.\n"); + out.push_str("# TYPE file_upload_bytes_total counter\n"); + for (labels, count) in self.file_upload_bytes_total.snapshot() { + out.push_str(&format!("file_upload_bytes_total{{{labels}}} {count}\n")); + } + + debug!("Rendered {} bytes of metrics", out.len()); + out + } +} + +impl Default for MetricsRegistry { + fn default() -> Self { + Self::new() + } +} + +/// Shared, cheaply-cloneable handle to the metrics registry. +pub type SharedMetrics = Arc; + +/// Axum handler: `GET /metrics` — returns Prometheus text format. +#[instrument(skip(metrics))] +pub async fn metrics_handler( + State(metrics): State, +) -> impl IntoResponse { + let body = metrics.render(); + ( + StatusCode::OK, + [("content-type", "text/plain; version=0.0.4; charset=utf-8")], + body, + ) +} + +/// Timing guard: records request duration on drop. +pub struct RequestTimer<'a> { + registry: &'a MetricsRegistry, + method: String, + path: String, + status: u16, + start: Instant, +} + +impl<'a> RequestTimer<'a> { + pub fn new(registry: &'a MetricsRegistry, method: &str, path: &str, status: u16) -> Self { + Self { + registry, + method: method.to_string(), + path: path.to_string(), + status, + start: Instant::now(), + } + } +} + +impl Drop for RequestTimer<'_> { + fn drop(&mut self) { + let elapsed_ms = self.start.elapsed().as_secs_f64() * 1000.0; + self.registry.record_request(&self.method, &self.path, self.status, elapsed_ms); + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use axum::{body::Body, http::Request}; + + #[test] + fn counter_increments() { + let c = Counter::new(); + c.inc("a"); + c.inc("a"); + c.inc("b"); + let snap = c.snapshot(); + assert_eq!(snap["a"], 2); + assert_eq!(snap["b"], 1); + } + + #[test] + fn counter_inc_by() { + let c = Counter::new(); + c.inc_by("x", 10); + assert_eq!(c.snapshot()["x"], 10); + } + + #[test] + fn gauge_set() { + let g = Gauge::new(); + g.set("pool", 5.0); + g.set("pool", 7.0); + assert_eq!(g.snapshot()["pool"], 7.0); + } + + #[test] + fn histogram_observe() { + let h = Histogram::new(LATENCY_BUCKETS_MS.to_vec()); + h.observe("route", 20.0); + h.observe("route", 80.0); + let snap = h.snapshot(); + let (buckets, sum, count) = &snap["route"]; + assert_eq!(*count, 2); + assert!((sum - 100.0).abs() < f64::EPSILON); + // 20ms falls in ≤25ms bucket + let bucket_25 = buckets.iter().find(|(b, _)| *b == 25.0).unwrap(); + assert_eq!(bucket_25.1, 1); + } + + #[test] + fn registry_record_request() { + let r = MetricsRegistry::new(); + r.record_request("GET", "/health", 200, 5.0); + let snap = r.http_requests_total.snapshot(); + assert_eq!(snap[r#"method="GET",path="/health",status="200""#], 1); + } + + #[test] + fn registry_record_error() { + let r = MetricsRegistry::new(); + r.record_error("database"); + r.record_error("database"); + let snap = r.errors_total.snapshot(); + assert_eq!(snap[r#"kind="database""#], 2); + } + + #[test] + fn registry_record_cache() { + let r = MetricsRegistry::new(); + r.record_cache_hit("redis"); + r.record_cache_miss("redis"); + assert_eq!(r.cache_hits_total.snapshot()[r#"cache="redis""#], 1); + assert_eq!(r.cache_misses_total.snapshot()[r#"cache="redis""#], 1); + } + + #[test] + fn registry_record_file_upload() { + let r = MetricsRegistry::new(); + r.record_file_upload("application/wasm", 1024); + let snap = r.file_upload_bytes_total.snapshot(); + assert_eq!(snap[r#"mime="application/wasm""#], 1024); + } + + #[test] + fn render_contains_expected_metric_names() { + let r = MetricsRegistry::new(); + r.record_request("POST", "/upload", 201, 42.0); + r.record_error("redis"); + let output = r.render(); + assert!(output.contains("http_requests_total")); + assert!(output.contains("http_request_duration_ms_bucket")); + assert!(output.contains("errors_total")); + assert!(output.contains("process_start_time_seconds")); + } + + #[test] + fn render_prometheus_text_format() { + let r = MetricsRegistry::new(); + r.record_request("GET", "/api/v1/status", 200, 10.0); + let output = r.render(); + // Must contain HELP and TYPE lines + assert!(output.contains("# HELP http_requests_total")); + assert!(output.contains("# TYPE http_requests_total counter")); + // Must contain the labelled counter line + assert!(output.contains(r#"method="GET""#)); + } + + #[tokio::test] + async fn metrics_handler_returns_prometheus_text() { + let metrics = Arc::new(MetricsRegistry::new()); + let app = axum::Router::new() + .route("/metrics", axum::routing::get(metrics_handler)) + .with_state(metrics.clone()); + + let response = app + .oneshot(Request::builder().uri("/metrics").body(Body::empty()).unwrap()) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.headers()["content-type"], "text/plain; version=0.0.4; charset=utf-8"); + let body = axum::body::to_bytes(response.into_body()).await.unwrap(); + let body = std::str::from_utf8(&body).unwrap(); + assert!(body.contains("process_start_time_seconds")); + } + + #[test] + fn request_timer_records_on_drop() { + let r = MetricsRegistry::new(); + { + let _t = RequestTimer::new(&r, "GET", "/test", 200); + // timer drops here + } + let snap = r.http_requests_total.snapshot(); + assert_eq!(snap.values().sum::(), 1); + } +} diff --git a/backend/src/services/mod.rs b/backend/src/services/mod.rs index c9ffa9b..24483bc 100644 --- a/backend/src/services/mod.rs +++ b/backend/src/services/mod.rs @@ -3,7 +3,7 @@ pub mod alerts; pub mod error_recovery; pub mod feature_flags; pub mod log_aggregator; -pub mod log_alerts; +pub mod metrics; pub mod sys_metrics; pub mod business_metrics; pub mod tracing; diff --git a/backend/src/utils/errors.rs b/backend/src/utils/errors.rs new file mode 100644 index 0000000..1ef0811 --- /dev/null +++ b/backend/src/utils/errors.rs @@ -0,0 +1,369 @@ +//! Custom error type hierarchy for the Crucible backend. +//! +//! Provides domain-specific error types that compose into [`AppError`] for +//! HTTP responses, while preserving rich context for logging and tracing. + +use axum::{http::StatusCode, response::{IntoResponse, Response}, Json}; +use serde::Serialize; +use serde_json::json; +use thiserror::Error; +use tracing::error; + +/// Result type alias for backend services and handlers. +pub type Result = std::result::Result; + +// --------------------------------------------------------------------------- +// Domain error types +// --------------------------------------------------------------------------- + +/// Errors arising from file upload and validation operations. +#[derive(Debug, Error)] +pub enum FileError { + #[error("File too large: {size} bytes exceeds limit of {limit} bytes")] + TooLarge { size: u64, limit: u64 }, + + #[error("Unsupported MIME type: {0}")] + UnsupportedMimeType(String), + + #[error("Invalid file name: {0}")] + InvalidFileName(String), + + #[error("Malformed file content: {0}")] + MalformedContent(String), +} + +/// Errors arising from database operations. +#[derive(Debug, Error)] +pub enum DatabaseError { + #[error("Record not found: {0}")] + NotFound(String), + + #[error("Unique constraint violation: {0}")] + UniqueViolation(String), + + #[error("Foreign key violation: {0}")] + ForeignKeyViolation(String), + + #[error("Connection error: {0}")] + Connection(String), + + #[error("Query error: {0}")] + Query(#[from] sqlx::Error), +} + +/// Errors arising from cache / Redis operations. +#[derive(Debug, Error)] +pub enum CacheError { + #[error("Cache miss for key: {0}")] + Miss(String), + + #[error("Serialization error: {0}")] + Serialization(String), + + #[error("Redis error: {0}")] + Redis(#[from] redis::RedisError), +} + +/// Errors arising from authentication and authorisation. +#[derive(Debug, Error)] +pub enum AuthError { + #[error("Missing credentials")] + MissingCredentials, + + #[error("Invalid token")] + InvalidToken, + + #[error("Token expired")] + TokenExpired, + + #[error("Insufficient permissions: required {required}, got {actual}")] + InsufficientPermissions { required: String, actual: String }, +} + +/// Errors arising from external service calls (e.g. Stellar network). +#[derive(Debug, Error)] +pub enum ExternalServiceError { + #[error("Request timeout after {timeout_ms}ms")] + Timeout { timeout_ms: u64 }, + + #[error("Service unavailable: {0}")] + Unavailable(String), + + #[error("Unexpected response: {0}")] + UnexpectedResponse(String), +} + +// --------------------------------------------------------------------------- +// Top-level application error +// --------------------------------------------------------------------------- + +/// Structured error response body returned to API clients. +#[derive(Debug, Serialize)] +pub struct ErrorResponse { + pub code: String, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub details: Option, +} + +/// Unified application error that maps all domain errors to HTTP responses. +#[derive(Debug, Error)] +pub enum AppError { + // --- 400 Bad Request --- + #[error("Bad request: {0}")] + BadRequest(String), + + // --- 400 Validation --- + #[error("Validation error: {0}")] + Validation(String), + + // --- 401 Unauthorized --- + #[error("Unauthorized: {0}")] + Unauthorized(String), + + // --- 403 Forbidden --- + #[error("Forbidden: {0}")] + Forbidden(String), + + // --- 404 Not Found --- + #[error("Not found: {0}")] + NotFound(String), + + // --- 409 Conflict --- + #[error("Conflict: {0}")] + Conflict(String), + + // --- 413 Payload Too Large --- + #[error("Payload too large: {0}")] + PayloadTooLarge(String), + + // --- 415 Unsupported Media Type --- + #[error("Unsupported media type: {0}")] + UnsupportedMediaType(String), + + // --- 500 Database --- + #[error("Database error: {0}")] + Database(#[from] sqlx::Error), + + // --- 500 Redis --- + #[error("Redis error: {0}")] + Redis(#[from] redis::RedisError), + + // --- 500 Internal --- + #[error("Internal error: {0}")] + Internal(String), + + // --- Domain errors (mapped to appropriate HTTP codes) --- + #[error(transparent)] + File(#[from] FileError), + + #[error(transparent)] + Auth(#[from] AuthError), + + #[error(transparent)] + Cache(#[from] CacheError), + + #[error(transparent)] + ExternalService(#[from] ExternalServiceError), +} + +impl AppError { + fn status_and_code(&self) -> (StatusCode, &'static str) { + match self { + AppError::BadRequest(_) => (StatusCode::BAD_REQUEST, "bad_request"), + AppError::Validation(_) => (StatusCode::UNPROCESSABLE_ENTITY, "validation_error"), + AppError::Unauthorized(_) => (StatusCode::UNAUTHORIZED, "unauthorized"), + AppError::Forbidden(_) => (StatusCode::FORBIDDEN, "forbidden"), + AppError::NotFound(_) => (StatusCode::NOT_FOUND, "not_found"), + AppError::Conflict(_) => (StatusCode::CONFLICT, "conflict"), + AppError::PayloadTooLarge(_) => (StatusCode::PAYLOAD_TOO_LARGE, "payload_too_large"), + AppError::UnsupportedMediaType(_) => { + (StatusCode::UNSUPPORTED_MEDIA_TYPE, "unsupported_media_type") + } + AppError::Database(e) => { + error!("Database error: {e:?}"); + (StatusCode::INTERNAL_SERVER_ERROR, "database_error") + } + AppError::Redis(e) => { + error!("Redis error: {e:?}"); + (StatusCode::INTERNAL_SERVER_ERROR, "redis_error") + } + AppError::Internal(msg) => { + error!("Internal error: {msg}"); + (StatusCode::INTERNAL_SERVER_ERROR, "internal_error") + } + AppError::File(e) => match e { + FileError::TooLarge { .. } => (StatusCode::PAYLOAD_TOO_LARGE, "file_too_large"), + FileError::UnsupportedMimeType(_) => { + (StatusCode::UNSUPPORTED_MEDIA_TYPE, "unsupported_mime_type") + } + FileError::InvalidFileName(_) | FileError::MalformedContent(_) => { + (StatusCode::BAD_REQUEST, "invalid_file") + } + }, + AppError::Auth(e) => match e { + AuthError::MissingCredentials | AuthError::InvalidToken | AuthError::TokenExpired => { + (StatusCode::UNAUTHORIZED, "auth_error") + } + AuthError::InsufficientPermissions { .. } => (StatusCode::FORBIDDEN, "forbidden"), + }, + AppError::Cache(e) => { + error!("Cache error: {e:?}"); + (StatusCode::INTERNAL_SERVER_ERROR, "cache_error") + } + AppError::ExternalService(e) => { + error!("External service error: {e:?}"); + (StatusCode::BAD_GATEWAY, "external_service_error") + } + } + } +} + +impl IntoResponse for AppError { + fn into_response(self) -> Response { + let (status, code) = self.status_and_code(); + let message = self.to_string(); + ( + status, + Json(json!({ "code": code, "message": message })), + ) + .into_response() + } +} + +// --------------------------------------------------------------------------- +// Conversions from domain errors to AppError +// --------------------------------------------------------------------------- + +impl From for AppError { + fn from(e: DatabaseError) -> Self { + match e { + DatabaseError::NotFound(msg) => AppError::NotFound(msg), + DatabaseError::UniqueViolation(msg) => AppError::Conflict(msg), + DatabaseError::ForeignKeyViolation(msg) => AppError::BadRequest(msg), + DatabaseError::Connection(msg) => AppError::Internal(msg), + DatabaseError::Query(sqlx::Error::PoolTimedOut) => { + AppError::Internal("database pool timed out".into()) + } + DatabaseError::Query(e) => AppError::Database(e), + } + } +} + +impl From for AppError { + fn from(err: sqlx::Error) -> Self { + AppError::Database(err) + } +} + +impl From for AppError { + fn from(err: redis::RedisError) -> Self { + AppError::Redis(err) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn file_error_too_large_display() { + let e = FileError::TooLarge { size: 20_000_000, limit: 10_000_000 }; + assert!(e.to_string().contains("20000000")); + } + + #[test] + fn file_error_unsupported_mime() { + let e = FileError::UnsupportedMimeType("application/exe".into()); + assert!(e.to_string().contains("application/exe")); + } + + #[test] + fn auth_error_insufficient_permissions_display() { + let e = AuthError::InsufficientPermissions { + required: "admin".into(), + actual: "user".into(), + }; + assert!(e.to_string().contains("admin")); + } + + #[test] + fn app_error_not_found_status() { + let e = AppError::NotFound("contract".into()); + let (status, code) = e.status_and_code(); + assert_eq!(status, StatusCode::NOT_FOUND); + assert_eq!(code, "not_found"); + } + + #[test] + fn app_error_validation_status() { + let e = AppError::Validation("field required".into()); + let (status, code) = e.status_and_code(); + assert_eq!(status, StatusCode::UNPROCESSABLE_ENTITY); + assert_eq!(code, "validation_error"); + } + + #[test] + fn app_error_from_file_too_large() { + let e = AppError::File(FileError::TooLarge { size: 1, limit: 0 }); + let (status, code) = e.status_and_code(); + assert_eq!(status, StatusCode::PAYLOAD_TOO_LARGE); + assert_eq!(code, "file_too_large"); + } + + #[test] + fn app_error_from_auth_forbidden() { + let e = AppError::Auth(AuthError::InsufficientPermissions { + required: "admin".into(), + actual: "user".into(), + }); + let (status, code) = e.status_and_code(); + assert_eq!(status, StatusCode::FORBIDDEN); + assert_eq!(code, "forbidden"); + } + + #[test] + fn database_error_not_found_converts() { + let e: AppError = DatabaseError::NotFound("user 42".into()).into(); + assert!(matches!(e, AppError::NotFound(_))); + } + + #[test] + fn database_error_unique_violation_converts() { + let e: AppError = DatabaseError::UniqueViolation("email".into()).into(); + assert!(matches!(e, AppError::Conflict(_))); + } + + #[test] + fn sqlx_error_converts_to_app_error() { + let err = sqlx::Error::RowNotFound; + let e: AppError = err.into(); + assert!(matches!(e, AppError::Database(_))); + } + + #[tokio::test] + async fn app_error_into_response_renders_json() { + let e = AppError::NotFound("contract".into()); + let response = e.into_response(); + let body = response.into_body(); + let body = axum::body::to_bytes(body, usize::MAX).await.unwrap(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!(json["code"], "not_found"); + } + + #[test] + fn error_response_serializes() { + let resp = ErrorResponse { + code: "not_found".into(), + message: "Resource not found".into(), + details: None, + }; + let json = serde_json::to_string(&resp).unwrap(); + assert!(json.contains("\"code\":\"not_found\"")); + assert!(!json.contains("details")); + } +} diff --git a/backend/src/utils/file_validation.rs b/backend/src/utils/file_validation.rs new file mode 100644 index 0000000..9ba5694 --- /dev/null +++ b/backend/src/utils/file_validation.rs @@ -0,0 +1,326 @@ +//! File upload validation utilities. +//! +//! Validates uploaded files for size, MIME type, and file name safety before +//! they are persisted or processed. Uses magic-byte sniffing to verify that +//! the declared content type matches the actual file content. + +use std::collections::HashSet; +use tracing::{debug, instrument, warn}; + +use crate::utils::errors::FileError; + +// --------------------------------------------------------------------------- +// Configuration +// --------------------------------------------------------------------------- + +/// Default maximum file size: 10 MiB. +pub const DEFAULT_MAX_SIZE: u64 = 10 * 1024 * 1024; + +/// Allowed MIME types for contract-related uploads. +pub const ALLOWED_MIME_TYPES: &[&str] = &[ + "application/wasm", + "application/octet-stream", + "application/json", + "text/plain", + "text/x-rust", +]; + +/// Configuration for file upload validation. +#[derive(Debug, Clone)] +pub struct ValidationConfig { + /// Maximum allowed file size in bytes. + pub max_size: u64, + /// Set of permitted MIME types. + pub allowed_mime_types: HashSet, +} + +impl Default for ValidationConfig { + fn default() -> Self { + Self { + max_size: DEFAULT_MAX_SIZE, + allowed_mime_types: ALLOWED_MIME_TYPES + .iter() + .map(|s| s.to_string()) + .collect(), + } + } +} + +impl ValidationConfig { + /// Create a config with a custom size limit and the default MIME allow-list. + pub fn with_max_size(max_size: u64) -> Self { + Self { max_size, ..Default::default() } + } + + /// Add an extra allowed MIME type. + pub fn allow_mime(mut self, mime: impl Into) -> Self { + self.allowed_mime_types.insert(mime.into()); + self + } +} + +// --------------------------------------------------------------------------- +// Validation result +// --------------------------------------------------------------------------- + +/// Metadata produced after a successful validation pass. +#[derive(Debug, Clone)] +pub struct ValidatedFile { + /// Original file name (sanitised). + pub file_name: String, + /// Detected MIME type. + pub mime_type: String, + /// File size in bytes. + pub size: u64, + /// Raw file bytes. + pub bytes: Vec, +} + +// --------------------------------------------------------------------------- +// Core validator +// --------------------------------------------------------------------------- + +/// Validates a raw file upload against the provided configuration. +/// +/// # Errors +/// +/// Returns [`FileError`] if any validation step fails. +#[instrument(skip(bytes, config), fields(file_name = %file_name, size = bytes.len()))] +pub fn validate_upload( + file_name: &str, + declared_mime: &str, + bytes: Vec, + config: &ValidationConfig, +) -> Result { + let size = bytes.len() as u64; + + // 1. Size check + if size > config.max_size { + warn!(size, limit = config.max_size, "File exceeds size limit"); + return Err(FileError::TooLarge { size, limit: config.max_size }); + } + + // 2. File name safety + let safe_name = sanitize_file_name(file_name)?; + + // 3. MIME type check (declared) + let normalized_mime = declared_mime.split(';').next().unwrap_or("").trim().to_lowercase(); + if !config.allowed_mime_types.contains(&normalized_mime) { + warn!(mime = %normalized_mime, "Unsupported MIME type"); + return Err(FileError::UnsupportedMimeType(normalized_mime)); + } + + // 4. Magic-byte verification + let detected = detect_mime(&bytes); + if let Some(detected_mime) = detected { + if detected_mime != normalized_mime + && !is_compatible_mime(&normalized_mime, detected_mime) + { + warn!( + declared = %normalized_mime, + detected = %detected_mime, + "MIME type mismatch between declared and detected" + ); + return Err(FileError::MalformedContent(format!( + "declared MIME '{normalized_mime}' does not match detected '{detected_mime}'" + ))); + } + } + + debug!(file_name = %safe_name, size, mime = %normalized_mime, "File validation passed"); + + Ok(ValidatedFile { + file_name: safe_name, + mime_type: normalized_mime, + size, + bytes, + }) +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Sanitise a file name: reject path traversal, null bytes, and empty names. +fn sanitize_file_name(name: &str) -> Result { + if name.is_empty() { + return Err(FileError::InvalidFileName("file name is empty".into())); + } + if name.contains('\0') { + return Err(FileError::InvalidFileName("file name contains null byte".into())); + } + // Strip any directory component — keep only the final segment. + let base = name + .replace('\\', "/") + .split('/') + .filter(|s| !s.is_empty() && *s != "..") + .last() + .unwrap_or("") + .to_string(); + + if base.is_empty() || base == ".." { + return Err(FileError::InvalidFileName(format!("unsafe file name: {name}"))); + } + Ok(base) +} + +/// Detect MIME type from magic bytes. +fn detect_mime(bytes: &[u8]) -> Option<&'static str> { + // WebAssembly magic: \0asm + if bytes.starts_with(&[0x00, 0x61, 0x73, 0x6d]) { + return Some("application/wasm"); + } + + let data = if bytes.starts_with(&[0xEF, 0xBB, 0xBF]) { + &bytes[3..] + } else { + bytes + }; + + if data + .iter() + .position(|b| !b.is_ascii_whitespace()) + .map(|i| matches!(data[i], b'{' | b'[')) + .unwrap_or(false) + { + return Some("application/json"); + } + + if is_valid_text(bytes) { + return Some("text/plain"); + } + + Some("application/octet-stream") +} + +fn is_valid_text(bytes: &[u8]) -> bool { + match std::str::from_utf8(bytes) { + Ok(text) => text.chars().all(|c| { + c == '\n' || c == '\r' || c == '\t' || !c.is_control() + }), + Err(_) => false, + } +} + +/// Returns true when the declared and detected MIME types are compatible +/// (e.g. `application/octet-stream` is a valid fallback for any binary). +fn is_compatible_mime(declared: &str, detected: &str) -> bool { + declared == "application/octet-stream" + || detected == "application/octet-stream" + || (declared.starts_with("text/") && detected.starts_with("text/")) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn cfg() -> ValidationConfig { + ValidationConfig::default() + } + + #[test] + fn valid_wasm_upload() { + let bytes = vec![0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00]; + let result = validate_upload("contract.wasm", "application/wasm", bytes, &cfg()); + assert!(result.is_ok()); + let f = result.unwrap(); + assert_eq!(f.file_name, "contract.wasm"); + assert_eq!(f.mime_type, "application/wasm"); + } + + #[test] + fn valid_json_upload() { + let bytes = br#"{"name":"test"}"#.to_vec(); + let result = validate_upload("meta.json", "application/json", bytes, &cfg()); + assert!(result.is_ok()); + } + + #[test] + fn rejects_file_too_large() { + let bytes = vec![0u8; 11 * 1024 * 1024]; + let result = validate_upload("big.wasm", "application/wasm", bytes, &cfg()); + assert!(matches!(result, Err(FileError::TooLarge { .. }))); + } + + #[test] + fn rejects_unsupported_mime() { + let bytes = b"data".to_vec(); + let result = validate_upload("file.exe", "application/exe", bytes, &cfg()); + assert!(matches!(result, Err(FileError::UnsupportedMimeType(_)))); + } + + #[test] + fn rejects_path_traversal() { + let bytes = b"data".to_vec(); + let result = validate_upload("../../etc/passwd", "text/plain", bytes, &cfg()); + // Should either sanitise or reject + match result { + Ok(f) => assert!(!f.file_name.contains("..")), + Err(FileError::InvalidFileName(_)) => {} + Err(e) => panic!("unexpected error: {e}"), + } + } + + #[test] + fn rejects_empty_file_name() { + let bytes = b"data".to_vec(); + let result = validate_upload("", "text/plain", bytes, &cfg()); + assert!(matches!(result, Err(FileError::InvalidFileName(_)))); + } + + #[test] + fn strips_directory_prefix() { + let bytes = b"hello".to_vec(); + let result = validate_upload("uploads/contract.txt", "text/plain", bytes, &cfg()); + assert!(result.is_ok()); + assert_eq!(result.unwrap().file_name, "contract.txt"); + } + + #[test] + fn rejects_mime_mismatch() { + // Declare JSON but send WASM magic bytes + let bytes = vec![0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00]; + let result = validate_upload("contract.json", "application/json", bytes, &cfg()); + assert!(matches!(result, Err(FileError::MalformedContent(_)))); + } + + #[test] + fn octet_stream_is_compatible_with_any() { + let bytes = vec![0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00]; + let result = validate_upload("contract.wasm", "application/octet-stream", bytes, &cfg()); + assert!(result.is_ok()); + } + + #[test] + fn custom_config_allows_extra_mime() { + let cfg = ValidationConfig::default().allow_mime("image/png"); + let bytes = b"PNG data".to_vec(); + // Will fail magic check but pass MIME allow-list check + let result = validate_upload("logo.png", "image/png", bytes, &cfg); + // text/plain detected for ASCII bytes — compatible via text/* rule? No. + // Just verify the MIME allow-list step passes (error is MalformedContent, not UnsupportedMimeType) + assert!(!matches!(result, Err(FileError::UnsupportedMimeType(_)))); + } + + #[test] + fn sanitize_null_byte_rejected() { + let result = sanitize_file_name("file\0name.txt"); + assert!(matches!(result, Err(FileError::InvalidFileName(_)))); + } + + #[test] + fn detect_mime_wasm() { + let bytes = vec![0x00, 0x61, 0x73, 0x6d]; + assert_eq!(detect_mime(&bytes), Some("application/wasm")); + } + + #[test] + fn detect_mime_json() { + let bytes = b"{\"key\":\"val\"}".to_vec(); + assert_eq!(detect_mime(&bytes), Some("application/json")); + } +} diff --git a/backend/src/utils/mod.rs b/backend/src/utils/mod.rs index ee92e54..2ec2a2e 100644 --- a/backend/src/utils/mod.rs +++ b/backend/src/utils/mod.rs @@ -1,3 +1,5 @@ +pub mod errors; +pub mod file_validation; pub mod json_schema; pub mod serialization; pub mod xdr;