Skip to content
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/target
/logs
.vscode
.cursor
16 changes: 16 additions & 0 deletions dev-env/init-scripts/init-schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,19 @@ CREATE INDEX IF NOT EXISTS idx_worker_events_worker_year ON worker_events(worker
CREATE INDEX IF NOT EXISTS idx_worker_events_worker_action_week ON worker_events(worker_id, action, bucket_week);
CREATE INDEX IF NOT EXISTS idx_worker_events_worker_action_month ON worker_events(worker_id, action, bucket_month);
CREATE INDEX IF NOT EXISTS idx_worker_events_worker_action_year ON worker_events(worker_id, action, bucket_year);

-- Batched rate-limit violations (aggregated per flush window + gateway).
-- "details" stores a JSON object with per-client counters.
CREATE TABLE IF NOT EXISTS rate_limit_violations (
id BIGSERIAL PRIMARY KEY,
gateway_name VARCHAR(255) NOT NULL,
window_start TIMESTAMP WITHOUT TIME ZONE NOT NULL,
window_end TIMESTAMP WITHOUT TIME ZONE NOT NULL,
total_count BIGINT NOT NULL CHECK (total_count >= 0),
details JSONB NOT NULL,
created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT (NOW() AT TIME ZONE 'UTC')
);
CREATE INDEX IF NOT EXISTS idx_rate_limit_violations_gateway_window
ON rate_limit_violations(gateway_name, window_start DESC);
CREATE INDEX IF NOT EXISTS idx_rate_limit_violations_window
ON rate_limit_violations(window_start DESC);
110 changes: 109 additions & 1 deletion src/db/event_recorder.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::BTreeMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
Expand All @@ -7,12 +8,22 @@ use tokio_util::sync::CancellationToken;
use tracing::error;
use uuid::Uuid;

use crate::db::{ActivityEventRow, EventSink, EventSinkHandle, WorkerEventRow};
use crate::db::{
ActivityEventRow, EventSink, EventSinkHandle, RateLimitViolationBatchRow, WorkerEventRow,
};

#[derive(Clone)]
enum EventRow {
Activity(ActivityEventRow),
Worker(WorkerEventRow),
RateLimitViolation(RateLimitViolationRow),
}

#[derive(Clone)]
struct RateLimitViolationRow {
gateway_name: String,
client_id: String,
created_at: chrono::DateTime<chrono::Utc>,
}

#[derive(Clone)]
Expand Down Expand Up @@ -94,6 +105,19 @@ impl<S: EventSink + 'static> EventRecorder<S> {
self.enqueue(EventRow::Worker(row));
}

/// Records a single rate-limit denial event for the provided client key.
///
/// The recorder stores these events in memory and periodically flushes them as
/// aggregated batch rows so we avoid writing one database row per denial.
pub fn record_rate_limit_violation(&self, client_id: &str) {
let row = RateLimitViolationRow {
gateway_name: self.gateway_name.to_string(),
client_id: client_id.to_string(),
created_at: Utc::now(),
};
self.enqueue(EventRow::RateLimitViolation(row));
}

fn spawn_flusher(&self, flush_interval: Duration, shutdown: CancellationToken) {
let sink = Arc::clone(&self.sink);
let queue = Arc::clone(&self.queue);
Expand Down Expand Up @@ -148,10 +172,12 @@ impl<S: EventSink + 'static> EventRecorder<S> {
) -> anyhow::Result<()> {
let mut activity_rows: Vec<ActivityEventRow> = Vec::new();
let mut worker_rows: Vec<WorkerEventRow> = Vec::new();
let mut rate_limit_rows: Vec<RateLimitViolationRow> = Vec::new();
while let Some(entry) = queue.pop() {
match &**entry {
EventRow::Activity(row) => activity_rows.push(row.clone()),
EventRow::Worker(row) => worker_rows.push(row.clone()),
EventRow::RateLimitViolation(row) => rate_limit_rows.push(row.clone()),
}
}

Expand All @@ -171,6 +197,48 @@ impl<S: EventSink + 'static> EventRecorder<S> {
Self::enqueue_with_limit(queue, capacity, dropped, EventRow::Worker(row));
}
}

if !rate_limit_rows.is_empty() {
let mut details = BTreeMap::<String, i64>::new();
let mut window_start = rate_limit_rows[0].created_at;
let mut window_end = rate_limit_rows[0].created_at;
for row in &rate_limit_rows {
*details.entry(row.client_id.clone()).or_insert(0) += 1;
if row.created_at < window_start {
window_start = row.created_at;
}
if row.created_at > window_end {
window_end = row.created_at;
}
}

let mut details_pairs: Vec<(String, i64)> = details.into_iter().collect();
details_pairs.sort_by(|a, b| a.0.cmp(&b.0));
let details_map: serde_json::Map<String, serde_json::Value> = details_pairs
.into_iter()
.map(|(k, v)| (k, serde_json::Value::from(v)))
.collect();
let batch_row = RateLimitViolationBatchRow {
gateway_name: rate_limit_rows[0].gateway_name.clone(),
window_start,
window_end,
total_count: rate_limit_rows.len() as i64,
details: serde_json::Value::Object(details_map),
created_at: Utc::now(),
};
let batch_rows = vec![batch_row];
if let Err(e) = sink.record_rate_limit_violation_batches(&batch_rows).await {
error!(error = ?e, "Failed to flush rate-limit violation aggregates");
for row in rate_limit_rows {
Self::enqueue_with_limit(
queue,
capacity,
dropped,
EventRow::RateLimitViolation(row),
);
}
}
}
Ok(())
}

Expand All @@ -180,3 +248,43 @@ impl<S: EventSink + 'static> EventRecorder<S> {
let _ = Self::flush_once(&self.sink, &self.queue, self.capacity, &self.dropped).await;
}
}

#[cfg(all(test, feature = "test-support"))]
mod tests {
use super::EventRecorder;
use crate::db::{EventSinkHandle, InMemoryEventSink};
use std::sync::Arc;
use std::time::Duration;
use tokio_util::sync::CancellationToken;

#[tokio::test]
async fn rate_limit_violations_are_aggregated_per_client() {
let sink = InMemoryEventSink::default();
let shutdown = CancellationToken::new();
let recorder = EventRecorder::new(
Arc::new(EventSinkHandle::InMemory(sink.clone())),
Arc::from("gw-test"),
Duration::from_secs(60),
1024,
shutdown.clone(),
);

recorder.record_rate_limit_violation("user:alice");
recorder.record_rate_limit_violation("user:alice");
recorder.record_rate_limit_violation("ip:10");
recorder.flush_once_for_test().await;

let rows = sink.rate_limit_rows().await;
assert_eq!(rows.len(), 1);
let row = &rows[0];
assert_eq!(row.gateway_name, "gw-test");
assert_eq!(row.total_count, 3);
assert_eq!(
row.details,
serde_json::json!({
"ip:10": 1,
"user:alice": 2
})
);
}
}
48 changes: 48 additions & 0 deletions src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ pub struct WorkerEventRow {
pub created_at: DateTime<Utc>,
}

#[derive(Clone)]
pub struct RateLimitViolationBatchRow {
pub gateway_name: String,
pub window_start: DateTime<Utc>,
pub window_end: DateTime<Utc>,
pub total_count: i64,
pub details: serde_json::Value,
pub created_at: DateTime<Utc>,
}

pub struct DatabaseBuilder {
sslcert_path: Option<String>,
sslkey_path: Option<String>,
Expand All @@ -71,13 +81,18 @@ pub struct DatabaseBuilder {
pub trait EventSink: Send + Sync {
async fn record_activity_events_batch(&self, rows: &[ActivityEventRow]) -> Result<()>;
async fn record_worker_events_batch(&self, rows: &[WorkerEventRow]) -> Result<()>;
async fn record_rate_limit_violation_batches(
&self,
rows: &[RateLimitViolationBatchRow],
) -> Result<()>;
}

#[cfg(feature = "test-support")]
#[derive(Clone, Default)]
pub struct InMemoryEventSink {
activity: Arc<Mutex<Vec<ActivityEventRow>>>,
worker: Arc<Mutex<Vec<WorkerEventRow>>>,
rate_limit: Arc<Mutex<Vec<RateLimitViolationBatchRow>>>,
}

#[cfg(feature = "test-support")]
Expand All @@ -91,6 +106,11 @@ impl InMemoryEventSink {
pub async fn worker_rows(&self) -> Vec<WorkerEventRow> {
self.worker.lock().await.clone()
}

#[allow(dead_code)]
pub async fn rate_limit_rows(&self) -> Vec<RateLimitViolationBatchRow> {
self.rate_limit.lock().await.clone()
}
}

#[cfg(feature = "test-support")]
Expand All @@ -107,6 +127,15 @@ impl EventSink for InMemoryEventSink {
guard.extend(rows.iter().cloned());
Ok(())
}

async fn record_rate_limit_violation_batches(
&self,
rows: &[RateLimitViolationBatchRow],
) -> Result<()> {
let mut guard = self.rate_limit.lock().await;
guard.extend(rows.iter().cloned());
Ok(())
}
}

#[derive(Clone)]
Expand Down Expand Up @@ -137,6 +166,18 @@ impl EventSink for EventSinkHandle {
EventSinkHandle::Noop => Ok(()),
}
}

async fn record_rate_limit_violation_batches(
&self,
rows: &[RateLimitViolationBatchRow],
) -> Result<()> {
match self {
EventSinkHandle::Database(db) => db.record_rate_limit_violation_batches(rows).await,
#[cfg(feature = "test-support")]
EventSinkHandle::InMemory(sink) => sink.record_rate_limit_violation_batches(rows).await,
EventSinkHandle::Noop => Ok(()),
}
}
}

#[async_trait]
Expand All @@ -148,4 +189,11 @@ impl EventSink for Database {
async fn record_worker_events_batch(&self, rows: &[WorkerEventRow]) -> Result<()> {
Database::record_worker_events_batch(self, rows).await
}

async fn record_rate_limit_violation_batches(
&self,
rows: &[RateLimitViolationBatchRow],
) -> Result<()> {
Database::record_rate_limit_violation_batches(self, rows).await
}
}
75 changes: 74 additions & 1 deletion src/db/repository.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use futures_util::SinkExt;
use tokio_postgres::types::ToSql;

use super::connection::StmtKey;
use super::{ActivityEventRow, Database, WorkerEventRow};
use super::{ActivityEventRow, Database, RateLimitViolationBatchRow, WorkerEventRow};

impl Database {
pub(super) const Q_SERVER_TIME_UTC: &'static str = r#"
Expand Down Expand Up @@ -126,6 +126,16 @@ task_kind, \
reason, \
gateway_name, \
created_at\
) FROM STDIN WITH (FORMAT text)";

pub(super) const COPY_RATE_LIMIT_VIOLATIONS: &'static str = "\
COPY rate_limit_violations (\
gateway_name, \
window_start, \
window_end, \
total_count, \
details, \
created_at\
) FROM STDIN WITH (FORMAT text)";

pub async fn fetch_all_user_key_hashes(
Expand Down Expand Up @@ -382,6 +392,69 @@ created_at\
}
}
}

pub async fn record_rate_limit_violation_batches(
&self,
rows: &[RateLimitViolationBatchRow],
) -> Result<()> {
if rows.is_empty() {
return Ok(());
}
let client = self.load_client().await?;
client.batch_execute("BEGIN").await?;
let result = async {
for chunk in rows.chunks(self.events_copy_batch_size) {
let mut buf = Vec::with_capacity(chunk.len() * 256);
for row in chunk {
append_copy_field(&mut buf, Some(row.gateway_name.as_str()));
buf.push(b'\t');
let window_start = row
.window_start
.naive_utc()
.format("%Y-%m-%d %H:%M:%S%.f")
.to_string();
append_copy_field(&mut buf, Some(window_start.as_str()));
buf.push(b'\t');
let window_end = row
.window_end
.naive_utc()
.format("%Y-%m-%d %H:%M:%S%.f")
.to_string();
append_copy_field(&mut buf, Some(window_end.as_str()));
buf.push(b'\t');
append_copy_field(&mut buf, Some(row.total_count.to_string().as_str()));
buf.push(b'\t');
let details = serde_json::to_string(&row.details)?;
append_copy_field(&mut buf, Some(details.as_str()));
buf.push(b'\t');
let created_at = row
.created_at
.naive_utc()
.format("%Y-%m-%d %H:%M:%S%.f")
.to_string();
append_copy_field(&mut buf, Some(created_at.as_str()));
buf.push(b'\n');
}
let sink = client.copy_in(Self::COPY_RATE_LIMIT_VIOLATIONS).await?;
let mut sink = std::pin::pin!(sink);
sink.as_mut().send(Bytes::from(buf)).await?;
sink.as_mut().finish().await?;
}
Ok::<(), anyhow::Error>(())
}
.await;

match result {
Ok(()) => {
client.batch_execute("COMMIT").await?;
Ok(())
}
Err(err) => {
let _ = client.batch_execute("ROLLBACK").await;
Err(err)
}
}
}
}

fn append_copy_field(buf: &mut Vec<u8>, value: Option<&str>) {
Expand Down
Loading