diff --git a/.gitignore b/.gitignore index 4f9f2e3..4c66556 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target /logs .vscode +.cursor diff --git a/dev-env/init-scripts/init-schema.sql b/dev-env/init-scripts/init-schema.sql index a4f43bd..2bf7819 100644 --- a/dev-env/init-scripts/init-schema.sql +++ b/dev-env/init-scripts/init-schema.sql @@ -160,3 +160,19 @@ CREATE INDEX IF NOT EXISTS idx_worker_events_worker_year ON worker_events(worker CREATE INDEX IF NOT EXISTS idx_worker_events_worker_action_week ON worker_events(worker_id, action, bucket_week); CREATE INDEX IF NOT EXISTS idx_worker_events_worker_action_month ON worker_events(worker_id, action, bucket_month); CREATE INDEX IF NOT EXISTS idx_worker_events_worker_action_year ON worker_events(worker_id, action, bucket_year); + +-- Batched rate-limit violations (aggregated per flush window + gateway). +-- "details" stores a JSON object with per-client counters. +CREATE TABLE IF NOT EXISTS rate_limit_violations ( + id BIGSERIAL PRIMARY KEY, + gateway_name VARCHAR(255) NOT NULL, + window_start TIMESTAMP WITHOUT TIME ZONE NOT NULL, + window_end TIMESTAMP WITHOUT TIME ZONE NOT NULL, + total_count BIGINT NOT NULL CHECK (total_count >= 0), + details JSONB NOT NULL, + created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT (NOW() AT TIME ZONE 'UTC') +); +CREATE INDEX IF NOT EXISTS idx_rate_limit_violations_gateway_window + ON rate_limit_violations(gateway_name, window_start DESC); +CREATE INDEX IF NOT EXISTS idx_rate_limit_violations_window + ON rate_limit_violations(window_start DESC); diff --git a/src/db/event_recorder.rs b/src/db/event_recorder.rs index b6b14bd..dfdb25d 100644 --- a/src/db/event_recorder.rs +++ b/src/db/event_recorder.rs @@ -1,3 +1,4 @@ +use std::collections::BTreeMap; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; @@ -7,12 +8,22 @@ use tokio_util::sync::CancellationToken; use tracing::error; use uuid::Uuid; -use crate::db::{ActivityEventRow, EventSink, EventSinkHandle, WorkerEventRow}; +use crate::db::{ + ActivityEventRow, EventSink, EventSinkHandle, RateLimitViolationBatchRow, WorkerEventRow, +}; #[derive(Clone)] enum EventRow { Activity(ActivityEventRow), Worker(WorkerEventRow), + RateLimitViolation(RateLimitViolationRow), +} + +#[derive(Clone)] +struct RateLimitViolationRow { + gateway_name: String, + client_id: String, + created_at: chrono::DateTime, } #[derive(Clone)] @@ -94,6 +105,19 @@ impl EventRecorder { self.enqueue(EventRow::Worker(row)); } + /// Records a single rate-limit denial event for the provided client key. + /// + /// The recorder stores these events in memory and periodically flushes them as + /// aggregated batch rows so we avoid writing one database row per denial. + pub fn record_rate_limit_violation(&self, client_id: &str) { + let row = RateLimitViolationRow { + gateway_name: self.gateway_name.to_string(), + client_id: client_id.to_string(), + created_at: Utc::now(), + }; + self.enqueue(EventRow::RateLimitViolation(row)); + } + fn spawn_flusher(&self, flush_interval: Duration, shutdown: CancellationToken) { let sink = Arc::clone(&self.sink); let queue = Arc::clone(&self.queue); @@ -148,10 +172,12 @@ impl EventRecorder { ) -> anyhow::Result<()> { let mut activity_rows: Vec = Vec::new(); let mut worker_rows: Vec = Vec::new(); + let mut rate_limit_rows: Vec = Vec::new(); while let Some(entry) = queue.pop() { match &**entry { EventRow::Activity(row) => activity_rows.push(row.clone()), EventRow::Worker(row) => worker_rows.push(row.clone()), + EventRow::RateLimitViolation(row) => rate_limit_rows.push(row.clone()), } } @@ -171,6 +197,48 @@ impl EventRecorder { Self::enqueue_with_limit(queue, capacity, dropped, EventRow::Worker(row)); } } + + if !rate_limit_rows.is_empty() { + let mut details = BTreeMap::::new(); + let mut window_start = rate_limit_rows[0].created_at; + let mut window_end = rate_limit_rows[0].created_at; + for row in &rate_limit_rows { + *details.entry(row.client_id.clone()).or_insert(0) += 1; + if row.created_at < window_start { + window_start = row.created_at; + } + if row.created_at > window_end { + window_end = row.created_at; + } + } + + let mut details_pairs: Vec<(String, i64)> = details.into_iter().collect(); + details_pairs.sort_by(|a, b| a.0.cmp(&b.0)); + let details_map: serde_json::Map = details_pairs + .into_iter() + .map(|(k, v)| (k, serde_json::Value::from(v))) + .collect(); + let batch_row = RateLimitViolationBatchRow { + gateway_name: rate_limit_rows[0].gateway_name.clone(), + window_start, + window_end, + total_count: rate_limit_rows.len() as i64, + details: serde_json::Value::Object(details_map), + created_at: Utc::now(), + }; + let batch_rows = vec![batch_row]; + if let Err(e) = sink.record_rate_limit_violation_batches(&batch_rows).await { + error!(error = ?e, "Failed to flush rate-limit violation aggregates"); + for row in rate_limit_rows { + Self::enqueue_with_limit( + queue, + capacity, + dropped, + EventRow::RateLimitViolation(row), + ); + } + } + } Ok(()) } @@ -180,3 +248,43 @@ impl EventRecorder { let _ = Self::flush_once(&self.sink, &self.queue, self.capacity, &self.dropped).await; } } + +#[cfg(all(test, feature = "test-support"))] +mod tests { + use super::EventRecorder; + use crate::db::{EventSinkHandle, InMemoryEventSink}; + use std::sync::Arc; + use std::time::Duration; + use tokio_util::sync::CancellationToken; + + #[tokio::test] + async fn rate_limit_violations_are_aggregated_per_client() { + let sink = InMemoryEventSink::default(); + let shutdown = CancellationToken::new(); + let recorder = EventRecorder::new( + Arc::new(EventSinkHandle::InMemory(sink.clone())), + Arc::from("gw-test"), + Duration::from_secs(60), + 1024, + shutdown.clone(), + ); + + recorder.record_rate_limit_violation("user:alice"); + recorder.record_rate_limit_violation("user:alice"); + recorder.record_rate_limit_violation("ip:10"); + recorder.flush_once_for_test().await; + + let rows = sink.rate_limit_rows().await; + assert_eq!(rows.len(), 1); + let row = &rows[0]; + assert_eq!(row.gateway_name, "gw-test"); + assert_eq!(row.total_count, 3); + assert_eq!( + row.details, + serde_json::json!({ + "ip:10": 1, + "user:alice": 2 + }) + ); + } +} diff --git a/src/db/mod.rs b/src/db/mod.rs index 756f241..e061fbf 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -55,6 +55,16 @@ pub struct WorkerEventRow { pub created_at: DateTime, } +#[derive(Clone)] +pub struct RateLimitViolationBatchRow { + pub gateway_name: String, + pub window_start: DateTime, + pub window_end: DateTime, + pub total_count: i64, + pub details: serde_json::Value, + pub created_at: DateTime, +} + pub struct DatabaseBuilder { sslcert_path: Option, sslkey_path: Option, @@ -71,6 +81,10 @@ pub struct DatabaseBuilder { pub trait EventSink: Send + Sync { async fn record_activity_events_batch(&self, rows: &[ActivityEventRow]) -> Result<()>; async fn record_worker_events_batch(&self, rows: &[WorkerEventRow]) -> Result<()>; + async fn record_rate_limit_violation_batches( + &self, + rows: &[RateLimitViolationBatchRow], + ) -> Result<()>; } #[cfg(feature = "test-support")] @@ -78,6 +92,7 @@ pub trait EventSink: Send + Sync { pub struct InMemoryEventSink { activity: Arc>>, worker: Arc>>, + rate_limit: Arc>>, } #[cfg(feature = "test-support")] @@ -91,6 +106,11 @@ impl InMemoryEventSink { pub async fn worker_rows(&self) -> Vec { self.worker.lock().await.clone() } + + #[allow(dead_code)] + pub async fn rate_limit_rows(&self) -> Vec { + self.rate_limit.lock().await.clone() + } } #[cfg(feature = "test-support")] @@ -107,6 +127,15 @@ impl EventSink for InMemoryEventSink { guard.extend(rows.iter().cloned()); Ok(()) } + + async fn record_rate_limit_violation_batches( + &self, + rows: &[RateLimitViolationBatchRow], + ) -> Result<()> { + let mut guard = self.rate_limit.lock().await; + guard.extend(rows.iter().cloned()); + Ok(()) + } } #[derive(Clone)] @@ -137,6 +166,18 @@ impl EventSink for EventSinkHandle { EventSinkHandle::Noop => Ok(()), } } + + async fn record_rate_limit_violation_batches( + &self, + rows: &[RateLimitViolationBatchRow], + ) -> Result<()> { + match self { + EventSinkHandle::Database(db) => db.record_rate_limit_violation_batches(rows).await, + #[cfg(feature = "test-support")] + EventSinkHandle::InMemory(sink) => sink.record_rate_limit_violation_batches(rows).await, + EventSinkHandle::Noop => Ok(()), + } + } } #[async_trait] @@ -148,4 +189,11 @@ impl EventSink for Database { async fn record_worker_events_batch(&self, rows: &[WorkerEventRow]) -> Result<()> { Database::record_worker_events_batch(self, rows).await } + + async fn record_rate_limit_violation_batches( + &self, + rows: &[RateLimitViolationBatchRow], + ) -> Result<()> { + Database::record_rate_limit_violation_batches(self, rows).await + } } diff --git a/src/db/repository.rs b/src/db/repository.rs index cae590d..f65308b 100644 --- a/src/db/repository.rs +++ b/src/db/repository.rs @@ -5,7 +5,7 @@ use futures_util::SinkExt; use tokio_postgres::types::ToSql; use super::connection::StmtKey; -use super::{ActivityEventRow, Database, WorkerEventRow}; +use super::{ActivityEventRow, Database, RateLimitViolationBatchRow, WorkerEventRow}; impl Database { pub(super) const Q_SERVER_TIME_UTC: &'static str = r#" @@ -126,6 +126,16 @@ task_kind, \ reason, \ gateway_name, \ created_at\ +) FROM STDIN WITH (FORMAT text)"; + + pub(super) const COPY_RATE_LIMIT_VIOLATIONS: &'static str = "\ +COPY rate_limit_violations (\ +gateway_name, \ +window_start, \ +window_end, \ +total_count, \ +details, \ +created_at\ ) FROM STDIN WITH (FORMAT text)"; pub async fn fetch_all_user_key_hashes( @@ -382,6 +392,69 @@ created_at\ } } } + + pub async fn record_rate_limit_violation_batches( + &self, + rows: &[RateLimitViolationBatchRow], + ) -> Result<()> { + if rows.is_empty() { + return Ok(()); + } + let client = self.load_client().await?; + client.batch_execute("BEGIN").await?; + let result = async { + for chunk in rows.chunks(self.events_copy_batch_size) { + let mut buf = Vec::with_capacity(chunk.len() * 256); + for row in chunk { + append_copy_field(&mut buf, Some(row.gateway_name.as_str())); + buf.push(b'\t'); + let window_start = row + .window_start + .naive_utc() + .format("%Y-%m-%d %H:%M:%S%.f") + .to_string(); + append_copy_field(&mut buf, Some(window_start.as_str())); + buf.push(b'\t'); + let window_end = row + .window_end + .naive_utc() + .format("%Y-%m-%d %H:%M:%S%.f") + .to_string(); + append_copy_field(&mut buf, Some(window_end.as_str())); + buf.push(b'\t'); + append_copy_field(&mut buf, Some(row.total_count.to_string().as_str())); + buf.push(b'\t'); + let details = serde_json::to_string(&row.details)?; + append_copy_field(&mut buf, Some(details.as_str())); + buf.push(b'\t'); + let created_at = row + .created_at + .naive_utc() + .format("%Y-%m-%d %H:%M:%S%.f") + .to_string(); + append_copy_field(&mut buf, Some(created_at.as_str())); + buf.push(b'\n'); + } + let sink = client.copy_in(Self::COPY_RATE_LIMIT_VIOLATIONS).await?; + let mut sink = std::pin::pin!(sink); + sink.as_mut().send(Bytes::from(buf)).await?; + sink.as_mut().finish().await?; + } + Ok::<(), anyhow::Error>(()) + } + .await; + + match result { + Ok(()) => { + client.batch_execute("COMMIT").await?; + Ok(()) + } + Err(err) => { + let _ = client.batch_execute("ROLLBACK").await; + Err(err) + } + } + } } fn append_copy_field(buf: &mut Vec, value: Option<&str>) { diff --git a/src/http3/rate_limits.rs b/src/http3/rate_limits.rs index e884f41..fd70000 100644 --- a/src/http3/rate_limits.rs +++ b/src/http3/rate_limits.rs @@ -142,6 +142,48 @@ impl RateLimitContext { } } +/// Builds the identity key used for batched rate-limit violation aggregation. +/// +/// Priority: +/// 1. company id +/// 2. user id +/// 3. decimal source IP +/// 4. unknown +fn violation_client_key(ctx: &RateLimitContext) -> String { + if let Some(company) = ctx.company.as_ref() { + return format!("company:{}", company.id); + } + if let Some(user_id) = ctx.user_id { + return format!("user:{user_id}"); + } + if let Some(ip) = ctx.decimal_ip.as_ref() { + return format!("ip:{ip}"); + } + "unknown".to_string() +} + +fn violation_client_key_for_subject(subject: Subject, id: u128) -> String { + match subject { + Subject::Company => format!("company:{}", Uuid::from_u128(id)), + Subject::User => format!("user:{}", Uuid::from_u128(id)), + Subject::GenericGlobal => "generic:global".to_string(), + Subject::GenericIp => format!("generic:ip:{id}"), + } +} + +fn maybe_record_local_violation(state: &HttpState, depot: &Depot, res: &Response) { + if res.status_code != Some(salvo::http::StatusCode::TOO_MANY_REQUESTS) { + return; + } + if let Ok(ctx) = depot.obtain::() { + state + .gateway_state() + .record_rate_limit_violation(violation_client_key(ctx).as_str()); + } else { + state.gateway_state().record_rate_limit_violation("unknown"); + } +} + #[handler] pub async fn prepare_rate_limit_context( depot: &mut Depot, @@ -313,6 +355,7 @@ pub async fn basic_rate_limit( .basic_limiter .handle(req, depot, res, ctrl) .await; + maybe_record_local_violation(&state, depot, res); Ok(()) } @@ -329,6 +372,7 @@ pub async fn update_key_rate_limit( .update_limiter .handle(req, depot, res, ctrl) .await; + maybe_record_local_violation(&state, depot, res); Ok(()) } @@ -345,6 +389,7 @@ pub async fn unauthorized_only_rate_limit( .unauthorized_only_limiter .handle(req, depot, res, ctrl) .await; + maybe_record_local_violation(&state, depot, res); Ok(()) } @@ -361,6 +406,7 @@ pub async fn read_rate_limit( .read_limiter .handle(req, depot, res, ctrl) .await; + maybe_record_local_violation(&state, depot, res); Ok(()) } @@ -377,6 +423,7 @@ pub async fn result_rate_limit( .result_limiter .handle(req, depot, res, ctrl) .await; + maybe_record_local_violation(&state, depot, res); Ok(()) } @@ -393,6 +440,7 @@ pub async fn load_rate_limit( .load_limiter .handle(req, depot, res, ctrl) .await; + maybe_record_local_violation(&state, depot, res); Ok(()) } @@ -409,6 +457,7 @@ pub async fn leader_rate_limit( .leader_limiter .handle(req, depot, res, ctrl) .await; + maybe_record_local_violation(&state, depot, res); Ok(()) } @@ -425,6 +474,7 @@ pub async fn metric_rate_limit( .metric_limiter .handle(req, depot, res, ctrl) .await; + maybe_record_local_violation(&state, depot, res); Ok(()) } @@ -441,6 +491,7 @@ pub async fn status_rate_limit( .status_limiter .handle(req, depot, res, ctrl) .await; + maybe_record_local_violation(&state, depot, res); Ok(()) } @@ -521,6 +572,7 @@ async fn check_subject_limit( ) .await { + gs.record_rate_limit_violation(violation_client_key_for_subject(subject, id).as_str()); return Err(ServerError::TooManyRequests(error_msg.to_string())); } diff --git a/src/raft/gateway_state.rs b/src/raft/gateway_state.rs index 9ddeb82..ca6e8cc 100644 --- a/src/raft/gateway_state.rs +++ b/src/raft/gateway_state.rs @@ -478,6 +478,12 @@ impl GatewayState { ); } + /// Records a single rate-limit violation for later batched persistence. + pub fn record_rate_limit_violation(&self, client_id: &str) { + self.internal + .event_recorder + .record_rate_limit_violation(client_id); + } pub async fn submit_rate_limit_mutations_with_request_id( &self, request_id: u128, diff --git a/tests/client_http_api/add_task.rs b/tests/client_http_api/add_task.rs index e8db471..d196c82 100644 --- a/tests/client_http_api/add_task.rs +++ b/tests/client_http_api/add_task.rs @@ -30,7 +30,6 @@ fn pop_seed_for_task(h: &crate::support::TestHarness, task_id: Uuid) -> i32 { assert_eq!(task.id, task_id, "seed check popped unexpected task"); task.seed } - #[tokio::test] async fn add_task_json_success() { let h = build_harness().await; diff --git a/tests/client_http_api/mod.rs b/tests/client_http_api/mod.rs index 563be70..8991ca0 100644 --- a/tests/client_http_api/mod.rs +++ b/tests/client_http_api/mod.rs @@ -6,4 +6,5 @@ mod get_result; mod get_status; mod get_tasks; mod misc; +mod rate_limits; mod support; diff --git a/tests/client_http_api/rate_limits.rs b/tests/client_http_api/rate_limits.rs new file mode 100644 index 0000000..c51c7a1 --- /dev/null +++ b/tests/client_http_api/rate_limits.rs @@ -0,0 +1,71 @@ +use std::sync::Arc; +use std::time::Duration; + +use http::StatusCode; +use salvo::prelude::*; +use salvo::test::TestClient; +use tokio_util::sync::CancellationToken; + +use gateway::db::{EventRecorder, EventSinkHandle}; +use gateway::test_support::{ + basic_rate_limit, build_shared_harness_core, ensure_test_crypto_provider, + load_test_single_node_config, +}; + +use crate::support::read_response; + +#[handler] +async fn ok_handler() -> &'static str { + "ok" +} + +#[tokio::test] +async fn basic_rate_limit_returns_429() { + ensure_test_crypto_provider(); + let (mut config, _path) = load_test_single_node_config(); + config.http.basic_rate_limit = 1; + let config = Arc::new(config); + + let config_file = tempfile::Builder::new() + .suffix(".toml") + .tempfile() + .expect("temp config file"); + let config_toml = toml::to_string(config.as_ref()).expect("serialize test config"); + std::fs::write(config_file.path(), config_toml).expect("write temp config"); + let config_path = config_file.path().to_path_buf(); + + let shutdown = CancellationToken::new(); + let event_recorder = EventRecorder::new( + Arc::new(EventSinkHandle::Noop), + Arc::from(config.network.name.as_str()), + Duration::from_secs(30), + config.db.events_queue_capacity.max(1), + shutdown.clone(), + ); + let core = build_shared_harness_core(config.clone(), config_path, event_recorder, true).await; + let state = core.state; + + let router = Router::new().hoop(affix_state::inject(state)).push( + Router::with_path("/rl-probe") + .hoop(basic_rate_limit) + .get(ok_handler), + ); + let service = Service::new(router); + + let first = TestClient::get("http://localhost/rl-probe") + .send(&service) + .await; + let (first_status, _headers, _body) = read_response(first).await; + assert_eq!(first_status, StatusCode::OK); + + let second = TestClient::get("http://localhost/rl-probe") + .send(&service) + .await; + let (second_status, _headers, body) = read_response(second).await; + assert_eq!( + second_status, + StatusCode::TOO_MANY_REQUESTS, + "second probe body: {}", + String::from_utf8_lossy(&body) + ); +} diff --git a/tests/event_tracker/support.rs b/tests/event_tracker/support.rs index a55aab7..2d7fbb7 100644 --- a/tests/event_tracker/support.rs +++ b/tests/event_tracker/support.rs @@ -65,6 +65,7 @@ pub(crate) fn current_timestamp() -> u64 { current_timestamp_secs() } +#[allow(clippy::too_many_arguments)] pub(crate) fn multipart_add_result( task_id: Uuid, worker_hotkey: &Hotkey,