Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions crates/rmlx-server/src/anthropic/blocking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ pub(super) async fn generate_blocking(
u64, // queue_depth at admission
u64, // queue_wait_ms
)>,
// Rolling ring-buffer for TTFT samples shared with AppState.
// Written here so non-streaming requests populate the same ring as streaming.
ttft_store: &crate::openai::state::TtftStore,
) -> Response {
let input_token_count = req.prompt_tokens.len() as u32;
// Capture stop sequences before `req` is moved into the generator.
Expand Down Expand Up @@ -103,6 +106,20 @@ pub(super) async fn generate_blocking(
if output_tokens == 0 {
let ttft_ms = request_start.elapsed().as_millis() as u64;
ttft_ms_blocking = Some(ttft_ms);
tracing::info!(model_id, ttft_ms, "generate_blocking (anthropic): TTFT");
// Append to the rolling ring-buffer so non-streaming requests
// populate the same ring as streaming requests.
{
use crate::openai::{TtftSample, TTFT_RING_CAPACITY};
let mut ring = ttft_store.lock();
if ring.len() >= TTFT_RING_CAPACITY {
ring.pop_front();
}
ring.push_back(TtftSample {
model_id: model_id.to_owned(),
ttft_ms,
});
}
tracing::debug!(
model_id,
phase = ?crate::engine::Phase::Decode,
Expand Down
1 change: 1 addition & 0 deletions crates/rmlx-server/src/anthropic/route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@ pub(crate) async fn messages(
.admission_controller
.as_ref()
.map(|ctrl| (ctrl.clone(), admitted_depth, admitted_wait_ms)),
&state.ttft_store,
)
.instrument(req_span)
.await
Expand Down
1 change: 1 addition & 0 deletions crates/rmlx-server/src/openai/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,7 @@ pub(crate) async fn chat_completions(
.admission_controller
.as_ref()
.map(|ctrl| (ctrl.clone(), admitted_depth, admitted_wait_ms)),
&state.ttft_store,
)
.instrument(req_span)
.await
Expand Down
22 changes: 21 additions & 1 deletion crates/rmlx-server/src/openai/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
//! - `extract_top_level_json_value` — strip markdown fences from json_object output.
//! - `try_extract_at` — attempt to extract a single JSON value at a given offset.

#[cfg(test)]
#[path = "generate_tests.rs"]
mod generate_tests;

use std::sync::Arc;
use std::time::Instant;

Expand All @@ -26,7 +30,7 @@ use super::response::{
select_finish_reason, to_response_tool_call, ChatCompletionChunk, ChatCompletionsResponse,
ChatLogprobContent, ChatLogprobs, Choice, DeltaContent, ResponseMessage, StreamChoice, Usage,
};
use super::state::{ApiErrorCounters, AppState, TtftSample, TTFT_RING_CAPACITY};
use super::state::{ApiErrorCounters, AppState, TtftSample, TtftStore, TTFT_RING_CAPACITY};
use super::streaming::{handle_streaming_token, StreamState};

// ── JSON extraction helpers ───────────────────────────────────────────────────
Expand Down Expand Up @@ -275,6 +279,9 @@ pub(super) async fn generate_blocking(
u64, // queue_depth at admission
u64, // queue_wait_ms
)>,
// Rolling ring-buffer for TTFT samples shared with AppState.
// Written here so non-streaming requests populate the same ring as streaming.
ttft_store: &TtftStore,
) -> Response {
let prompt_token_count = req.prompt_tokens.len() as u32;
// Capture the stop sequences before `req` is moved into the
Expand Down Expand Up @@ -330,6 +337,19 @@ pub(super) async fn generate_blocking(
if completion_tokens == 0 {
let ttft_ms = request_start.elapsed().as_millis() as u64;
ttft_ms_blocking = Some(ttft_ms);
tracing::info!(model_id, ttft_ms, "generate_blocking: TTFT");
// Append to the rolling ring-buffer so non-streaming requests
// populate the same ring as streaming requests.
{
let mut ring = ttft_store.lock();
if ring.len() >= TTFT_RING_CAPACITY {
ring.pop_front();
}
ring.push_back(TtftSample {
model_id: model_id.to_owned(),
ttft_ms,
});
}
// phase transition Prefill -> Decode at the natural
// TTFT boundary (first OK token). Same timestamp as the
// existing TTFT capture — no second `Instant::now()`.
Expand Down
263 changes: 263 additions & 0 deletions crates/rmlx-server/src/openai/generate_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
//! Unit tests for `generate_blocking` and `generate_streaming` behaviour
//! that does not require a real model or HTTP server.
//!
//! Convention: `#[cfg(test)] #[path = "generate_tests.rs"] mod generate_tests;`
//! in `generate.rs` wires this file. No inline test blocks elsewhere.

#![allow(
clippy::unwrap_used,
reason = "test-only: panics surface the root cause clearly"
)]
#![allow(
clippy::expect_used,
reason = "test-only: panics surface the root cause clearly"
)]

use std::pin::Pin;
use std::sync::Arc;
use std::time::Instant;

use futures::stream::{self, Stream};
use rmlx_core::Error;

use crate::engine::{GenerationRequest, GenerationToken, Generator};
use crate::openai::state::{ApiErrorCounters, TtftStore};
use crate::openai::TTFT_RING_CAPACITY;

// ── Minimal stub generator ────────────────────────────────────────────────────

/// Emits a fixed sequence of tokens then signals done.
struct FixedTokenGenerator {
pieces: Vec<&'static str>,
}

impl Generator for FixedTokenGenerator {
fn generate(
&self,
_req: GenerationRequest,
) -> Pin<Box<dyn Stream<Item = rmlx_core::Result<GenerationToken>> + Send>> {
let pieces = self.pieces.clone();
let n = pieces.len();
let tokens: Vec<rmlx_core::Result<GenerationToken>> = pieces
.into_iter()
.enumerate()
.map(|(i, piece)| {
let done = i + 1 == n;
Ok(GenerationToken {
token_id: i as u32,
piece: piece.to_owned(),
done,
finish_reason: if done { Some("stop".to_owned()) } else { None },
is_thinking: false,
logprobs: None,
})
})
.collect();
Box::pin(stream::iter(tokens))
}
}

/// Emits a single error item — exercises the early-error path.
struct ErrorGenerator;

impl Generator for ErrorGenerator {
fn generate(
&self,
_req: GenerationRequest,
) -> Pin<Box<dyn Stream<Item = rmlx_core::Result<GenerationToken>> + Send>> {
Box::pin(stream::once(async {
Err(Error::Other("stub engine error".to_owned()))
}))
}
}

// ── Helper: minimal GenerationRequest ────────────────────────────────────────

fn minimal_request(model_id: &str) -> GenerationRequest {
GenerationRequest {
model_id: model_id.to_owned(),
prompt_tokens: vec![1, 2, 3],
max_tokens: 16,
sampling: crate::engine::types::SamplingParams::default(),
stop: vec![],
stream: false,
system: None,
session_id: None,
effective_prompt_cache_slots: None,
metrics_drainer: None,
itl_store: None,
event_recorder: None,
tools: None,
tool_choice: None,
response_format: None,
constraint: None,
is_thinking_handle: None,
thinking_budget: None,
thinking_end_token_id: None,
enable_thinking: None,
emit_tool_markers: false,
thinking_start_token: None,
thinking_end_token: None,
gpu_admission: None,
kv_quant_override: None,
max_ctx_override: None,
images: vec![],
audio_b64: vec![],
}
}

// ── Tests ─────────────────────────────────────────────────────────────────────

/// Non-streaming `generate_blocking` must populate `ttft_store` with exactly
/// one entry per completed request, using the correct `model_id`.
#[tokio::test]
async fn blocking_ttft_ring_populated() {
let ttft_store = TtftStore::default();
let tokens_in = Arc::new(std::sync::atomic::AtomicU64::new(0));
let tokens_out = Arc::new(std::sync::atomic::AtomicU64::new(0));
let error_counts = ApiErrorCounters::new();
let requests_completed = Arc::new(std::sync::atomic::AtomicU64::new(0));
let requests_failed = Arc::new(std::sync::atomic::AtomicU64::new(0));

let generator: Arc<dyn Generator> = Arc::new(FixedTokenGenerator {
pieces: vec!["hello", " world"],
});

let _ = super::generate_blocking(
generator,
minimal_request("test-model"),
None,
"test-model",
None,
false,
false,
Instant::now(),
None,
0,
&tokens_in,
&tokens_out,
"req-001",
&error_counts,
&requests_completed,
&requests_failed,
None,
false,
None,
&ttft_store,
)
.await;

let ring = ttft_store.lock();
assert_eq!(ring.len(), 1, "one completed request → one TTFT sample");
assert_eq!(
ring[0].model_id, "test-model",
"TTFT sample must carry the correct model_id"
);
assert!(
ring[0].ttft_ms < 5_000,
"TTFT must be a plausible wall-clock value"
);
}

/// When the engine returns an error on the first token, `ttft_store` must
/// remain empty (no TTFT is recorded for failed requests).
#[tokio::test]
async fn blocking_ttft_ring_empty_on_engine_error() {
let ttft_store = TtftStore::default();
let tokens_in = Arc::new(std::sync::atomic::AtomicU64::new(0));
let tokens_out = Arc::new(std::sync::atomic::AtomicU64::new(0));
let error_counts = ApiErrorCounters::new();
let requests_completed = Arc::new(std::sync::atomic::AtomicU64::new(0));
let requests_failed = Arc::new(std::sync::atomic::AtomicU64::new(0));

let generator: Arc<dyn Generator> = Arc::new(ErrorGenerator);

let _ = super::generate_blocking(
generator,
minimal_request("test-model"),
None,
"test-model",
None,
false,
false,
Instant::now(),
None,
0,
&tokens_in,
&tokens_out,
"req-002",
&error_counts,
&requests_completed,
&requests_failed,
None,
false,
None,
&ttft_store,
)
.await;

let ring = ttft_store.lock();
assert_eq!(
ring.len(),
0,
"engine error on first token → no TTFT sample (first token never arrived)"
);
}

/// TTFT ring evicts the oldest entry once `TTFT_RING_CAPACITY` is reached.
#[tokio::test]
async fn blocking_ttft_ring_respects_capacity() {
let ttft_store = TtftStore::default();
let error_counts = ApiErrorCounters::new();
let requests_completed = Arc::new(std::sync::atomic::AtomicU64::new(0));
let requests_failed = Arc::new(std::sync::atomic::AtomicU64::new(0));

// Run CAPACITY + 2 requests to verify oldest entries are evicted.
for i in 0..(TTFT_RING_CAPACITY + 2) {
let generator: Arc<dyn Generator> = Arc::new(FixedTokenGenerator {
pieces: vec!["tok"],
});
let tokens_in = Arc::new(std::sync::atomic::AtomicU64::new(0));
let tokens_out = Arc::new(std::sync::atomic::AtomicU64::new(0));
let model = format!("model-{i}");
let _ = super::generate_blocking(
generator,
minimal_request(&model),
None,
&model,
None,
false,
false,
Instant::now(),
None,
0,
&tokens_in,
&tokens_out,
"req-cap",
&error_counts,
&requests_completed,
&requests_failed,
None,
false,
None,
&ttft_store,
)
.await;
}

let ring = ttft_store.lock();
assert_eq!(
ring.len(),
TTFT_RING_CAPACITY,
"ring must not exceed TTFT_RING_CAPACITY after overflow"
);
// The two oldest entries (model-0, model-1) should have been evicted.
assert!(
ring.iter().all(|s| s.model_id != "model-0"),
"model-0 must have been evicted"
);
assert!(
ring.iter().all(|s| s.model_id != "model-1"),
"model-1 must have been evicted"
);
}
Loading
Loading