Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1499,6 +1499,43 @@ impl MistralRs {
.map_err(|_| MistralRsError::EnginePoisoned)?;
unloaded.insert(resolved_model_id.to_string(), unloaded_state);

let device = engine_instance.config.device.clone();

// The Engine asynchronous worker loop has finished dropping inner variables.
let _ = engine_instance.engine_handler.join();

// We MUST bind to the device context BEFORE dropping the state!
// Tensors drop on the HTTP thread naturally here; cudarc requires the CUDA OS context
// to be bounded for cuMemFreeAsync to execute natively without silently aborting into the memory pool void.
#[cfg(feature = "cuda")]
let _ctx_guard = {
if let candle_core::Device::Cuda(dev) = &device {
dev.cuda_stream().context().bind_to_thread().ok()
} else {
None
}
};

drop(engine_instance.reboot_state);
let _ = device.synchronize();

// Manually execute a garbage collection pool trim synchronously mapped.
#[cfg(feature = "cuda")]
if let candle_core::Device::Cuda(dev) = &device {
unsafe {
use candle_core::cuda::cudarc::driver::sys;
if let Ok(_ctx) = dev.cuda_stream().context().bind_to_thread() {
let mut dev_id = 0;
if sys::cuCtxGetDevice(&mut dev_id) == sys::CUresult::CUDA_SUCCESS {
let mut pool: sys::CUmemoryPool = std::ptr::null_mut();
if sys::cuDeviceGetDefaultMemPool(&mut pool, dev_id) == sys::CUresult::CUDA_SUCCESS {
sys::cuMemPoolTrimTo(pool, 0);
}
}
}
}
}

// Update default if needed
let mut default_lock = self
.default_engine_id
Expand Down
12 changes: 6 additions & 6 deletions mistralrs-server-core/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ pub async fn re_isq(
/// Request for model operations (unload, reload, status)
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
pub struct ModelOperationRequest {
#[schema(example = "my-model")]
pub model_id: String,
#[schema(example = "default")]
pub model_id: Option<String>,
}

/// Model status enum
Expand All @@ -157,7 +157,7 @@ pub enum ModelStatus {
/// Response for model status operations
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
pub struct ModelStatusResponse {
#[schema(example = "my-model")]
#[schema(example = "default")]
pub model_id: String,
pub status: ModelStatus,
/// Error message when status indicates an error condition
Expand All @@ -179,7 +179,7 @@ pub async fn unload_model(
State(state): ExtractedMistralRsState,
Json(request): Json<ModelOperationRequest>,
) -> Json<ModelStatusResponse> {
let model_id = request.model_id;
let model_id = request.model_id.unwrap_or_else(|| "default".to_string());
match state.unload_model(&model_id) {
Ok(()) => Json(ModelStatusResponse {
model_id,
Expand Down Expand Up @@ -216,7 +216,7 @@ pub async fn reload_model(
State(state): ExtractedMistralRsState,
Json(request): Json<ModelOperationRequest>,
) -> Json<ModelStatusResponse> {
let model_id = request.model_id;
let model_id = request.model_id.unwrap_or_else(|| "default".to_string());
match state.reload_model(&model_id).await {
Ok(()) => Json(ModelStatusResponse {
model_id,
Expand Down Expand Up @@ -256,7 +256,7 @@ pub async fn get_model_status(
State(state): ExtractedMistralRsState,
Json(request): Json<ModelOperationRequest>,
) -> Json<ModelStatusResponse> {
let model_id = request.model_id;
let model_id = request.model_id.unwrap_or_else(|| "default".to_string());
match state.get_model_status(&model_id) {
Ok(Some(core_status)) => {
let status = match core_status {
Expand Down
9 changes: 9 additions & 0 deletions mistralrs-server-core/src/mistralrs_server_router_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,15 @@ fn init_router(
.route("/v1/models", get(models))
.route("/v1/models/unload", post(unload_model))
.route("/v1/models/reload", post(reload_model))
// Aliases for vLLM & SGLang feature parity
.route("/v1/sleep", post(unload_model))
.route("/sleep", post(unload_model))
.route("/v1/wake_up", post(reload_model))
.route("/wake_up", post(reload_model))
.route("/v1/release_memory_occupation", post(unload_model))
.route("/release_memory_occupation", post(unload_model))
.route("/v1/resume_memory_occupation", post(reload_model))
.route("/resume_memory_occupation", post(reload_model))
.route("/v1/models/status", post(get_model_status))
.route("/v1/models/tune", post(tune_model))
.route("/v1/system/info", get(system_info))
Expand Down