From 383013704b150bc5d6eb4500cd92acd68a3d7ece Mon Sep 17 00:00:00 2001 From: Raman Kudaktsin Date: Fri, 27 Feb 2026 11:07:54 +0300 Subject: [PATCH 1/9] feat: model params property --- src/api/mod.rs | 2 + src/api/request.rs | 1 + src/api/response.rs | 2 + src/common/queue.rs | 3 + src/config.rs | 11 +++ src/config_runtime.rs | 8 +- src/http3/handlers/task.rs | 35 +++++++- src/task/tests.rs | 5 ++ tests/client_http_api/add_task.rs | 126 +++++++++++++++++++++++++--- tests/client_http_api/get_result.rs | 29 +++++++ tests/client_http_api/get_tasks.rs | 6 ++ tests/client_http_api/support.rs | 6 +- tests/event_tracker/add_result.rs | 1 + tests/event_tracker/worker.rs | 3 + 14 files changed, 224 insertions(+), 14 deletions(-) diff --git a/src/api/mod.rs b/src/api/mod.rs index e76d295..299e44d 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -28,6 +28,8 @@ pub struct Task { #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, pub seed: i32, + #[serde(skip_serializing_if = "Option::is_none")] + pub model_params: Option, } pub trait HasUuid { diff --git a/src/api/request.rs b/src/api/request.rs index ef82434..3c2513f 100644 --- a/src/api/request.rs +++ b/src/api/request.rs @@ -11,6 +11,7 @@ pub struct AddTaskRequest { pub seed: Option, pub prompt: Option, pub model: Option, + pub model_params: Option, } #[derive(Debug, Clone, Deserialize)] diff --git a/src/api/response.rs b/src/api/response.rs index 7ba5e21..d2dfc06 100644 --- a/src/api/response.rs +++ b/src/api/response.rs @@ -125,6 +125,7 @@ mod tests { image: None, model: Some("404-3dgs".to_string()), seed: 0, + model_params: None, }; let resp = GetTasksResponse { @@ -151,6 +152,7 @@ mod tests { image: Some(Bytes::from(vec![1u8, 2u8, 3u8])), model: Some("404-mesh".to_string()), seed: 0, + model_params: None, }; let resp = GetTasksResponse { diff --git a/src/common/queue.rs b/src/common/queue.rs index 46fdd78..0ced88e 100644 --- a/src/common/queue.rs +++ b/src/common/queue.rs @@ -226,6 +226,7 @@ mod tests { image: None, model: None, seed: 0, + model_params: None, } } @@ -285,6 +286,7 @@ mod tests { image: None, model: Some("404-3dgs".to_string()), seed: 0, + model_params: None, }; let task_b = Task { id: Uuid::new_v4(), @@ -292,6 +294,7 @@ mod tests { image: None, model: Some("404-mesh".to_string()), seed: 0, + model_params: None, }; queue.push(task_a.clone()); diff --git a/src/config.rs b/src/config.rs index 4b5e585..1d54a83 100644 --- a/src/config.rs +++ b/src/config.rs @@ -156,6 +156,12 @@ pub struct PromptConfig { pub allowed_pattern: String, } +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ModelParamsConfig { + #[serde(default = "default_model_params_max_len")] + pub max_len: usize, +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ImageConfig { pub max_width: u32, @@ -216,6 +222,10 @@ fn default_max_rate_limit_deltas_per_batch() -> usize { 16_384 } +fn default_model_params_max_len() -> usize { + 1024 +} + fn default_snapshot_dir() -> String { "data/snapshots".to_string() } @@ -256,6 +266,7 @@ pub struct NodeConfig { pub rclient: RClientConfig, pub http: HTTPConfig, pub prompt: PromptConfig, + pub model_params: ModelParamsConfig, pub image: ImageConfig, pub db: DbConfig, pub cert: Certificate, diff --git a/src/config_runtime.rs b/src/config_runtime.rs index e47b737..7caf40b 100644 --- a/src/config_runtime.rs +++ b/src/config_runtime.rs @@ -13,7 +13,9 @@ use std::time::{Duration, UNIX_EPOCH}; use tokio::sync::Notify; use tracing::{info, warn}; -use crate::config::{HTTPConfig, ImageConfig, NodeConfig, PromptConfig, read_config_from_path}; +use crate::config::{ + HTTPConfig, ImageConfig, ModelParamsConfig, NodeConfig, PromptConfig, read_config_from_path, +}; use crate::http3::rate_limits::{RateLimitService, RateLimiters}; use crate::http3::upload_limiter::ImageUploadLimiter; use crate::http3::whitelist::{ @@ -52,6 +54,10 @@ impl RuntimeConfigView { &self.snapshot.raw.image } + pub fn model_params(&self) -> &ModelParamsConfig { + &self.snapshot.raw.model_params + } + pub fn prompt_regex(&self) -> &Regex { &self.snapshot.prompt_regex } diff --git a/src/http3/handlers/task.rs b/src/http3/handlers/task.rs index 39baba8..acd394f 100644 --- a/src/http3/handlers/task.rs +++ b/src/http3/handlers/task.rs @@ -44,6 +44,7 @@ struct AddTaskMultipartData { prompt: Option, image: Option, model: Option, + model_params: Option, } struct ValidatedAddTask { @@ -51,6 +52,7 @@ struct ValidatedAddTask { prompt: Option, image: Option, model: Option, + model_params: Option, task_kind: TaskKind, } @@ -65,17 +67,19 @@ async fn parse_add_task_multipart( let cfg = state.config(); let image_cfg = cfg.image(); let prompt_cfg = cfg.prompt(); + let model_params_cfg = cfg.model_params(); let upload_limiter = cfg.image_upload_limiter(); let mut image_permit: Option = None; let constraints = Constraints::new() - .allowed_fields(vec!["seed", "prompt", "image", "model"]) + .allowed_fields(vec!["seed", "prompt", "image", "model", "model_params"]) .size_limit( SizeLimit::new() .for_field("image", image_cfg.max_size_bytes as u64) .for_field("prompt", prompt_cfg.max_len as u64) .for_field("model", 64) - .for_field("seed", 11), + .for_field("seed", 11) + .for_field("model_params", model_params_cfg.max_len as u64), ); let mut multipart = Multipart::with_constraints(byte_stream, boundary, constraints); @@ -83,6 +87,7 @@ async fn parse_add_task_multipart( let mut seed = None; let mut image = None; let mut model = None; + let mut model_params = None; while let Some(field) = multipart .next_field() @@ -116,6 +121,9 @@ async fn parse_add_task_multipart( "seed" => { seed = Some(parse_seed_text(&read_text_field(field, "seed").await?)?); } + "model_params" => { + model_params = Some(read_text_field(field, "model_params").await?); + } _ => continue, } } @@ -125,6 +133,7 @@ async fn parse_add_task_multipart( image, model, seed, + model_params, }) } @@ -152,6 +161,7 @@ async fn parse_add_task_request( image: None, model: add_task.model, seed: add_task.seed.map(|seed| seed.into_i32()), + model_params: add_task.model_params, }) } } @@ -208,6 +218,24 @@ fn validate_add_task_input( } } + if let Some(model_params) = &add_task.model_params { + let model_params_cfg = cfg.model_params(); + if model_params.len() > model_params_cfg.max_len { + return Err(ServerError::BadRequest(format!( + "Model params is too long: maximum length is {} characters (got {})", + model_params_cfg.max_len, + model_params.len() + ))); + } + + if let Err(err) = serde_json::from_str::(model_params) { + return Err(ServerError::BadRequest(format!( + "Model params must be valid JSON: {}", + err + ))); + } + } + let validated_image = if let Some(image_data) = add_task.image { let image_cfg = cfg.image(); Some(validate_image(image_data, image_cfg)?) @@ -221,6 +249,7 @@ fn validate_add_task_input( image: validated_image, model: add_task.model, task_kind, + model_params: add_task.model_params, }) } @@ -241,6 +270,7 @@ pub async fn add_task_handler( let user_seed = validated.seed.filter(|seed| *seed != -1); // Draw from all u32 values except u32::MAX, after cast that excludes -1 sentinel. let seed = user_seed.unwrap_or_else(|| rand::random_range(0..u32::MAX) as i32); + let model_params = validated.model_params; let state = depot.require::()?.clone(); let cfg = state.config(); @@ -285,6 +315,7 @@ pub async fn add_task_handler( image: validated.image.as_ref().map(|img| img.data.clone()), model: Some(resolved_model.model), seed, + model_params, }; queue.push(task.clone()); diff --git a/src/task/tests.rs b/src/task/tests.rs index 9ab6b7d..2b596dc 100644 --- a/src/task/tests.rs +++ b/src/task/tests.rs @@ -183,6 +183,7 @@ async fn image_task_persists_until_all_assignments_complete() { image: Some(image.clone()), model: Some("404-mesh".to_string()), seed: 0, + model_params: None, }; task_manager.add_task(task).await; @@ -265,6 +266,7 @@ async fn model_persists_until_result_retrieval() { image: None, model: Some("404-3dgs".to_string()), seed: 0, + model_params: None, }; task_manager.add_task(task).await; @@ -327,6 +329,7 @@ async fn tasks_in_progress_gauge_tracks_assignments() { image: None, model: None, seed: 0, + model_params: None, }; task_manager.add_task(task).await; task_manager @@ -415,6 +418,7 @@ async fn tasks_in_progress_handles_multiple_random_assignments() { image: None, model: None, seed: 0, + model_params: None, }; task_manager.add_task(task).await; @@ -486,6 +490,7 @@ async fn test_timeout_increments_metric() { image: Some(Bytes::from_static(b"")), model: None, seed: 0, + model_params: None, }; task_manager.add_task(task).await; diff --git a/tests/client_http_api/add_task.rs b/tests/client_http_api/add_task.rs index fd01905..148ce6c 100644 --- a/tests/client_http_api/add_task.rs +++ b/tests/client_http_api/add_task.rs @@ -16,6 +16,111 @@ async fn add_task_json_success() { let _task_id = add_task_prompt(&h, "robot", None).await; } +#[tokio::test] +async fn add_task_rejects_invalid_model_params_json() { + let h = build_harness().await; + let res = TestClient::post("http://localhost/add_task") + .add_header("x-api-key", h.api_key.to_string(), true) + .json(&serde_json::json!({ + "prompt": "robot", + "model_params": "not-json" + })) + .send(&h.service) + .await; + let (status, _headers, _body) = read_response(res).await; + assert_eq!(status, StatusCode::BAD_REQUEST); +} + +#[tokio::test] +async fn add_task_multipart_rejects_invalid_model_params() { + let h = build_harness().await; + let image = tiny_png_bytes(); + let (boundary, body) = + multipart_body(Some("robot"), Some(&image), None, None, Some("not-json")); + let res = TestClient::post("http://localhost/add_task") + .add_header("x-api-key", h.api_key.to_string(), true) + .add_header( + "content-type", + format!("multipart/form-data; boundary={}", boundary), + true, + ) + .body(body) + .send(&h.service) + .await; + let (status, _headers, _body) = read_response(res).await; + assert_eq!(status, StatusCode::BAD_REQUEST); +} + +#[tokio::test] +async fn add_task_rejects_too_long_model_params_json() { + let h = build_harness().await; + let max_len = h.config.model_params.max_len; + let too_long = "a".repeat(max_len + 1); + + let res = TestClient::post("http://localhost/add_task") + .add_header("x-api-key", h.api_key.to_string(), true) + .json(&serde_json::json!({ + "prompt": "robot", + "model_params": too_long + })) + .send(&h.service) + .await; + let (status, _headers, _body) = read_response(res).await; + assert_eq!(status, StatusCode::BAD_REQUEST); +} + +#[tokio::test] +async fn add_task_multipart_rejects_too_long_model_params() { + let h = build_harness().await; + let max_len = h.config.model_params.max_len; + let too_long = "a".repeat(max_len + 1); + let image = tiny_png_bytes(); + let (boundary, body) = multipart_body(Some("robot"), Some(&image), None, None, Some(&too_long)); + let res = TestClient::post("http://localhost/add_task") + .add_header("x-api-key", h.api_key.to_string(), true) + .add_header( + "content-type", + format!("multipart/form-data; boundary={}", boundary), + true, + ) + .body(body) + .send(&h.service) + .await; + let (status, _headers, _body) = read_response(res).await; + assert_eq!(status, StatusCode::BAD_REQUEST); +} + +#[tokio::test] +async fn add_task_accepts_valid_model_params_json() { + let h = build_harness().await; + let params = r#"{"temperature":0.5}"#; + let res = TestClient::post("http://localhost/add_task") + .add_header("x-api-key", h.api_key.to_string(), true) + .json(&serde_json::json!({ + "prompt": "robot", + "model_params": params + })) + .send(&h.service) + .await; + let (status, _headers, _body) = read_response(res).await; + assert_eq!(status, StatusCode::OK); +} + +#[tokio::test] +async fn add_task_accepts_null_model_params_json() { + let h = build_harness().await; + let res = TestClient::post("http://localhost/add_task") + .add_header("x-api-key", h.api_key.to_string(), true) + .json(&serde_json::json!({ + "prompt": "robot", + "model_params": serde_json::Value::Null + })) + .send(&h.service) + .await; + let (status, _headers, _body) = read_response(res).await; + assert_eq!(status, StatusCode::OK); +} + #[tokio::test] async fn add_task_origin_header_success() { let h = build_harness().await; @@ -52,7 +157,7 @@ async fn add_task_origin_header_success() { async fn add_task_image_multipart_success() { let h = build_harness().await; let image = tiny_png_bytes(); - let (boundary, body) = multipart_body(None, Some(&image), None, Some("1")); + let (boundary, body) = multipart_body(None, Some(&image), None, Some("1"), None); let res = TestClient::post("http://localhost/add_task") .add_header("x-api-key", h.api_key.to_string(), true) .add_header( @@ -76,7 +181,7 @@ async fn add_task_image_multipart_success() { async fn add_task_without_seed_is_ok() { let h = build_harness().await; let image = tiny_png_bytes(); - let (boundary, body) = multipart_body(None, Some(&image), None, None); + let (boundary, body) = multipart_body(None, Some(&image), None, None, None); let res = TestClient::post("http://localhost/add_task") .add_header("x-api-key", h.api_key.to_string(), true) .add_header( @@ -100,7 +205,7 @@ async fn add_task_without_seed_is_ok() { async fn add_task_negative_seed_is_ok() { let h = build_harness().await; let image = tiny_png_bytes(); - let (boundary, body) = multipart_body(None, Some(&image), None, Some("-1")); + let (boundary, body) = multipart_body(None, Some(&image), None, Some("-1"), None); let res = TestClient::post("http://localhost/add_task") .add_header("x-api-key", h.api_key.to_string(), true) .add_header( @@ -128,7 +233,7 @@ async fn add_task_negative_seed_is_ok() { async fn add_task_min_i32_seed_is_ok() { let h = build_harness().await; let image = tiny_png_bytes(); - let (boundary, body) = multipart_body(None, Some(&image), None, Some("-2147483648")); + let (boundary, body) = multipart_body(None, Some(&image), None, Some("-2147483648"), None); let res = TestClient::post("http://localhost/add_task") .add_header("x-api-key", h.api_key.to_string(), true) .add_header( @@ -152,7 +257,7 @@ async fn add_task_min_i32_seed_is_ok() { async fn add_task_high_u32_seed_converts_to_signed_for_multipart() { let h = build_harness().await; let image = tiny_png_bytes(); - let (boundary, body) = multipart_body(None, Some(&image), None, Some("2147483648")); + let (boundary, body) = multipart_body(None, Some(&image), None, Some("2147483648"), None); let res = TestClient::post("http://localhost/add_task") .add_header("x-api-key", h.api_key.to_string(), true) .add_header( @@ -201,7 +306,7 @@ async fn add_task_high_u32_seed_converts_to_signed_for_json() { async fn add_task_max_u32_seed_randomizes_for_multipart() { let h = build_harness().await; let image = tiny_png_bytes(); - let (boundary, body) = multipart_body(None, Some(&image), None, Some("4294967295")); + let (boundary, body) = multipart_body(None, Some(&image), None, Some("4294967295"), None); let res = TestClient::post("http://localhost/add_task") .add_header("x-api-key", h.api_key.to_string(), true) .add_header( @@ -271,7 +376,7 @@ async fn add_task_json_negative_seed_randomizes() { async fn add_task_invalid_seed_returns_bad_request() { let h = build_harness().await; let image = tiny_png_bytes(); - let (boundary, body) = multipart_body(None, Some(&image), None, Some("not-a-number")); + let (boundary, body) = multipart_body(None, Some(&image), None, Some("not-a-number"), None); let res = TestClient::post("http://localhost/add_task") .add_header("x-api-key", h.api_key.to_string(), true) .add_header( @@ -290,7 +395,7 @@ async fn add_task_invalid_seed_returns_bad_request() { async fn add_task_out_of_range_seed_returns_bad_request() { let h = build_harness().await; let image = tiny_png_bytes(); - let (boundary, body) = multipart_body(None, Some(&image), None, Some("4294967296")); + let (boundary, body) = multipart_body(None, Some(&image), None, Some("4294967296"), None); let res = TestClient::post("http://localhost/add_task") .add_header("x-api-key", h.api_key.to_string(), true) .add_header( @@ -308,7 +413,7 @@ async fn add_task_out_of_range_seed_returns_bad_request() { #[tokio::test] async fn add_task_rejects_invalid_image_data() { let h = build_harness().await; - let (boundary, body) = multipart_body(None, Some(b"not-an-image"), None, None); + let (boundary, body) = multipart_body(None, Some(b"not-an-image"), None, None, None); let res = TestClient::post("http://localhost/add_task") .add_header("x-api-key", h.api_key.to_string(), true) .add_header( @@ -394,7 +499,7 @@ async fn add_task_rejects_missing_prompt_and_image() { async fn add_task_rejects_prompt_and_image() { let h = build_harness().await; let image = tiny_png_bytes(); - let (boundary, body) = multipart_body(Some("robot"), Some(&image), None, None); + let (boundary, body) = multipart_body(Some("robot"), Some(&image), None, None, None); let res = TestClient::post("http://localhost/add_task") .add_header("x-api-key", h.api_key.to_string(), true) .add_header( @@ -508,6 +613,7 @@ async fn add_task_rejects_when_queue_full() { image: None, model: None, seed: 0, + model_params: None, }); } let res = TestClient::post("http://localhost/add_task") diff --git a/tests/client_http_api/get_result.rs b/tests/client_http_api/get_result.rs index 66f5ebe..af8a288 100644 --- a/tests/client_http_api/get_result.rs +++ b/tests/client_http_api/get_result.rs @@ -261,6 +261,7 @@ async fn get_result_invalid_model_config() { image: None, model: Some("missing-model".to_string()), seed: 0, + model_params: Some(r#"{"preset":"default"}"#.to_string()), }) .await; @@ -275,6 +276,34 @@ async fn get_result_invalid_model_config() { assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR); } +#[tokio::test] +async fn get_result_allows_none_model_params() { + let h = build_harness().await; + let task_id = Uuid::new_v4(); + h.task_manager + .add_task(Task { + id: task_id, + prompt: Some(Arc::new("robot".to_string())), + image: None, + model: Some("404-3dgs".to_string()), + seed: 0, + model_params: None, + }) + .await; + + let worker = hotkey_from_seed(HOTKEY_SEED_BASE + 18); + let payload = b"spz".to_vec(); + add_success_result(&h.task_manager, task_id, worker, payload.clone()).await; + + let res = TestClient::get(format!("http://localhost/get_result?id={task_id}")) + .add_header("x-api-key", h.api_key.to_string(), true) + .send(&h.service) + .await; + let (status, _headers, body) = read_response(res).await; + assert_eq!(status, StatusCode::OK); + assert_eq!(body, payload); +} + #[tokio::test] async fn get_result_model_glb_success() { let h = build_harness().await; diff --git a/tests/client_http_api/get_tasks.rs b/tests/client_http_api/get_tasks.rs index 7815195..703aca1 100644 --- a/tests/client_http_api/get_tasks.rs +++ b/tests/client_http_api/get_tasks.rs @@ -81,6 +81,7 @@ async fn get_tasks_requested_count_zero_returns_empty() { image: None, model: None, seed: 1, + model_params: None, }; h.task_manager.add_task(task.clone()).await; h.task_queue.push(task); @@ -117,6 +118,7 @@ async fn get_tasks_requested_count_large_returns_available() { image: None, model: None, seed: 1, + model_params: Some(r#"{"quality":"high"}"#.to_string()), }; h.task_manager.add_task(task.clone()).await; h.task_queue.push(task); @@ -152,6 +154,7 @@ async fn get_tasks_single_model_filters_results() { image: None, model: Some("404-3dgs".to_string()), seed: 0, + model_params: Some(r#"{"preset":"fast"}"#.to_string()), }; let task_b = Task { id: Uuid::new_v4(), @@ -159,6 +162,7 @@ async fn get_tasks_single_model_filters_results() { image: None, model: Some("404-mesh".to_string()), seed: 0, + model_params: Some(r#"{"preset":"high_quality"}"#.to_string()), }; h.task_manager.add_task(task_a.clone()).await; h.task_manager.add_task(task_b.clone()).await; @@ -201,6 +205,7 @@ async fn get_tasks_multiple_models_returns_all_matches() { image: None, model: Some("404-3dgs".to_string()), seed: 0, + model_params: Some(r#"{"preset":"fast"}"#.to_string()), }; let task_b = Task { id: Uuid::new_v4(), @@ -208,6 +213,7 @@ async fn get_tasks_multiple_models_returns_all_matches() { image: None, model: Some("404-mesh".to_string()), seed: 0, + model_params: Some(r#"{"preset":"high_quality"}"#.to_string()), }; h.task_manager.add_task(task_a.clone()).await; h.task_manager.add_task(task_b.clone()).await; diff --git a/tests/client_http_api/support.rs b/tests/client_http_api/support.rs index f76bef8..f8fd401 100644 --- a/tests/client_http_api/support.rs +++ b/tests/client_http_api/support.rs @@ -109,6 +109,7 @@ pub(crate) fn multipart_body( image: Option<&[u8]>, model: Option<&str>, seed: Option<&str>, + model_params: Option<&str>, ) -> (String, Vec) { let mut fields = Vec::<(&str, &str)>::new(); if let Some(p) = prompt { @@ -120,6 +121,9 @@ pub(crate) fn multipart_body( if let Some(s) = seed { fields.push(("seed", s)); } + if let Some(mp) = model_params { + fields.push(("model_params", mp)); + } let mut files = Vec::new(); if let Some(img) = image { @@ -295,7 +299,7 @@ pub(crate) async fn add_task_prompt(h: &TestHarness, prompt: &str, model: Option } pub(crate) async fn add_task_image(h: &TestHarness, image: &[u8], model: Option<&str>) -> Uuid { - let (boundary, body) = multipart_body(None, Some(image), model, None); + let (boundary, body) = multipart_body(None, Some(image), model, None, None); let res = TestClient::post("http://localhost/add_task") .add_header("x-api-key", h.api_key.to_string(), true) .add_header( diff --git a/tests/event_tracker/add_result.rs b/tests/event_tracker/add_result.rs index 17d6845..32139d6 100644 --- a/tests/event_tracker/add_result.rs +++ b/tests/event_tracker/add_result.rs @@ -218,6 +218,7 @@ async fn add_result_not_assigned_returns_unauthorized() { image: None, model: None, seed: 0, + model_params: Some(r#"{"preset":"default"}"#.to_string()), }) .await; diff --git a/tests/event_tracker/worker.rs b/tests/event_tracker/worker.rs index 9a5930e..9ace3af 100644 --- a/tests/event_tracker/worker.rs +++ b/tests/event_tracker/worker.rs @@ -28,6 +28,7 @@ async fn records_worker_events() { image: None, model: None, seed: 0, + model_params: Some(r#"{"preset":"default"}"#.to_string()), }; h.task_manager.add_task(task.clone()).await; h.task_queue.push(task); @@ -103,6 +104,7 @@ async fn records_worker_failure_event() { image: None, model: None, seed: 0, + model_params: Some(r#"{"preset":"default"}"#.to_string()), }; h.task_manager.add_task(task.clone()).await; h.task_queue.push(task); @@ -185,6 +187,7 @@ async fn records_worker_timeout_event() { image: None, model: None, seed: 0, + model_params: Some(r#"{"preset":"default"}"#.to_string()), }; task_manager.add_task(task).await; let worker = Hotkey::from_bytes(&[3u8; 32]); From 22695de1224739c8342e1bbf0cda77c0409bfd08 Mon Sep 17 00:00:00 2001 From: Raman Kudaktsin Date: Fri, 27 Feb 2026 11:20:52 +0300 Subject: [PATCH 2/9] fix: models params test fix --- dev-env/config/config1.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dev-env/config/config1.toml b/dev-env/config/config1.toml index f99767c..fa2f2f6 100644 --- a/dev-env/config/config1.toml +++ b/dev-env/config/config1.toml @@ -131,6 +131,11 @@ min_len = 3 max_len = 384 allowed_pattern = "^[A-Za-z0-9 .,'():;/?!+%-]+$" +[model_params] +# Maximum length (in bytes) of the `model_params` JSON string accepted by the API. +# This should stay in sync with `ModelParamsConfig::max_len` defaults. +max_len = 1024 + [image] max_width = 1024 max_height = 1024 From d1210a9ee65f08acd825181574b5fdeda3e2f1ed Mon Sep 17 00:00:00 2001 From: Raman Kudaktsin Date: Fri, 27 Feb 2026 11:50:36 +0300 Subject: [PATCH 3/9] fix: model params in test fix --- dev-env/config/config-single.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dev-env/config/config-single.toml b/dev-env/config/config-single.toml index c79654b..88ce946 100644 --- a/dev-env/config/config-single.toml +++ b/dev-env/config/config-single.toml @@ -131,6 +131,9 @@ min_len = 3 max_len = 384 allowed_pattern = "^[A-Za-z0-9 .,'():;/?!+%-]+$" +[model_params] +max_len = 1024 + [image] max_width = 1024 max_height = 1024 From c9515c048ef33bf41a4801062e948bb9104a36a7 Mon Sep 17 00:00:00 2001 From: Denis Avvakumov Date: Fri, 27 Feb 2026 14:43:03 +0200 Subject: [PATCH 4/9] Make model_params a JSON object --- dev-env/config/config1.toml | 2 +- dev-env/config/config2.toml | 3 + dev-env/config/config3.toml | 3 + src/api/mod.rs | 5 +- src/api/request.rs | 3 +- src/config.rs | 7 ++ src/http3/handlers/task.rs | 88 +++++++++++++------ tests/client_http_api/add_task.rs | 132 ++++++++++++++++++++++++---- tests/client_http_api/get_result.rs | 4 +- tests/client_http_api/get_tasks.rs | 16 ++-- tests/event_tracker/add_result.rs | 4 +- tests/event_tracker/worker.rs | 12 ++- 12 files changed, 220 insertions(+), 59 deletions(-) diff --git a/dev-env/config/config1.toml b/dev-env/config/config1.toml index fa2f2f6..1934c61 100644 --- a/dev-env/config/config1.toml +++ b/dev-env/config/config1.toml @@ -132,7 +132,7 @@ max_len = 384 allowed_pattern = "^[A-Za-z0-9 .,'():;/?!+%-]+$" [model_params] -# Maximum length (in bytes) of the `model_params` JSON string accepted by the API. +# Maximum serialized length (in bytes) of the `model_params` JSON object accepted by the API. # This should stay in sync with `ModelParamsConfig::max_len` defaults. max_len = 1024 diff --git a/dev-env/config/config2.toml b/dev-env/config/config2.toml index f23fc79..74c3b4c 100644 --- a/dev-env/config/config2.toml +++ b/dev-env/config/config2.toml @@ -132,6 +132,9 @@ min_len = 3 max_len = 384 allowed_pattern = "^[A-Za-z0-9 .,'():;/?!+%-]+$" +[model_params] +max_len = 1024 + [image] max_width = 1024 max_height = 1024 diff --git a/dev-env/config/config3.toml b/dev-env/config/config3.toml index e66ddd6..f1392a0 100644 --- a/dev-env/config/config3.toml +++ b/dev-env/config/config3.toml @@ -131,6 +131,9 @@ min_len = 3 max_len = 384 allowed_pattern = "^[A-Za-z0-9 .,'():;/?!+%-]+$" +[model_params] +max_len = 1024 + [image] max_width = 1024 max_height = 1024 diff --git a/src/api/mod.rs b/src/api/mod.rs index 299e44d..ae423a7 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -9,8 +9,11 @@ use std::sync::Arc; use crate::common::image::serialize_image_base64; use bytes::Bytes; use serde::Serialize; +use serde_json::{Map, Value}; use uuid::Uuid; +pub type ModelParams = Map; + #[derive(Debug, Clone, Serialize)] pub struct Task { // The unique ID of the task (UUIDv4). @@ -29,7 +32,7 @@ pub struct Task { pub model: Option, pub seed: i32, #[serde(skip_serializing_if = "Option::is_none")] - pub model_params: Option, + pub model_params: Option, } pub trait HasUuid { diff --git a/src/api/request.rs b/src/api/request.rs index 3c2513f..a7b7c78 100644 --- a/src/api/request.rs +++ b/src/api/request.rs @@ -1,4 +1,5 @@ use serde::Deserialize; +use serde_json::Value; use std::sync::Arc; use std::time::Instant; use uuid::Uuid; @@ -11,7 +12,7 @@ pub struct AddTaskRequest { pub seed: Option, pub prompt: Option, pub model: Option, - pub model_params: Option, + pub model_params: Option, } #[derive(Debug, Clone, Deserialize)] diff --git a/src/config.rs b/src/config.rs index 1d54a83..0fe7d38 100644 --- a/src/config.rs +++ b/src/config.rs @@ -226,6 +226,12 @@ fn default_model_params_max_len() -> usize { 1024 } +fn default_model_params_config() -> ModelParamsConfig { + ModelParamsConfig { + max_len: default_model_params_max_len(), + } +} + fn default_snapshot_dir() -> String { "data/snapshots".to_string() } @@ -266,6 +272,7 @@ pub struct NodeConfig { pub rclient: RClientConfig, pub http: HTTPConfig, pub prompt: PromptConfig, + #[serde(default = "default_model_params_config")] pub model_params: ModelParamsConfig, pub image: ImageConfig, pub db: DbConfig, diff --git a/src/http3/handlers/task.rs b/src/http3/handlers/task.rs index acd394f..b556a00 100644 --- a/src/http3/handlers/task.rs +++ b/src/http3/handlers/task.rs @@ -9,9 +9,9 @@ use tokio::sync::OwnedSemaphorePermit; use tracing::info; use uuid::Uuid; -use crate::api::Task; use crate::api::request::{AddTaskRequest, GetTasksRequest}; use crate::api::response::{GetTasksResponse, LoadResponse}; +use crate::api::{ModelParams, Task}; use crate::common::image::validate_image; use crate::http3::depot_ext::DepotExt; use crate::http3::error::ServerError; @@ -44,7 +44,7 @@ struct AddTaskMultipartData { prompt: Option, image: Option, model: Option, - model_params: Option, + model_params: Option, } struct ValidatedAddTask { @@ -52,7 +52,7 @@ struct ValidatedAddTask { prompt: Option, image: Option, model: Option, - model_params: Option, + model_params: Option, task_kind: TaskKind, } @@ -67,7 +67,6 @@ async fn parse_add_task_multipart( let cfg = state.config(); let image_cfg = cfg.image(); let prompt_cfg = cfg.prompt(); - let model_params_cfg = cfg.model_params(); let upload_limiter = cfg.image_upload_limiter(); let mut image_permit: Option = None; @@ -78,8 +77,7 @@ async fn parse_add_task_multipart( .for_field("image", image_cfg.max_size_bytes as u64) .for_field("prompt", prompt_cfg.max_len as u64) .for_field("model", 64) - .for_field("seed", 11) - .for_field("model_params", model_params_cfg.max_len as u64), + .for_field("seed", 11), ); let mut multipart = Multipart::with_constraints(byte_stream, boundary, constraints); @@ -122,7 +120,14 @@ async fn parse_add_task_multipart( seed = Some(parse_seed_text(&read_text_field(field, "seed").await?)?); } "model_params" => { - model_params = Some(read_text_field(field, "model_params").await?); + let raw_model_params = read_text_field(field, "model_params").await?; + let parsed_model_params = serde_json::from_str::( + &raw_model_params, + ) + .map_err(|err| { + ServerError::BadRequest(format!("Model params must be valid JSON: {}", err)) + })?; + model_params = Some(parsed_model_params); } _ => continue, } @@ -152,10 +157,21 @@ async fn parse_add_task_request( let boundary = parse_boundary(&content_type)?; parse_add_task_multipart(depot, req, boundary).await } else { - let add_task: AddTaskRequest = req - .parse_json::() + let add_task_json = req + .parse_json::() .await .map_err(|e| ServerError::BadRequest(e.to_string()))?; + if add_task_json + .as_object() + .and_then(|payload| payload.get("model_params")) + .is_some_and(|value| value.is_null()) + { + return Err(ServerError::BadRequest( + "Model params must be a JSON object".into(), + )); + } + let add_task: AddTaskRequest = serde_json::from_value(add_task_json) + .map_err(|e| ServerError::BadRequest(e.to_string()))?; Ok(AddTaskMultipartData { prompt: add_task.prompt, image: None, @@ -183,8 +199,16 @@ fn validate_add_task_input( ) -> Result { let state = depot.require::()?.clone(); let cfg = state.config(); - let has_prompt = add_task.prompt.as_ref().is_some_and(|p| !p.is_empty()); - let has_image = add_task.image.as_ref().is_some_and(|b| !b.is_empty()); + let AddTaskMultipartData { + seed, + prompt, + image, + model, + model_params, + } = add_task; + + let has_prompt = prompt.as_ref().is_some_and(|p| !p.is_empty()); + let has_image = image.as_ref().is_some_and(|b| !b.is_empty()); if has_prompt == has_image { return Err(ServerError::BadRequest(if has_prompt { "Cannot provide both prompt and image. Choose one.".into() @@ -199,7 +223,7 @@ fn validate_add_task_input( TaskKind::TextTo3D }; - if let Some(prompt) = &add_task.prompt { + if let Some(prompt) = &prompt { let prompt_cfg = cfg.prompt(); let len = prompt.chars().count(); if len < prompt_cfg.min_len { @@ -218,25 +242,31 @@ fn validate_add_task_input( } } - if let Some(model_params) = &add_task.model_params { + let validated_model_params = if let Some(model_params_value) = model_params { + let model_params = model_params_value + .as_object() + .cloned() + .ok_or_else(|| ServerError::BadRequest("Model params must be a JSON object".into()))?; let model_params_cfg = cfg.model_params(); - if model_params.len() > model_params_cfg.max_len { - return Err(ServerError::BadRequest(format!( - "Model params is too long: maximum length is {} characters (got {})", - model_params_cfg.max_len, - model_params.len() - ))); - } + let serialized_len = serde_json::to_vec(&model_params) + .map_err(|err| { + ServerError::BadRequest(format!("Model params must be valid JSON: {}", err)) + })? + .len(); - if let Err(err) = serde_json::from_str::(model_params) { + if serialized_len > model_params_cfg.max_len { return Err(ServerError::BadRequest(format!( - "Model params must be valid JSON: {}", - err + "Model params is too long: maximum length is {} bytes (got {})", + model_params_cfg.max_len, serialized_len ))); } - } - let validated_image = if let Some(image_data) = add_task.image { + Some(model_params) + } else { + None + }; + + let validated_image = if let Some(image_data) = image { let image_cfg = cfg.image(); Some(validate_image(image_data, image_cfg)?) } else { @@ -244,12 +274,12 @@ fn validate_add_task_input( }; Ok(ValidatedAddTask { - seed: add_task.seed, - prompt: add_task.prompt, + seed, + prompt, image: validated_image, - model: add_task.model, + model, task_kind, - model_params: add_task.model_params, + model_params: validated_model_params, }) } diff --git a/tests/client_http_api/add_task.rs b/tests/client_http_api/add_task.rs index 148ce6c..ece6655 100644 --- a/tests/client_http_api/add_task.rs +++ b/tests/client_http_api/add_task.rs @@ -10,6 +10,17 @@ use crate::support::{ add_task_prompt, build_harness, multipart_body, read_response, tiny_png_bytes, }; +fn model_params_json_object_text_with_total_len(total_len: usize) -> String { + // {"p":""} -> fixed overhead is 8 bytes/chars. + assert!(total_len >= 8, "total_len must be at least 8"); + format!(r#"{{"p":"{}"}}"#, "a".repeat(total_len - 8)) +} + +fn model_params_json_object_with_total_len(total_len: usize) -> serde_json::Value { + serde_json::from_str(&model_params_json_object_text_with_total_len(total_len)) + .expect("model params object") +} + #[tokio::test] async fn add_task_json_success() { let h = build_harness().await; @@ -17,26 +28,60 @@ async fn add_task_json_success() { } #[tokio::test] -async fn add_task_rejects_invalid_model_params_json() { +async fn add_task_rejects_non_object_model_params_json() { let h = build_harness().await; let res = TestClient::post("http://localhost/add_task") .add_header("x-api-key", h.api_key.to_string(), true) .json(&serde_json::json!({ "prompt": "robot", - "model_params": "not-json" + "model_params": "not-an-object" })) .send(&h.service) .await; - let (status, _headers, _body) = read_response(res).await; + let (status, _headers, body) = read_response(res).await; assert_eq!(status, StatusCode::BAD_REQUEST); + let body_text = String::from_utf8_lossy(&body); + assert!(body_text.contains("Model params must be a JSON object")); +} + +#[tokio::test] +async fn add_task_rejects_null_model_params_json() { + let h = build_harness().await; + let res = TestClient::post("http://localhost/add_task") + .add_header("x-api-key", h.api_key.to_string(), true) + .json(&serde_json::json!({ + "prompt": "robot", + "model_params": serde_json::Value::Null + })) + .send(&h.service) + .await; + let (status, _headers, body) = read_response(res).await; + assert_eq!(status, StatusCode::BAD_REQUEST); + let body_text = String::from_utf8_lossy(&body); + assert!(body_text.contains("Model params must be a JSON object")); +} + +#[tokio::test] +async fn add_task_rejects_array_model_params_json() { + let h = build_harness().await; + let res = TestClient::post("http://localhost/add_task") + .add_header("x-api-key", h.api_key.to_string(), true) + .json(&serde_json::json!({ + "prompt": "robot", + "model_params": [1, 2, 3] + })) + .send(&h.service) + .await; + let (status, _headers, body) = read_response(res).await; + assert_eq!(status, StatusCode::BAD_REQUEST); + let body_text = String::from_utf8_lossy(&body); + assert!(body_text.contains("Model params must be a JSON object")); } #[tokio::test] async fn add_task_multipart_rejects_invalid_model_params() { let h = build_harness().await; - let image = tiny_png_bytes(); - let (boundary, body) = - multipart_body(Some("robot"), Some(&image), None, None, Some("not-json")); + let (boundary, body) = multipart_body(Some("robot"), None, None, None, Some("not-json")); let res = TestClient::post("http://localhost/add_task") .add_header("x-api-key", h.api_key.to_string(), true) .add_header( @@ -47,15 +92,37 @@ async fn add_task_multipart_rejects_invalid_model_params() { .body(body) .send(&h.service) .await; - let (status, _headers, _body) = read_response(res).await; + let (status, _headers, body) = read_response(res).await; assert_eq!(status, StatusCode::BAD_REQUEST); + let body_text = String::from_utf8_lossy(&body); + assert!(body_text.contains("Model params must be valid JSON")); +} + +#[tokio::test] +async fn add_task_multipart_rejects_non_object_model_params() { + let h = build_harness().await; + let (boundary, body) = multipart_body(Some("robot"), None, None, None, Some("[1,2,3]")); + let res = TestClient::post("http://localhost/add_task") + .add_header("x-api-key", h.api_key.to_string(), true) + .add_header( + "content-type", + format!("multipart/form-data; boundary={}", boundary), + true, + ) + .body(body) + .send(&h.service) + .await; + let (status, _headers, body) = read_response(res).await; + assert_eq!(status, StatusCode::BAD_REQUEST); + let body_text = String::from_utf8_lossy(&body); + assert!(body_text.contains("Model params must be a JSON object")); } #[tokio::test] async fn add_task_rejects_too_long_model_params_json() { let h = build_harness().await; let max_len = h.config.model_params.max_len; - let too_long = "a".repeat(max_len + 1); + let too_long = model_params_json_object_with_total_len(max_len + 1); let res = TestClient::post("http://localhost/add_task") .add_header("x-api-key", h.api_key.to_string(), true) @@ -65,17 +132,18 @@ async fn add_task_rejects_too_long_model_params_json() { })) .send(&h.service) .await; - let (status, _headers, _body) = read_response(res).await; + let (status, _headers, body) = read_response(res).await; assert_eq!(status, StatusCode::BAD_REQUEST); + let body_text = String::from_utf8_lossy(&body); + assert!(body_text.contains("Model params is too long")); } #[tokio::test] async fn add_task_multipart_rejects_too_long_model_params() { let h = build_harness().await; let max_len = h.config.model_params.max_len; - let too_long = "a".repeat(max_len + 1); - let image = tiny_png_bytes(); - let (boundary, body) = multipart_body(Some("robot"), Some(&image), None, None, Some(&too_long)); + let too_long = model_params_json_object_text_with_total_len(max_len + 1); + let (boundary, body) = multipart_body(Some("robot"), None, None, None, Some(&too_long)); let res = TestClient::post("http://localhost/add_task") .add_header("x-api-key", h.api_key.to_string(), true) .add_header( @@ -86,14 +154,23 @@ async fn add_task_multipart_rejects_too_long_model_params() { .body(body) .send(&h.service) .await; - let (status, _headers, _body) = read_response(res).await; + let (status, _headers, body) = read_response(res).await; assert_eq!(status, StatusCode::BAD_REQUEST); + let body_text = String::from_utf8_lossy(&body); + let body_lower = body_text.to_ascii_lowercase(); + assert!( + body_text.contains("Model params is too long") + || (body_lower.contains("model_params") + && (body_lower.contains("exceed") || body_lower.contains("size"))), + "unexpected multipart over-limit error: {body_text}" + ); } #[tokio::test] -async fn add_task_accepts_valid_model_params_json() { +async fn add_task_accepts_model_params_json_at_max_len() { let h = build_harness().await; - let params = r#"{"temperature":0.5}"#; + let params = model_params_json_object_with_total_len(h.config.model_params.max_len); + let res = TestClient::post("http://localhost/add_task") .add_header("x-api-key", h.api_key.to_string(), true) .json(&serde_json::json!({ @@ -107,13 +184,34 @@ async fn add_task_accepts_valid_model_params_json() { } #[tokio::test] -async fn add_task_accepts_null_model_params_json() { +async fn add_task_accepts_model_params_multipart_at_max_len() { let h = build_harness().await; + let params = model_params_json_object_text_with_total_len(h.config.model_params.max_len); + let (boundary, body) = multipart_body(Some("robot"), None, None, None, Some(¶ms)); + + let res = TestClient::post("http://localhost/add_task") + .add_header("x-api-key", h.api_key.to_string(), true) + .add_header( + "content-type", + format!("multipart/form-data; boundary={}", boundary), + true, + ) + .body(body) + .send(&h.service) + .await; + let (status, _headers, _body) = read_response(res).await; + assert_eq!(status, StatusCode::OK); +} + +#[tokio::test] +async fn add_task_accepts_valid_model_params_json() { + let h = build_harness().await; + let params = serde_json::json!({"temperature": 0.5}); let res = TestClient::post("http://localhost/add_task") .add_header("x-api-key", h.api_key.to_string(), true) .json(&serde_json::json!({ "prompt": "robot", - "model_params": serde_json::Value::Null + "model_params": params })) .send(&h.service) .await; diff --git a/tests/client_http_api/get_result.rs b/tests/client_http_api/get_result.rs index af8a288..84779fe 100644 --- a/tests/client_http_api/get_result.rs +++ b/tests/client_http_api/get_result.rs @@ -261,7 +261,9 @@ async fn get_result_invalid_model_config() { image: None, model: Some("missing-model".to_string()), seed: 0, - model_params: Some(r#"{"preset":"default"}"#.to_string()), + model_params: Some( + serde_json::from_str(r#"{"preset":"default"}"#).expect("model params object"), + ), }) .await; diff --git a/tests/client_http_api/get_tasks.rs b/tests/client_http_api/get_tasks.rs index 703aca1..6208c5f 100644 --- a/tests/client_http_api/get_tasks.rs +++ b/tests/client_http_api/get_tasks.rs @@ -118,7 +118,7 @@ async fn get_tasks_requested_count_large_returns_available() { image: None, model: None, seed: 1, - model_params: Some(r#"{"quality":"high"}"#.to_string()), + model_params: Some(serde_json::from_str(r#"{"quality":"high"}"#).expect("object")), }; h.task_manager.add_task(task.clone()).await; h.task_queue.push(task); @@ -143,6 +143,12 @@ async fn get_tasks_requested_count_large_returns_available() { .and_then(|v| v.as_array()) .expect("tasks"); assert_eq!(tasks.len(), 1); + assert!( + tasks[0] + .get("model_params") + .and_then(|value| value.as_object()) + .is_some() + ); } #[tokio::test] @@ -154,7 +160,7 @@ async fn get_tasks_single_model_filters_results() { image: None, model: Some("404-3dgs".to_string()), seed: 0, - model_params: Some(r#"{"preset":"fast"}"#.to_string()), + model_params: Some(serde_json::from_str(r#"{"preset":"fast"}"#).expect("object")), }; let task_b = Task { id: Uuid::new_v4(), @@ -162,7 +168,7 @@ async fn get_tasks_single_model_filters_results() { image: None, model: Some("404-mesh".to_string()), seed: 0, - model_params: Some(r#"{"preset":"high_quality"}"#.to_string()), + model_params: Some(serde_json::from_str(r#"{"preset":"high_quality"}"#).expect("object")), }; h.task_manager.add_task(task_a.clone()).await; h.task_manager.add_task(task_b.clone()).await; @@ -205,7 +211,7 @@ async fn get_tasks_multiple_models_returns_all_matches() { image: None, model: Some("404-3dgs".to_string()), seed: 0, - model_params: Some(r#"{"preset":"fast"}"#.to_string()), + model_params: Some(serde_json::from_str(r#"{"preset":"fast"}"#).expect("object")), }; let task_b = Task { id: Uuid::new_v4(), @@ -213,7 +219,7 @@ async fn get_tasks_multiple_models_returns_all_matches() { image: None, model: Some("404-mesh".to_string()), seed: 0, - model_params: Some(r#"{"preset":"high_quality"}"#.to_string()), + model_params: Some(serde_json::from_str(r#"{"preset":"high_quality"}"#).expect("object")), }; h.task_manager.add_task(task_a.clone()).await; h.task_manager.add_task(task_b.clone()).await; diff --git a/tests/event_tracker/add_result.rs b/tests/event_tracker/add_result.rs index 32139d6..68ae976 100644 --- a/tests/event_tracker/add_result.rs +++ b/tests/event_tracker/add_result.rs @@ -218,7 +218,9 @@ async fn add_result_not_assigned_returns_unauthorized() { image: None, model: None, seed: 0, - model_params: Some(r#"{"preset":"default"}"#.to_string()), + model_params: Some( + serde_json::from_str(r#"{"preset":"default"}"#).expect("model params object"), + ), }) .await; diff --git a/tests/event_tracker/worker.rs b/tests/event_tracker/worker.rs index 9ace3af..23bca99 100644 --- a/tests/event_tracker/worker.rs +++ b/tests/event_tracker/worker.rs @@ -28,7 +28,9 @@ async fn records_worker_events() { image: None, model: None, seed: 0, - model_params: Some(r#"{"preset":"default"}"#.to_string()), + model_params: Some( + serde_json::from_str(r#"{"preset":"default"}"#).expect("model params object"), + ), }; h.task_manager.add_task(task.clone()).await; h.task_queue.push(task); @@ -104,7 +106,9 @@ async fn records_worker_failure_event() { image: None, model: None, seed: 0, - model_params: Some(r#"{"preset":"default"}"#.to_string()), + model_params: Some( + serde_json::from_str(r#"{"preset":"default"}"#).expect("model params object"), + ), }; h.task_manager.add_task(task.clone()).await; h.task_queue.push(task); @@ -187,7 +191,9 @@ async fn records_worker_timeout_event() { image: None, model: None, seed: 0, - model_params: Some(r#"{"preset":"default"}"#.to_string()), + model_params: Some( + serde_json::from_str(r#"{"preset":"default"}"#).expect("model params object"), + ), }; task_manager.add_task(task).await; let worker = Hotkey::from_bytes(&[3u8; 32]); From 3058d9de34a25e39581159d637c5e6cab4d11f70 Mon Sep 17 00:00:00 2001 From: Denis Avvakumov Date: Fri, 27 Feb 2026 14:56:38 +0200 Subject: [PATCH 5/9] Enforce multipart size limit for model_params field --- src/http3/handlers/task.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/http3/handlers/task.rs b/src/http3/handlers/task.rs index b556a00..74b5590 100644 --- a/src/http3/handlers/task.rs +++ b/src/http3/handlers/task.rs @@ -67,6 +67,7 @@ async fn parse_add_task_multipart( let cfg = state.config(); let image_cfg = cfg.image(); let prompt_cfg = cfg.prompt(); + let model_params_cfg = cfg.model_params(); let upload_limiter = cfg.image_upload_limiter(); let mut image_permit: Option = None; @@ -77,7 +78,8 @@ async fn parse_add_task_multipart( .for_field("image", image_cfg.max_size_bytes as u64) .for_field("prompt", prompt_cfg.max_len as u64) .for_field("model", 64) - .for_field("seed", 11), + .for_field("seed", 11) + .for_field("model_params", model_params_cfg.max_len as u64), ); let mut multipart = Multipart::with_constraints(byte_stream, boundary, constraints); From d3c74e7dbe32e4ddc7368614c1429ce76fd46aa4 Mon Sep 17 00:00:00 2001 From: Raman Kudaktsin Date: Tue, 10 Mar 2026 13:15:48 +0100 Subject: [PATCH 6/9] feat: add batched rate-limit persistence pipeline Persist rate-limit violations as periodic aggregates with per-client JSON details to avoid one-row-per-violation writes and reduce DB pressure. Made-with: Cursor --- dev-env/init-scripts/init-schema.sql | 16 ++++ src/db/event_recorder.rs | 110 ++++++++++++++++++++++++++- src/db/mod.rs | 48 ++++++++++++ src/db/repository.rs | 75 +++++++++++++++++- 4 files changed, 247 insertions(+), 2 deletions(-) diff --git a/dev-env/init-scripts/init-schema.sql b/dev-env/init-scripts/init-schema.sql index a4f43bd..2bf7819 100644 --- a/dev-env/init-scripts/init-schema.sql +++ b/dev-env/init-scripts/init-schema.sql @@ -160,3 +160,19 @@ CREATE INDEX IF NOT EXISTS idx_worker_events_worker_year ON worker_events(worker CREATE INDEX IF NOT EXISTS idx_worker_events_worker_action_week ON worker_events(worker_id, action, bucket_week); CREATE INDEX IF NOT EXISTS idx_worker_events_worker_action_month ON worker_events(worker_id, action, bucket_month); CREATE INDEX IF NOT EXISTS idx_worker_events_worker_action_year ON worker_events(worker_id, action, bucket_year); + +-- Batched rate-limit violations (aggregated per flush window + gateway). +-- "details" stores a JSON object with per-client counters. +CREATE TABLE IF NOT EXISTS rate_limit_violations ( + id BIGSERIAL PRIMARY KEY, + gateway_name VARCHAR(255) NOT NULL, + window_start TIMESTAMP WITHOUT TIME ZONE NOT NULL, + window_end TIMESTAMP WITHOUT TIME ZONE NOT NULL, + total_count BIGINT NOT NULL CHECK (total_count >= 0), + details JSONB NOT NULL, + created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT (NOW() AT TIME ZONE 'UTC') +); +CREATE INDEX IF NOT EXISTS idx_rate_limit_violations_gateway_window + ON rate_limit_violations(gateway_name, window_start DESC); +CREATE INDEX IF NOT EXISTS idx_rate_limit_violations_window + ON rate_limit_violations(window_start DESC); diff --git a/src/db/event_recorder.rs b/src/db/event_recorder.rs index b6b14bd..dfdb25d 100644 --- a/src/db/event_recorder.rs +++ b/src/db/event_recorder.rs @@ -1,3 +1,4 @@ +use std::collections::BTreeMap; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; @@ -7,12 +8,22 @@ use tokio_util::sync::CancellationToken; use tracing::error; use uuid::Uuid; -use crate::db::{ActivityEventRow, EventSink, EventSinkHandle, WorkerEventRow}; +use crate::db::{ + ActivityEventRow, EventSink, EventSinkHandle, RateLimitViolationBatchRow, WorkerEventRow, +}; #[derive(Clone)] enum EventRow { Activity(ActivityEventRow), Worker(WorkerEventRow), + RateLimitViolation(RateLimitViolationRow), +} + +#[derive(Clone)] +struct RateLimitViolationRow { + gateway_name: String, + client_id: String, + created_at: chrono::DateTime, } #[derive(Clone)] @@ -94,6 +105,19 @@ impl EventRecorder { self.enqueue(EventRow::Worker(row)); } + /// Records a single rate-limit denial event for the provided client key. + /// + /// The recorder stores these events in memory and periodically flushes them as + /// aggregated batch rows so we avoid writing one database row per denial. + pub fn record_rate_limit_violation(&self, client_id: &str) { + let row = RateLimitViolationRow { + gateway_name: self.gateway_name.to_string(), + client_id: client_id.to_string(), + created_at: Utc::now(), + }; + self.enqueue(EventRow::RateLimitViolation(row)); + } + fn spawn_flusher(&self, flush_interval: Duration, shutdown: CancellationToken) { let sink = Arc::clone(&self.sink); let queue = Arc::clone(&self.queue); @@ -148,10 +172,12 @@ impl EventRecorder { ) -> anyhow::Result<()> { let mut activity_rows: Vec = Vec::new(); let mut worker_rows: Vec = Vec::new(); + let mut rate_limit_rows: Vec = Vec::new(); while let Some(entry) = queue.pop() { match &**entry { EventRow::Activity(row) => activity_rows.push(row.clone()), EventRow::Worker(row) => worker_rows.push(row.clone()), + EventRow::RateLimitViolation(row) => rate_limit_rows.push(row.clone()), } } @@ -171,6 +197,48 @@ impl EventRecorder { Self::enqueue_with_limit(queue, capacity, dropped, EventRow::Worker(row)); } } + + if !rate_limit_rows.is_empty() { + let mut details = BTreeMap::::new(); + let mut window_start = rate_limit_rows[0].created_at; + let mut window_end = rate_limit_rows[0].created_at; + for row in &rate_limit_rows { + *details.entry(row.client_id.clone()).or_insert(0) += 1; + if row.created_at < window_start { + window_start = row.created_at; + } + if row.created_at > window_end { + window_end = row.created_at; + } + } + + let mut details_pairs: Vec<(String, i64)> = details.into_iter().collect(); + details_pairs.sort_by(|a, b| a.0.cmp(&b.0)); + let details_map: serde_json::Map = details_pairs + .into_iter() + .map(|(k, v)| (k, serde_json::Value::from(v))) + .collect(); + let batch_row = RateLimitViolationBatchRow { + gateway_name: rate_limit_rows[0].gateway_name.clone(), + window_start, + window_end, + total_count: rate_limit_rows.len() as i64, + details: serde_json::Value::Object(details_map), + created_at: Utc::now(), + }; + let batch_rows = vec![batch_row]; + if let Err(e) = sink.record_rate_limit_violation_batches(&batch_rows).await { + error!(error = ?e, "Failed to flush rate-limit violation aggregates"); + for row in rate_limit_rows { + Self::enqueue_with_limit( + queue, + capacity, + dropped, + EventRow::RateLimitViolation(row), + ); + } + } + } Ok(()) } @@ -180,3 +248,43 @@ impl EventRecorder { let _ = Self::flush_once(&self.sink, &self.queue, self.capacity, &self.dropped).await; } } + +#[cfg(all(test, feature = "test-support"))] +mod tests { + use super::EventRecorder; + use crate::db::{EventSinkHandle, InMemoryEventSink}; + use std::sync::Arc; + use std::time::Duration; + use tokio_util::sync::CancellationToken; + + #[tokio::test] + async fn rate_limit_violations_are_aggregated_per_client() { + let sink = InMemoryEventSink::default(); + let shutdown = CancellationToken::new(); + let recorder = EventRecorder::new( + Arc::new(EventSinkHandle::InMemory(sink.clone())), + Arc::from("gw-test"), + Duration::from_secs(60), + 1024, + shutdown.clone(), + ); + + recorder.record_rate_limit_violation("user:alice"); + recorder.record_rate_limit_violation("user:alice"); + recorder.record_rate_limit_violation("ip:10"); + recorder.flush_once_for_test().await; + + let rows = sink.rate_limit_rows().await; + assert_eq!(rows.len(), 1); + let row = &rows[0]; + assert_eq!(row.gateway_name, "gw-test"); + assert_eq!(row.total_count, 3); + assert_eq!( + row.details, + serde_json::json!({ + "ip:10": 1, + "user:alice": 2 + }) + ); + } +} diff --git a/src/db/mod.rs b/src/db/mod.rs index 756f241..e061fbf 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -55,6 +55,16 @@ pub struct WorkerEventRow { pub created_at: DateTime, } +#[derive(Clone)] +pub struct RateLimitViolationBatchRow { + pub gateway_name: String, + pub window_start: DateTime, + pub window_end: DateTime, + pub total_count: i64, + pub details: serde_json::Value, + pub created_at: DateTime, +} + pub struct DatabaseBuilder { sslcert_path: Option, sslkey_path: Option, @@ -71,6 +81,10 @@ pub struct DatabaseBuilder { pub trait EventSink: Send + Sync { async fn record_activity_events_batch(&self, rows: &[ActivityEventRow]) -> Result<()>; async fn record_worker_events_batch(&self, rows: &[WorkerEventRow]) -> Result<()>; + async fn record_rate_limit_violation_batches( + &self, + rows: &[RateLimitViolationBatchRow], + ) -> Result<()>; } #[cfg(feature = "test-support")] @@ -78,6 +92,7 @@ pub trait EventSink: Send + Sync { pub struct InMemoryEventSink { activity: Arc>>, worker: Arc>>, + rate_limit: Arc>>, } #[cfg(feature = "test-support")] @@ -91,6 +106,11 @@ impl InMemoryEventSink { pub async fn worker_rows(&self) -> Vec { self.worker.lock().await.clone() } + + #[allow(dead_code)] + pub async fn rate_limit_rows(&self) -> Vec { + self.rate_limit.lock().await.clone() + } } #[cfg(feature = "test-support")] @@ -107,6 +127,15 @@ impl EventSink for InMemoryEventSink { guard.extend(rows.iter().cloned()); Ok(()) } + + async fn record_rate_limit_violation_batches( + &self, + rows: &[RateLimitViolationBatchRow], + ) -> Result<()> { + let mut guard = self.rate_limit.lock().await; + guard.extend(rows.iter().cloned()); + Ok(()) + } } #[derive(Clone)] @@ -137,6 +166,18 @@ impl EventSink for EventSinkHandle { EventSinkHandle::Noop => Ok(()), } } + + async fn record_rate_limit_violation_batches( + &self, + rows: &[RateLimitViolationBatchRow], + ) -> Result<()> { + match self { + EventSinkHandle::Database(db) => db.record_rate_limit_violation_batches(rows).await, + #[cfg(feature = "test-support")] + EventSinkHandle::InMemory(sink) => sink.record_rate_limit_violation_batches(rows).await, + EventSinkHandle::Noop => Ok(()), + } + } } #[async_trait] @@ -148,4 +189,11 @@ impl EventSink for Database { async fn record_worker_events_batch(&self, rows: &[WorkerEventRow]) -> Result<()> { Database::record_worker_events_batch(self, rows).await } + + async fn record_rate_limit_violation_batches( + &self, + rows: &[RateLimitViolationBatchRow], + ) -> Result<()> { + Database::record_rate_limit_violation_batches(self, rows).await + } } diff --git a/src/db/repository.rs b/src/db/repository.rs index cae590d..f65308b 100644 --- a/src/db/repository.rs +++ b/src/db/repository.rs @@ -5,7 +5,7 @@ use futures_util::SinkExt; use tokio_postgres::types::ToSql; use super::connection::StmtKey; -use super::{ActivityEventRow, Database, WorkerEventRow}; +use super::{ActivityEventRow, Database, RateLimitViolationBatchRow, WorkerEventRow}; impl Database { pub(super) const Q_SERVER_TIME_UTC: &'static str = r#" @@ -126,6 +126,16 @@ task_kind, \ reason, \ gateway_name, \ created_at\ +) FROM STDIN WITH (FORMAT text)"; + + pub(super) const COPY_RATE_LIMIT_VIOLATIONS: &'static str = "\ +COPY rate_limit_violations (\ +gateway_name, \ +window_start, \ +window_end, \ +total_count, \ +details, \ +created_at\ ) FROM STDIN WITH (FORMAT text)"; pub async fn fetch_all_user_key_hashes( @@ -382,6 +392,69 @@ created_at\ } } } + + pub async fn record_rate_limit_violation_batches( + &self, + rows: &[RateLimitViolationBatchRow], + ) -> Result<()> { + if rows.is_empty() { + return Ok(()); + } + let client = self.load_client().await?; + client.batch_execute("BEGIN").await?; + let result = async { + for chunk in rows.chunks(self.events_copy_batch_size) { + let mut buf = Vec::with_capacity(chunk.len() * 256); + for row in chunk { + append_copy_field(&mut buf, Some(row.gateway_name.as_str())); + buf.push(b'\t'); + let window_start = row + .window_start + .naive_utc() + .format("%Y-%m-%d %H:%M:%S%.f") + .to_string(); + append_copy_field(&mut buf, Some(window_start.as_str())); + buf.push(b'\t'); + let window_end = row + .window_end + .naive_utc() + .format("%Y-%m-%d %H:%M:%S%.f") + .to_string(); + append_copy_field(&mut buf, Some(window_end.as_str())); + buf.push(b'\t'); + append_copy_field(&mut buf, Some(row.total_count.to_string().as_str())); + buf.push(b'\t'); + let details = serde_json::to_string(&row.details)?; + append_copy_field(&mut buf, Some(details.as_str())); + buf.push(b'\t'); + let created_at = row + .created_at + .naive_utc() + .format("%Y-%m-%d %H:%M:%S%.f") + .to_string(); + append_copy_field(&mut buf, Some(created_at.as_str())); + buf.push(b'\n'); + } + let sink = client.copy_in(Self::COPY_RATE_LIMIT_VIOLATIONS).await?; + let mut sink = std::pin::pin!(sink); + sink.as_mut().send(Bytes::from(buf)).await?; + sink.as_mut().finish().await?; + } + Ok::<(), anyhow::Error>(()) + } + .await; + + match result { + Ok(()) => { + client.batch_execute("COMMIT").await?; + Ok(()) + } + Err(err) => { + let _ = client.batch_execute("ROLLBACK").await; + Err(err) + } + } + } } fn append_copy_field(buf: &mut Vec, value: Option<&str>) { From 24c3d08f1bf8bb7f0d21521ed9387e79914b32ca Mon Sep 17 00:00:00 2001 From: Raman Kudaktsin Date: Tue, 10 Mar 2026 13:16:10 +0100 Subject: [PATCH 7/9] feat: track local and distributed rate-limit denials Capture 429 events from middleware and distributed subject checks so every denial contributes to the batched violation aggregates. Made-with: Cursor --- src/http3/rate_limits.rs | 52 +++++++++++++++++++++++++++++++++++++++ src/raft/gateway_state.rs | 7 ++++++ 2 files changed, 59 insertions(+) diff --git a/src/http3/rate_limits.rs b/src/http3/rate_limits.rs index 17322cd..0b2b0d3 100644 --- a/src/http3/rate_limits.rs +++ b/src/http3/rate_limits.rs @@ -107,6 +107,46 @@ impl RateLimitContext { } } +/// Builds the identity key used for batched rate-limit violation aggregation. +/// +/// Priority: +/// 1. company id +/// 2. user id +/// 3. decimal source IP +/// 4. unknown +fn violation_client_key(ctx: &RateLimitContext) -> String { + if let Some(company) = ctx.company.as_ref() { + return format!("company:{}", company.id); + } + if let Some(user_id) = ctx.user_id { + return format!("user:{user_id}"); + } + if let Some(ip) = ctx.decimal_ip.as_ref() { + return format!("ip:{ip}"); + } + "unknown".to_string() +} + +fn violation_client_key_for_subject(subject: Subject, id: u128) -> String { + match subject { + Subject::Company => format!("company:{}", Uuid::from_u128(id)), + Subject::User => format!("user:{}", Uuid::from_u128(id)), + } +} + +fn maybe_record_local_violation(state: &HttpState, depot: &Depot, res: &Response) { + if res.status_code != Some(salvo::http::StatusCode::TOO_MANY_REQUESTS) { + return; + } + if let Ok(ctx) = depot.obtain::() { + state + .gateway_state() + .record_rate_limit_violation(violation_client_key(ctx).as_str()); + } else { + state.gateway_state().record_rate_limit_violation("unknown"); + } +} + #[handler] pub async fn prepare_rate_limit_context( depot: &mut Depot, @@ -343,6 +383,7 @@ pub async fn basic_rate_limit( .basic_limiter .handle(req, depot, res, ctrl) .await; + maybe_record_local_violation(&state, depot, res); Ok(()) } @@ -359,6 +400,7 @@ pub async fn update_key_rate_limit( .update_limiter .handle(req, depot, res, ctrl) .await; + maybe_record_local_violation(&state, depot, res); Ok(()) } @@ -375,6 +417,7 @@ pub async fn unauthorized_only_rate_limit( .unauthorized_only_limiter .handle(req, depot, res, ctrl) .await; + maybe_record_local_violation(&state, depot, res); Ok(()) } @@ -391,6 +434,7 @@ pub async fn generic_global_rate_limit( .generic_global_limiter .handle(req, depot, res, ctrl) .await; + maybe_record_local_violation(&state, depot, res); Ok(()) } @@ -407,6 +451,7 @@ pub async fn generic_per_ip_rate_limit( .generic_per_ip_limiter .handle(req, depot, res, ctrl) .await; + maybe_record_local_violation(&state, depot, res); Ok(()) } @@ -423,6 +468,7 @@ pub async fn read_rate_limit( .read_limiter .handle(req, depot, res, ctrl) .await; + maybe_record_local_violation(&state, depot, res); Ok(()) } @@ -439,6 +485,7 @@ pub async fn result_rate_limit( .result_limiter .handle(req, depot, res, ctrl) .await; + maybe_record_local_violation(&state, depot, res); Ok(()) } @@ -455,6 +502,7 @@ pub async fn load_rate_limit( .load_limiter .handle(req, depot, res, ctrl) .await; + maybe_record_local_violation(&state, depot, res); Ok(()) } @@ -471,6 +519,7 @@ pub async fn leader_rate_limit( .leader_limiter .handle(req, depot, res, ctrl) .await; + maybe_record_local_violation(&state, depot, res); Ok(()) } @@ -487,6 +536,7 @@ pub async fn metric_rate_limit( .metric_limiter .handle(req, depot, res, ctrl) .await; + maybe_record_local_violation(&state, depot, res); Ok(()) } @@ -503,6 +553,7 @@ pub async fn status_rate_limit( .status_limiter .handle(req, depot, res, ctrl) .await; + maybe_record_local_violation(&state, depot, res); Ok(()) } @@ -563,6 +614,7 @@ async fn enforce_subject( ) .await { + gs.record_rate_limit_violation(violation_client_key_for_subject(subject, id).as_str()); return Err(ServerError::TooManyRequests(error_msg.to_string())); } diff --git a/src/raft/gateway_state.rs b/src/raft/gateway_state.rs index 6c5ed89..bff0a5d 100644 --- a/src/raft/gateway_state.rs +++ b/src/raft/gateway_state.rs @@ -452,6 +452,13 @@ impl GatewayState { ); } + /// Records a single rate-limit violation for later batched persistence. + pub fn record_rate_limit_violation(&self, client_id: &str) { + self.internal + .event_recorder + .record_rate_limit_violation(client_id); + } + pub async fn submit_rate_limit_deltas( &self, deltas: RateLimitDeltaBatch, From 8df78e0b39e8b5af278990e5fa52c872050582c2 Mon Sep 17 00:00:00 2001 From: Raman Kudaktsin Date: Tue, 10 Mar 2026 13:16:26 +0100 Subject: [PATCH 8/9] test: add coverage for rate-limit violation flow Add integration coverage for distributed user limiter 429 behavior and unit coverage for per-client violation aggregation output. Made-with: Cursor --- tests/client_http_api/mod.rs | 1 + tests/client_http_api/rate_limits.rs | 80 ++++++++++++++++++++++++++++ tests/event_tracker/support.rs | 1 + 3 files changed, 82 insertions(+) create mode 100644 tests/client_http_api/rate_limits.rs diff --git a/tests/client_http_api/mod.rs b/tests/client_http_api/mod.rs index 563be70..8991ca0 100644 --- a/tests/client_http_api/mod.rs +++ b/tests/client_http_api/mod.rs @@ -6,4 +6,5 @@ mod get_result; mod get_status; mod get_tasks; mod misc; +mod rate_limits; mod support; diff --git a/tests/client_http_api/rate_limits.rs b/tests/client_http_api/rate_limits.rs new file mode 100644 index 0000000..0f5431b --- /dev/null +++ b/tests/client_http_api/rate_limits.rs @@ -0,0 +1,80 @@ +use std::sync::Arc; +use std::time::Duration; + +use http::StatusCode; +use salvo::prelude::*; +use salvo::test::TestClient; +use tokio_util::sync::CancellationToken; +use uuid::Uuid; + +use gateway::db::{EventRecorder, EventSinkHandle}; +use gateway::test_support::{ + RateLimitContext, build_shared_harness_core, enforce_rate_limit, ensure_test_crypto_provider, + load_test_single_node_config, +}; + +use crate::support::read_response; + +#[handler] +async fn ok_handler() -> &'static str { + "ok" +} + +#[tokio::test] +async fn distributed_user_rate_limit_returns_429() { + ensure_test_crypto_provider(); + let (mut config, _path) = load_test_single_node_config(); + config + .http + .add_task_authenticated_per_user_hourly_rate_limit = 1; + let config = Arc::new(config); + + let config_file = tempfile::Builder::new() + .suffix(".toml") + .tempfile() + .expect("temp config file"); + let config_toml = toml::to_string(config.as_ref()).expect("serialize test config"); + std::fs::write(config_file.path(), config_toml).expect("write temp config"); + let config_path = config_file.path().to_path_buf(); + + let shutdown = CancellationToken::new(); + let event_recorder = EventRecorder::new( + Arc::new(EventSinkHandle::Noop), + Arc::from(config.network.name.as_str()), + Duration::from_secs(30), + config.db.events_queue_capacity.max(1), + shutdown.clone(), + ); + let core = build_shared_harness_core(config.clone(), config_path, event_recorder, true).await; + let state = core.state; + + let router = Router::new().hoop(affix_state::inject(state)).push( + Router::with_path("/rl-probe") + .hoop(affix_state::inject(RateLimitContext { + user_id: Some(Uuid::new_v4()), + has_valid_api_key: true, + key_is_uuid: true, + ..RateLimitContext::default() + })) + .hoop(enforce_rate_limit) + .get(ok_handler), + ); + let service = Service::new(router); + + let first = TestClient::get("http://localhost/rl-probe") + .send(&service) + .await; + let (first_status, _headers, _body) = read_response(first).await; + assert_eq!(first_status, StatusCode::OK); + + let second = TestClient::get("http://localhost/rl-probe") + .send(&service) + .await; + let (second_status, _headers, body) = read_response(second).await; + assert_eq!( + second_status, + StatusCode::TOO_MANY_REQUESTS, + "second probe body: {}", + String::from_utf8_lossy(&body) + ); +} diff --git a/tests/event_tracker/support.rs b/tests/event_tracker/support.rs index a55aab7..2d7fbb7 100644 --- a/tests/event_tracker/support.rs +++ b/tests/event_tracker/support.rs @@ -65,6 +65,7 @@ pub(crate) fn current_timestamp() -> u64 { current_timestamp_secs() } +#[allow(clippy::too_many_arguments)] pub(crate) fn multipart_add_result( task_id: Uuid, worker_hotkey: &Hotkey, From 0195e8419e8af0f971714e940dea7bb9fb8cd956 Mon Sep 17 00:00:00 2001 From: Raman Kudaktsin Date: Tue, 10 Mar 2026 14:13:32 +0100 Subject: [PATCH 9/9] fix: resolve post-merge rate-limit check issues Ignore Cursor workspace files and fix rate-limit handler/test updates needed after syncing with main so local checks run against the new branch state. Made-with: Cursor --- .gitignore | 1 + src/http3/rate_limits.rs | 3 +++ tests/client_http_api/rate_limits.rs | 17 ++++------------- 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/.gitignore b/.gitignore index 4f9f2e3..4c66556 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target /logs .vscode +.cursor diff --git a/src/http3/rate_limits.rs b/src/http3/rate_limits.rs index 3f82fe9..fd70000 100644 --- a/src/http3/rate_limits.rs +++ b/src/http3/rate_limits.rs @@ -166,6 +166,8 @@ fn violation_client_key_for_subject(subject: Subject, id: u128) -> String { match subject { Subject::Company => format!("company:{}", Uuid::from_u128(id)), Subject::User => format!("user:{}", Uuid::from_u128(id)), + Subject::GenericGlobal => "generic:global".to_string(), + Subject::GenericIp => format!("generic:ip:{id}"), } } @@ -391,6 +393,7 @@ pub async fn unauthorized_only_rate_limit( Ok(()) } +#[handler] pub async fn read_rate_limit( depot: &mut Depot, req: &mut Request, diff --git a/tests/client_http_api/rate_limits.rs b/tests/client_http_api/rate_limits.rs index 0f5431b..c51c7a1 100644 --- a/tests/client_http_api/rate_limits.rs +++ b/tests/client_http_api/rate_limits.rs @@ -5,11 +5,10 @@ use http::StatusCode; use salvo::prelude::*; use salvo::test::TestClient; use tokio_util::sync::CancellationToken; -use uuid::Uuid; use gateway::db::{EventRecorder, EventSinkHandle}; use gateway::test_support::{ - RateLimitContext, build_shared_harness_core, enforce_rate_limit, ensure_test_crypto_provider, + basic_rate_limit, build_shared_harness_core, ensure_test_crypto_provider, load_test_single_node_config, }; @@ -21,12 +20,10 @@ async fn ok_handler() -> &'static str { } #[tokio::test] -async fn distributed_user_rate_limit_returns_429() { +async fn basic_rate_limit_returns_429() { ensure_test_crypto_provider(); let (mut config, _path) = load_test_single_node_config(); - config - .http - .add_task_authenticated_per_user_hourly_rate_limit = 1; + config.http.basic_rate_limit = 1; let config = Arc::new(config); let config_file = tempfile::Builder::new() @@ -50,13 +47,7 @@ async fn distributed_user_rate_limit_returns_429() { let router = Router::new().hoop(affix_state::inject(state)).push( Router::with_path("/rl-probe") - .hoop(affix_state::inject(RateLimitContext { - user_id: Some(Uuid::new_v4()), - has_valid_api_key: true, - key_is_uuid: true, - ..RateLimitContext::default() - })) - .hoop(enforce_rate_limit) + .hoop(basic_rate_limit) .get(ok_handler), ); let service = Service::new(router);