Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,8 @@ impl Engine {
&mut scheduled.completion,
self.pipeline,
'lp,
self.prefix_cacher
self.prefix_cacher,
&self.scheduler
);

self.logger.add_tokens_processed(scheduled.completion.len());
Expand Down Expand Up @@ -474,7 +475,8 @@ impl Engine {
&mut scheduled.prompt,
self.pipeline,
'lp,
self.prefix_cacher
self.prefix_cacher,
&self.scheduler
);

let total_processed_tokens: usize = scheduled
Expand Down Expand Up @@ -697,7 +699,8 @@ impl Engine {
&mut guards_mut,
self.pipeline,
'lp,
self.prefix_cacher
self.prefix_cacher,
&self.scheduler
);

let total_processed_tokens: usize = guards_mut
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,7 @@ impl Sequence {
*self.state.read().unwrap(),
SequenceState::FinishedAborted
| SequenceState::FinishedIgnored
| SequenceState::Error
| SequenceState::Done(_)
)
}
Expand Down
130 changes: 123 additions & 7 deletions mistralrs-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ macro_rules! handle_seq_error_stateaware_ok {
#[doc(hidden)]
#[macro_export]
macro_rules! handle_pipeline_forward_error {
($stage: tt, $fallible:expr, $seq_slice:expr, $pipeline:expr, $label:tt, $prefix_cacher:expr) => {
($stage: tt, $fallible:expr, $seq_slice:expr, $pipeline:expr, $label:tt, $prefix_cacher:expr, $scheduler:expr) => {
match $fallible {
Ok(v) => v,
Err(e) => {
Expand Down Expand Up @@ -174,13 +174,12 @@ macro_rules! handle_pipeline_forward_error {
usage: group.get_usage(),
};

seq.responder()
let _ = seq.responder()
.send(Response::ModelError(
e.to_string(),
partial_completion_response
))
.await
.unwrap();
.await;
} else {
let partial_completion_response = CompletionResponse {
id: seq.id().to_string(),
Expand All @@ -192,13 +191,12 @@ macro_rules! handle_pipeline_forward_error {
usage: group.get_usage(),
};

seq.responder()
let _ = seq.responder()
.send(Response::CompletionModelError(
e.to_string(),
partial_completion_response
))
.await
.unwrap();
.await;
}
}
for seq in $seq_slice.iter_mut() {
Expand All @@ -213,6 +211,124 @@ macro_rules! handle_pipeline_forward_error {
p.set_none_cache($seq_slice, true, true, false);
get_mut_arcmutex!($prefix_cacher).evict_all_caches().unwrap();

// Free KV blocks for errored sequences. The scheduler guard MUST
// be dropped by the call site before this macro is invoked — passing
// `&self.scheduler` while `let mut scheduler = get_mut_arcmutex!(self.scheduler)`
// is still in scope causes a spin-lock deadlock because get_mut_arcmutex!
// loops on try_lock() which will never succeed on the same thread.
get_mut_arcmutex!($scheduler).free_finished_sequence_groups();

continue $label;
}
}
};
// 6-argument form for DefaultScheduler callsites.
// The DefaultScheduler match arm holds the scheduler MutexGuard for its entire
// duration; passing &self.scheduler and calling get_mut_arcmutex! inside the
// macro would spin-deadlock. DefaultScheduler also calls free_finished_sequence_groups
// at the bottom of the engine loop normally, so omitting it here is correct.
($stage: tt, $fallible:expr, $seq_slice:expr, $pipeline:expr, $label:tt, $prefix_cacher:expr) => {
match $fallible {
Ok(v) => v,
Err(e) => {
#[cfg(feature = "metal")]
{
let err_str = e.to_string();
if err_str.contains("Insufficient Permission")
|| err_str.contains("BackgroundExecutionNotPermitted")
{
tracing::warn!(
"Metal GPU background error detected (iOS app likely in background). \
Pausing 1s before retry..."
);
{
let p = get_mut_arcmutex!($pipeline);
p.set_none_cache($seq_slice, true, true, false);
}
get_mut_arcmutex!($prefix_cacher).evict_all_caches().unwrap();
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
continue $label;
}
}
let (tokenizer, pipeline_name) = {
let pipeline = get_mut_arcmutex!($pipeline);
let pipeline_name = pipeline.name();
let tokenizer = pipeline.tokenizer();
(tokenizer, pipeline_name)
};
use $crate::response::Response;
use $crate::sequence::SequenceState;
use $crate::response::SYSTEM_FINGERPRINT;
use tracing::error;
error!("{} - Model failed with error: {:?}", $stage, &e);
for seq in $seq_slice.iter_mut() {
let start = seq.prompt_tokens().min(seq.get_toks().len());
let res = match &tokenizer {
Some(tok) => match tok.decode(&seq.get_toks()[start..], false) {
Ok(t) => t,
Err(_) => "".to_string(),
},
None => "".to_string(),
};
if seq.get_mut_group().is_chat {
let choice = Choice {
finish_reason: "error".to_string(),
index: seq.get_response_index(),
message: ResponseMessage {
content: Some(res),
role: "assistant".to_string(),
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
};
seq.add_choice_to_group(choice);
} else {
let choice = CompletionChoice {
finish_reason: "error".to_string(),
index: seq.get_response_index(),
text: res,
logprobs: None,
};
seq.add_completion_choice_to_group(choice);
}
}
for seq in $seq_slice.iter_mut() {
let group = seq.get_mut_group();
if group.is_chat {
let partial_completion_response = ChatCompletionResponse {
id: seq.id().to_string(),
choices: group.get_choices().to_vec(),
created: seq.creation_time(),
model: pipeline_name.clone(),
system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
object: "chat.completion".to_string(),
usage: group.get_usage(),
};
let _ = seq.responder()
.send(Response::ModelError(e.to_string(), partial_completion_response))
.await;
} else {
let partial_completion_response = CompletionResponse {
id: seq.id().to_string(),
choices: group.get_completion_choices().to_vec(),
created: seq.creation_time(),
model: pipeline_name.clone(),
system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
object: "text_completion".to_string(),
usage: group.get_usage(),
};
let _ = seq.responder()
.send(Response::CompletionModelError(e.to_string(), partial_completion_response))
.await;
}
}
for seq in $seq_slice.iter_mut() {
seq.set_state(SequenceState::Error);
}
let p = get_mut_arcmutex!($pipeline);
p.set_none_cache($seq_slice, true, true, false);
get_mut_arcmutex!($prefix_cacher).evict_all_caches().unwrap();
continue $label;
}
}
Expand Down
Loading