diff --git a/src/server/api.rs b/src/server/api.rs index bd5cc8a..017e5f5 100644 --- a/src/server/api.rs +++ b/src/server/api.rs @@ -55,6 +55,9 @@ pub struct EmbedRequest { pub text: String, } +const MAX_BATCH_SIZE: usize = 100; +const MAX_TEXT_LENGTH: usize = 8192; // Common token limit for many models + #[derive(Debug, Deserialize)] pub struct EmbedBatchRequest { pub texts: Vec, @@ -303,6 +306,16 @@ async fn embed( State(state): State, Json(req): Json, ) -> impl IntoResponse { + if req.text.len() > MAX_TEXT_LENGTH { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": format!("Text exceeds length limit of {}", MAX_TEXT_LENGTH) + })), + ) + .into_response(); + } + let engine = match state.embedding_engine.lock() { Ok(e) => e, Err(e) => { @@ -339,6 +352,28 @@ async fn embed_batch( State(state): State, Json(req): Json, ) -> impl IntoResponse { + if req.texts.len() > MAX_BATCH_SIZE { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": format!("Batch size exceeds limit of {}", MAX_BATCH_SIZE) + })), + ) + .into_response(); + } + + for (i, text) in req.texts.iter().enumerate() { + if text.len() > MAX_TEXT_LENGTH { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": format!("Text at index {} exceeds length limit of {}", i, MAX_TEXT_LENGTH) + })), + ) + .into_response(); + } + } + let engine = match state.embedding_engine.lock() { Ok(e) => e, Err(e) => {