From 4ef4ebaa8055d0f2b753b1d787933507bc083fc8 Mon Sep 17 00:00:00 2001 From: "Damian K. Kowalczyk" Date: Tue, 10 Feb 2026 22:38:23 -0800 Subject: [PATCH 1/8] managed-identity-support --- CMakeLists.txt | 1 + src/CMakeLists.txt | 4 ++ src/filesystem/implementations/as.h | 66 +++++++++++++++++++++++++++-- 3 files changed, 68 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ab55276b8..57d96654a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -286,6 +286,7 @@ if(NOT TRITON_CORE_HEADERS_ONLY) -Dazure-storage-blobs-cpp_DIR:PATH=${TRITON_THIRD_PARTY_INSTALL_PREFIX}/azure-sdk/share/azure-storage-blobs-cpp -Dazure-storage-common-cpp_DIR:PATH=${TRITON_THIRD_PARTY_INSTALL_PREFIX}/azure-sdk/share/azure-storage-common-cpp -Dazure-core-cpp_DIR:PATH=${TRITON_THIRD_PARTY_INSTALL_PREFIX}/azure-sdk/share/azure-core-cpp + -Dazure-identity-cpp_DIR:PATH=${TRITON_THIRD_PARTY_INSTALL_PREFIX}/azure-sdk/share/azure-identity-cpp ) endif() # TRITON_ENABLE_AZURE_STORAGE if(${TRITON_ENABLE_METRICS}) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6a0967036..dd3a0a426 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -103,7 +103,9 @@ endif() # if(${TRITON_ENABLE_AZURE_STORAGE}) find_package(azure-storage-blobs-cpp CONFIG REQUIRED) + find_package(azure-identity-cpp CONFIG REQUIRED) message(STATUS "Using Azure storage blobs ${azure-storage-blobs-cpp_VERSION}") + message(STATUS "Using Azure identity ${azure-identity-cpp_VERSION}") endif() configure_file(libtritonserver.ldscript libtritonserver.ldscript COPYONLY) @@ -327,6 +329,7 @@ if(${TRITON_ENABLE_AZURE_STORAGE}) target_include_directories( triton-core PRIVATE $ + PRIVATE $ ) endif() # TRITON_ENABLE_AZURE_STORAGE @@ -500,6 +503,7 @@ if(${TRITON_ENABLE_AZURE_STORAGE}) triton-core PRIVATE Azure::azure-storage-blobs + Azure::azure-identity ) endif() # TRITON_ENABLE_AZURE_STORAGE diff --git a/src/filesystem/implementations/as.h b/src/filesystem/implementations/as.h index fc449475a..3c1f57bdf 100644 --- a/src/filesystem/implementations/as.h +++ b/src/filesystem/implementations/as.h @@ -25,6 +25,8 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #pragma once +#include +#include #include #include @@ -39,9 +41,18 @@ namespace as = Azure::Storage; namespace asb = Azure::Storage::Blobs; const std::string AS_URL_PATTERN = "as://([^/]+)/([^/?]+)(?:/([^?]*))?(\\?.*)?"; +/// Supported authentication modes for Azure Storage access. +/// - "key": Shared Key (account key) — default, backwards-compatible. +/// - "managed_identity": Azure Managed Identity (system- or user-assigned). +/// - "default": Azure DefaultAzureCredential chain +/// (environment → managed identity → CLI → etc.). struct ASCredential { std::string account_str_; std::string account_key_; + /// Authentication type: "key" (default), "managed_identity", or "default". + std::string auth_type_; + /// Optional client ID for user-assigned Managed Identity. + std::string client_id_; ASCredential(); // from env var ASCredential(triton::common::TritonJson::Value& cred_json); @@ -54,17 +65,33 @@ ASCredential::ASCredential() }; const char* account_str = std::getenv("AZURE_STORAGE_ACCOUNT"); const char* account_key = std::getenv("AZURE_STORAGE_KEY"); + const char* auth_type = std::getenv("AZURE_STORAGE_AUTH_TYPE"); + const char* client_id = std::getenv("AZURE_STORAGE_CLIENT_ID"); account_str_ = to_str(account_str); account_key_ = to_str(account_key); + auth_type_ = to_str(auth_type); + client_id_ = to_str(client_id); + + // When no explicit auth type is set, infer from available credentials: + // if an account key is present use "key", otherwise remain empty (which + // the filesystem constructor treats as "key" for backwards compatibility). + if (auth_type_.empty() && !account_key_.empty()) { + auth_type_ = "key"; + } } ASCredential::ASCredential(triton::common::TritonJson::Value& cred_json) { - triton::common::TritonJson::Value account_str_json, account_key_json; + triton::common::TritonJson::Value account_str_json, account_key_json, + auth_type_json, client_id_json; if (cred_json.Find("account_str", &account_str_json)) account_str_json.AsString(&account_str_); if (cred_json.Find("account_key", &account_key_json)) account_key_json.AsString(&account_key_); + if (cred_json.Find("auth_type", &auth_type_json)) + auth_type_json.AsString(&auth_type_); + if (cred_json.Find("client_id", &client_id_json)) + client_id_json.AsString(&client_id_); } class ASFileSystem : public FileSystem { @@ -152,12 +179,45 @@ ASFileSystem::ASFileSystem(const std::string& path, const ASCredential& as_cred) std::string service_url( "https://" + account_name + ".blob.core.windows.net"); - if (!as_cred.account_key_.empty()) { - // Shared Key + if (as_cred.auth_type_ == "managed_identity") { + // Azure Managed Identity authentication (system- or user-assigned). + // Token caching and refresh are handled by the Azure Identity SDK. + LOG_VERBOSE(1) << "Using Azure Managed Identity authentication for " + << account_name; + std::shared_ptr token_cred; + if (!as_cred.client_id_.empty()) { + // User-assigned Managed Identity: specify the client ID. + Azure::Identity::ManagedIdentityCredentialOptions mi_opts; + mi_opts.ClientId = as_cred.client_id_; + token_cred = + std::make_shared( + mi_opts); + LOG_VERBOSE(1) << "Using user-assigned Managed Identity with client ID " + << as_cred.client_id_; + } else { + // System-assigned Managed Identity. + token_cred = + std::make_shared(); + LOG_VERBOSE(1) << "Using system-assigned Managed Identity"; + } + client_ = std::make_shared( + service_url, token_cred); + } else if (as_cred.auth_type_ == "default") { + // DefaultAzureCredential chains multiple credential sources: + // environment variables → managed identity → Azure CLI → etc. + LOG_VERBOSE(1) << "Using Azure DefaultAzureCredential for " + << account_name; + auto token_cred = + std::make_shared(); + client_ = std::make_shared( + service_url, token_cred); + } else if (!as_cred.account_key_.empty()) { + // Shared Key authentication (backwards-compatible default). auto cred = std::make_shared( account_name, as_cred.account_key_); client_ = std::make_shared(service_url, cred); } else { + // Anonymous access (no credential provided). client_ = std::make_shared(service_url); } } From de6bf5266b366eb1848e99b58313b5c92b1ab379 Mon Sep 17 00:00:00 2001 From: J Wyman Date: Fri, 6 Mar 2026 13:46:09 -0500 Subject: [PATCH 2/8] Apply suggestions from code review Co-authored-by: Yingge He <157551214+yinggeh@users.noreply.github.com> --- CMakeLists.txt | 2 +- src/CMakeLists.txt | 5 +++-- src/filesystem/implementations/as.h | 16 ++++++++-------- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 57d96654a..a9e9be9a8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2020-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2020-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index dd3a0a426..45b4759bd 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2019-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -103,8 +103,9 @@ endif() # if(${TRITON_ENABLE_AZURE_STORAGE}) find_package(azure-storage-blobs-cpp CONFIG REQUIRED) - find_package(azure-identity-cpp CONFIG REQUIRED) message(STATUS "Using Azure storage blobs ${azure-storage-blobs-cpp_VERSION}") + + find_package(azure-identity-cpp CONFIG REQUIRED) message(STATUS "Using Azure identity ${azure-identity-cpp_VERSION}") endif() diff --git a/src/filesystem/implementations/as.h b/src/filesystem/implementations/as.h index 3c1f57bdf..8ae429fa8 100644 --- a/src/filesystem/implementations/as.h +++ b/src/filesystem/implementations/as.h @@ -1,4 +1,4 @@ -// Copyright 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2019-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -48,9 +48,9 @@ const std::string AS_URL_PATTERN = "as://([^/]+)/([^/?]+)(?:/([^?]*))?(\\?.*)?"; /// (environment → managed identity → CLI → etc.). struct ASCredential { std::string account_str_; - std::string account_key_; /// Authentication type: "key" (default), "managed_identity", or "default". std::string auth_type_; + std::string account_key_; /// Optional client ID for user-assigned Managed Identity. std::string client_id_; @@ -64,12 +64,12 @@ ASCredential::ASCredential() return (s != nullptr ? std::string(s) : ""); }; const char* account_str = std::getenv("AZURE_STORAGE_ACCOUNT"); - const char* account_key = std::getenv("AZURE_STORAGE_KEY"); const char* auth_type = std::getenv("AZURE_STORAGE_AUTH_TYPE"); + const char* account_key = std::getenv("AZURE_STORAGE_KEY"); const char* client_id = std::getenv("AZURE_STORAGE_CLIENT_ID"); account_str_ = to_str(account_str); - account_key_ = to_str(account_key); auth_type_ = to_str(auth_type); + account_key_ = to_str(account_key); client_id_ = to_str(client_id); // When no explicit auth type is set, infer from available credentials: @@ -86,10 +86,10 @@ ASCredential::ASCredential(triton::common::TritonJson::Value& cred_json) auth_type_json, client_id_json; if (cred_json.Find("account_str", &account_str_json)) account_str_json.AsString(&account_str_); - if (cred_json.Find("account_key", &account_key_json)) - account_key_json.AsString(&account_key_); if (cred_json.Find("auth_type", &auth_type_json)) auth_type_json.AsString(&auth_type_); + if (cred_json.Find("account_key", &account_key_json)) + account_key_json.AsString(&account_key_); if (cred_json.Find("client_id", &client_id_json)) client_id_json.AsString(&client_id_); } @@ -200,7 +200,7 @@ ASFileSystem::ASFileSystem(const std::string& path, const ASCredential& as_cred) std::make_shared(); LOG_VERBOSE(1) << "Using system-assigned Managed Identity"; } - client_ = std::make_shared( + client_ = std::make_shared(service_url, token_cred); service_url, token_cred); } else if (as_cred.auth_type_ == "default") { // DefaultAzureCredential chains multiple credential sources: @@ -209,7 +209,7 @@ ASFileSystem::ASFileSystem(const std::string& path, const ASCredential& as_cred) << account_name; auto token_cred = std::make_shared(); - client_ = std::make_shared( + client_ = std::make_shared(service_url, token_cred); service_url, token_cred); } else if (!as_cred.account_key_.empty()) { // Shared Key authentication (backwards-compatible default). From 38b404e3ccadbbc51bfc725d06e861775d4be035 Mon Sep 17 00:00:00 2001 From: J Wyman Date: Fri, 6 Mar 2026 13:51:43 -0500 Subject: [PATCH 3/8] fixup pre-commit complaints --- src/filesystem/implementations/as.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/filesystem/implementations/as.h b/src/filesystem/implementations/as.h index 8ae429fa8..120f8c33a 100644 --- a/src/filesystem/implementations/as.h +++ b/src/filesystem/implementations/as.h @@ -200,8 +200,8 @@ ASFileSystem::ASFileSystem(const std::string& path, const ASCredential& as_cred) std::make_shared(); LOG_VERBOSE(1) << "Using system-assigned Managed Identity"; } - client_ = std::make_shared(service_url, token_cred); - service_url, token_cred); + client_ = + std::make_shared(service_url, token_cred); } else if (as_cred.auth_type_ == "default") { // DefaultAzureCredential chains multiple credential sources: // environment variables → managed identity → Azure CLI → etc. @@ -209,8 +209,8 @@ ASFileSystem::ASFileSystem(const std::string& path, const ASCredential& as_cred) << account_name; auto token_cred = std::make_shared(); - client_ = std::make_shared(service_url, token_cred); - service_url, token_cred); + client_ = + std::make_shared(service_url, token_cred); } else if (!as_cred.account_key_.empty()) { // Shared Key authentication (backwards-compatible default). auto cred = std::make_shared( From 6881a86d197989f65fef7ed85eb2e3fb7f9efe6d Mon Sep 17 00:00:00 2001 From: J Wyman Date: Fri, 6 Mar 2026 13:57:07 -0500 Subject: [PATCH 4/8] fixup cmake pre-commit complaints --- src/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 45b4759bd..39f2cad47 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -104,7 +104,7 @@ endif() if(${TRITON_ENABLE_AZURE_STORAGE}) find_package(azure-storage-blobs-cpp CONFIG REQUIRED) message(STATUS "Using Azure storage blobs ${azure-storage-blobs-cpp_VERSION}") - + find_package(azure-identity-cpp CONFIG REQUIRED) message(STATUS "Using Azure identity ${azure-identity-cpp_VERSION}") endif() From a18162d8b6ebc5c3203bf6c812c882d58ea66239 Mon Sep 17 00:00:00 2001 From: Damian Kowalczyk Date: Tue, 7 Apr 2026 13:43:24 -0700 Subject: [PATCH 5/8] fix: Use ManagedIdentityCredential(clientId) constructor ManagedIdentityCredentialOptions does not have a ClientId field. Use the constructor overload that accepts clientId as a string parameter instead. Also consolidate identity includes to the top-level azure/identity.hpp header and remove WIP comment. --- src/filesystem/implementations/as.h | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/filesystem/implementations/as.h b/src/filesystem/implementations/as.h index 120f8c33a..6f710cc23 100644 --- a/src/filesystem/implementations/as.h +++ b/src/filesystem/implementations/as.h @@ -25,13 +25,11 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #pragma once -#include -#include +#include #include #include #include "common.h" -// [WIP] below needed? #undef LOG_INFO #undef LOG_WARNING @@ -184,14 +182,13 @@ ASFileSystem::ASFileSystem(const std::string& path, const ASCredential& as_cred) // Token caching and refresh are handled by the Azure Identity SDK. LOG_VERBOSE(1) << "Using Azure Managed Identity authentication for " << account_name; - std::shared_ptr token_cred; + std::shared_ptr token_cred; if (!as_cred.client_id_.empty()) { - // User-assigned Managed Identity: specify the client ID. - Azure::Identity::ManagedIdentityCredentialOptions mi_opts; - mi_opts.ClientId = as_cred.client_id_; + // User-assigned Managed Identity: pass the client ID directly + // to the credential constructor. token_cred = std::make_shared( - mi_opts); + as_cred.client_id_); LOG_VERBOSE(1) << "Using user-assigned Managed Identity with client ID " << as_cred.client_id_; } else { From c5ec7ac5b58b42702e1733eb9846a08711ef101d Mon Sep 17 00:00:00 2001 From: mattwittwer Date: Tue, 14 Apr 2026 16:58:04 -0700 Subject: [PATCH 6/8] fix: RequestTracker counter mismatch (#483) * Fix RequestTracker counter mismatch in ScheduleSteps with parallel failures --- src/ensemble_scheduler/ensemble_scheduler.cc | 96 +++++++++++++++----- src/ensemble_scheduler/ensemble_scheduler.h | 2 + 2 files changed, 73 insertions(+), 25 deletions(-) diff --git a/src/ensemble_scheduler/ensemble_scheduler.cc b/src/ensemble_scheduler/ensemble_scheduler.cc index dd4f6b9b7..cc00d9b86 100644 --- a/src/ensemble_scheduler/ensemble_scheduler.cc +++ b/src/ensemble_scheduler/ensemble_scheduler.cc @@ -87,8 +87,24 @@ class RequestTracker { { } + // Accessed without additional synchronization while protected by + // EnsembleContext::mutex_. std::unique_ptr& Request() { return request_; } + // Used from paths where request_ may be released concurrently. + bool IsCancelled() + { + std::lock_guard lk(mtx_); + return (request_ == nullptr) || request_->IsCancelled(); + } + + std::string LogRequest() + { + std::lock_guard lk(mtx_); + return (request_ != nullptr) ? request_->LogRequest() + : std::string("[request released] "); + } + InferenceStatsAggregator* StatsAggregator() { return stats_aggregator_; } MetricModelReporter* MetricReporter() { return metric_reporter_; } @@ -141,6 +157,23 @@ class RequestTracker { status_ = status; } + void RespondIfError(const Status& status, FailureReason reason) + { + std::lock_guard lk(mtx_); + if (request_ != nullptr) { + InferenceRequest::RespondIfError( + request_, status, false /* release_request */, reason); + } + } + + void SendFlags(const uint32_t flags) + { + std::lock_guard lk(mtx_); + if (request_ != nullptr) { + request_->ResponseFactory()->SendFlags(flags); + } + } + private: std::mutex mtx_; uint32_t inflight_request_counter_; @@ -153,6 +186,7 @@ class RequestTracker { triton::common::ThreadPool* const callback_pool_; }; +using RequestTrackerReference = std::shared_ptr; // Step is used as 'userp' and keeps ensemble context alive // until no more internal requests are inflight. // Step contains metadata, and status for the @@ -188,6 +222,11 @@ struct Step { const bool preserve_responses_order_; size_t step_idx_; + + // Heap allocation passed as the release-callback userp. The allocation + // stores a shared_ptr so the tracker stays alive until the + // release callback or local failure cleanup drops this reference. + RequestTrackerReference* callback_tracker_ref_{nullptr}; }; struct TensorData { @@ -396,7 +435,7 @@ class EnsembleContext { // Objects related to the ensemble infer request Status ensemble_status_; - RequestTracker* request_tracker_; + std::shared_ptr request_tracker_; // Use in conjunction with 'is_decoupled_' in EnsembleInfo to // better distinguish ensemble ending behavior (see annotation in // FinishEnsemble for details). @@ -429,7 +468,7 @@ EnsembleContext::EnsembleContext( { uint64_t compute_start_ns = 0; INFER_STATS_SET_TIMESTAMP(compute_start_ns); - request_tracker_ = new RequestTracker( + request_tracker_ = std::make_shared( std::move(request), compute_start_ns, metric_reporter, stats_aggregator, callback_pool); @@ -646,16 +685,17 @@ void EnsembleContext::RequestComplete( TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp) { - auto request_tracker = reinterpret_cast(userp); + auto callback_tracker_ref = reinterpret_cast(userp); + auto request_tracker = *callback_tracker_ref; auto pool = request_tracker->CallbackPool(); - auto fn = [request, flags, request_tracker]() { + auto fn = [request, flags, request_tracker, callback_tracker_ref]() { if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) { + std::unique_ptr managed_callback_tracker_ref( + callback_tracker_ref); LOG_TRITONSERVER_ERROR( TRITONSERVER_InferenceRequestDelete(request), "deleting ensemble inference request"); - if (request_tracker->DecrementCounter()) { - delete request_tracker; - } + request_tracker->DecrementCounter(); } }; @@ -1070,12 +1110,19 @@ EnsembleContext::InitStep( irequest->SetSecondaryStatsAggregator( &request_tracker_->ContextStatsAggregator()); #endif + // Heap-allocate the release-callback userp because the C callback API only + // stores a raw void*. The heap object itself is single-owner here, while the + // object stored inside it is a shared_ptr that keeps the + // tracker alive until the callback or local cleanup drops this reference. + auto callback_tracker_ref = + std::make_unique(request_tracker_); irequest->SetResponseCallback( reinterpret_cast(allocator_.get()), step->get(), ResponseComplete, step->get()); - irequest->SetReleaseCallback(RequestComplete, request_tracker_); + irequest->SetReleaseCallback(RequestComplete, callback_tracker_ref.get()); RETURN_IF_ERROR(irequest->PrepareForInference()); + (*step)->callback_tracker_ref_ = callback_tracker_ref.release(); #ifdef TRITON_ENABLE_TRACING auto& parent_trace = request_tracker_->Request()->TraceProxy(); @@ -1220,16 +1267,14 @@ EnsembleContext::FinishEnsemble(std::unique_ptr&& response) ensemble_status_ = Status( Status::Code::INVALID_ARG, "in ensemble '" + info_->ensemble_name_ + "', " + - request_tracker_->Request()->LogRequest() + + request_tracker_->LogRequest() + "unexpected deadlock, at least one output is not set while no " "more " "ensemble steps can be made"); - InferenceRequest::RespondIfError( - request_tracker_->Request(), ensemble_status_, - false /* release_requests */, FailureReason::OTHER); + request_tracker_->RespondIfError( + ensemble_status_, FailureReason::OTHER); } else { - request_tracker_->Request()->ResponseFactory()->SendFlags( - TRITONSERVER_RESPONSE_COMPLETE_FINAL); + request_tracker_->SendFlags(TRITONSERVER_RESPONSE_COMPLETE_FINAL); } } } else { @@ -1239,9 +1284,8 @@ EnsembleContext::FinishEnsemble(std::unique_ptr&& response) std::move(response), TRITONSERVER_RESPONSE_COMPLETE_FINAL, ensemble_status_); } else { - InferenceRequest::RespondIfError( - request_tracker_->Request(), ensemble_status_, - false /* release_requests */, FailureReason::OTHER); + request_tracker_->RespondIfError( + ensemble_status_, FailureReason::OTHER); } error_response_sent_ = true; } @@ -1251,10 +1295,8 @@ EnsembleContext::FinishEnsemble(std::unique_ptr&& response) // Reach here when the ensemble execution comes to the end, // 'ensemble_status_' at this point is representative. request_tracker_->SetStatus(ensemble_status_); - if (request_tracker_->DecrementCounter()) { - delete request_tracker_; - } - request_tracker_ = nullptr; + request_tracker_->DecrementCounter(); + request_tracker_.reset(); } return ensemble_status_; } @@ -1436,7 +1478,7 @@ EnsembleContext::ScheduleSteps( if (should_schedule) { // If the ensemble request is cancelled, propagate the cancellation to the // next request step. - if (context->request_tracker_->Request()->IsCancelled()) { + if (context->request_tracker_->IsCancelled()) { step->request_->Cancel(); } // Acquire a slot from the per-step shared limiter only for steps that @@ -1469,7 +1511,6 @@ EnsembleContext::ScheduleSteps( // Reaching here means the step is not being scheduled, update corresponding // counters and attempt to finish ensemble if it is the last step. - // Release the limiter slot if one was acquired, and update counters. if (should_schedule && !context->info_->step_inflight_request_limiters_.empty()) { @@ -1477,8 +1518,13 @@ EnsembleContext::ScheduleSteps( } std::lock_guard lock(context->mutex_); - // Decrement only when IncrementCounter was called. An unconditional - // decrement would underflow the counter and cause a use-after-free. + // The request never reaches the callback-owned release path, so drop the + // heap-allocated callback userp here. + delete step->callback_tracker_ref_; + step->callback_tracker_ref_ = nullptr; + // Only undo IncrementCounter() for steps that actually reached the + // scheduling path. Otherwise the counter can underflow and release the + // top-level request while FinishEnsemble is still using it. if (should_schedule) { context->request_tracker_->DecrementCounter(); } diff --git a/src/ensemble_scheduler/ensemble_scheduler.h b/src/ensemble_scheduler/ensemble_scheduler.h index f3cbc2480..12b5562e5 100644 --- a/src/ensemble_scheduler/ensemble_scheduler.h +++ b/src/ensemble_scheduler/ensemble_scheduler.h @@ -27,7 +27,9 @@ #ifdef TRITON_ENABLE_ENSEMBLE +#include #include +#include #include "metric_model_reporter.h" #include "model.h" From 6037019144b3ed2f4fbe7b197ef955e6c3308e02 Mon Sep 17 00:00:00 2001 From: Yingge He <157551214+yinggeh@users.noreply.github.com> Date: Mon, 20 Apr 2026 11:27:36 -0700 Subject: [PATCH 7/8] fix: Safely handle filesystem exception (#491) --- src/backend_model.cc | 9 ++-- src/filesystem/api.cc | 36 +++++++++++----- src/filesystem/api.h | 12 +++--- src/model.cc | 5 ++- .../model_repository_manager.cc | 43 +++++++++++++------ 5 files changed, 74 insertions(+), 31 deletions(-) diff --git a/src/backend_model.cc b/src/backend_model.cc index c3b0fc2dc..81c865af8 100644 --- a/src/backend_model.cc +++ b/src/backend_model.cc @@ -363,9 +363,12 @@ TritonModel::GetBackendLibraryProperties( model_config->name() + "', searched: " + search_paths_str); } - if (IsChildPathEscapingParentPath( - *backend_libpath /* child_path */, - *backend_libdir /* parent_path */)) { + + bool is_escaped = false; + RETURN_IF_ERROR(IsChildPathEscapingParentPath( + *backend_libpath /* child_path */, *backend_libdir /* parent_path */, + &is_escaped)); + if (is_escaped) { return Status( Status::Code::INVALID_ARG, "backend library name '" + cpp_backend_libname + diff --git a/src/filesystem/api.cc b/src/filesystem/api.cc index 7b1d49c21..daf324759 100644 --- a/src/filesystem/api.cc +++ b/src/filesystem/api.cc @@ -1,4 +1,4 @@ -// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2019-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -399,17 +399,33 @@ IsAbsolutePath(const std::string& path) return !path.empty() && (path[0] == '/'); } -bool +Status IsChildPathEscapingParentPath( - const std::string& child_path, const std::string& parent_path) -{ - const std::string absolute_child_path = - std::filesystem::weakly_canonical(child_path).string(); - const std::string absolute_parent_path = - std::filesystem::canonical(parent_path).string(); + const std::string& child_path, const std::string& parent_path, + bool* is_escaped) +{ + std::string absolute_child_path; + std::string absolute_parent_path; + try { + absolute_child_path = + std::filesystem::weakly_canonical(child_path).string(); + } + catch (const std::exception& e) { + return Status( + Status::Code::INVALID_ARG, + "Invalid path '" + child_path + "': " + e.what()); + } + try { + absolute_parent_path = std::filesystem::canonical(parent_path).string(); + } + catch (const std::exception& e) { + return Status( + Status::Code::INVALID_ARG, + "Nonexistent path '" + parent_path + "': " + e.what()); + } // Can use starts_with() over rfind() in C++20. - bool is_escape = absolute_child_path.rfind(absolute_parent_path, 0) != 0; - return is_escape; + *is_escaped = absolute_child_path.rfind(absolute_parent_path, 0) != 0; + return Status::Success; } std::string diff --git a/src/filesystem/api.h b/src/filesystem/api.h index 1c5c5820e..1aac3fc77 100644 --- a/src/filesystem/api.h +++ b/src/filesystem/api.h @@ -1,4 +1,4 @@ -// Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2019-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -131,10 +131,12 @@ bool IsAbsolutePath(const std::string& path); /// Check if the child path escapes from its parent path. /// \param child_path The child path. /// \param parent_path The parent path. The path must exist. -/// \return true if the child path escapes from its parent path, false if the -/// child path is within its parent path. -bool IsChildPathEscapingParentPath( - const std::string& child_path, const std::string& parent_path); +/// \param is_escaped returns true if the child path escapes from its parent +/// path, false if the child path is within its parent path. \return Error +/// status +Status IsChildPathEscapingParentPath( + const std::string& child_path, const std::string& parent_path, + bool* is_escaped); /// Join path segments into a longer path /// \param segments The path segments. diff --git a/src/model.cc b/src/model.cc index 9e80bd9b4..86d871868 100644 --- a/src/model.cc +++ b/src/model.cc @@ -115,7 +115,10 @@ Model::Init(const bool is_config_provided) if (!io.label_filename().empty()) { auto label_path = JoinPath({model_dir_, io.label_filename()}); - if (IsChildPathEscapingParentPath(label_path, model_dir_)) { + bool is_escaped = false; + RETURN_IF_ERROR( + IsChildPathEscapingParentPath(label_path, model_dir_, &is_escaped)); + if (is_escaped) { return Status( Status::Code::UNSUPPORTED, "label file path '" + label_path + "' for output '" + io.name() + diff --git a/src/model_repository_manager/model_repository_manager.cc b/src/model_repository_manager/model_repository_manager.cc index 06940e907..b4d174db0 100644 --- a/src/model_repository_manager/model_repository_manager.cc +++ b/src/model_repository_manager/model_repository_manager.cc @@ -107,8 +107,16 @@ class LocalizeRepoAgent : public TritonRepoAgent { RETURN_TRITONSERVER_ERROR_IF_ERROR( agent_model->AcquireMutableLocation( TRITONREPOAGENT_ARTIFACT_FILESYSTEM, &temp_dir_cstr)); - const std::string temp_dir = - std::filesystem::canonical(temp_dir_cstr).string(); + std::string temp_dir; + try { + temp_dir = std::filesystem::canonical(temp_dir_cstr).string(); + } + catch (const std::exception& e) { + const std::string err_msg = std::string("Nonexistent path '") + + temp_dir_cstr + "': " + e.what(); + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, err_msg.c_str()); + } const auto& files = *reinterpret_cast*>( agent_model->State()); @@ -136,20 +144,31 @@ class LocalizeRepoAgent : public TritonRepoAgent { .c_str()); } + const std::string file_relpath = + file->Name().substr(file_prefix.size()); + std::string file_path; + try { + file_path = std::filesystem::weakly_canonical( + JoinPath({temp_dir, file_relpath})) + .string(); + } + catch (const std::exception& e) { + const std::string err_msg = + std::string("Invalid file parameter '") + file_relpath + + "': " + e.what(); + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, err_msg.c_str()); + } + // Resolve any relative paths or symlinks, and enforce that target // directory stays within model directory for security. - const std::string file_path = - std::filesystem::weakly_canonical( - JoinPath( - {temp_dir, file->Name().substr(file_prefix.size())})) - .string(); if (file_path.rfind(temp_dir, 0) != 0) { + const std::string msg = + std::string("Invalid file parameter '") + file->Name() + + "' with normalized path '" + file_path + + "' must stay within model directory."; return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("Invalid file parameter '") + file->Name() + - "' with normalized path '" + file_path + - "' must stay within model directory.") - .c_str()); + TRITONSERVER_ERROR_INVALID_ARG, msg.c_str()); } // Save model override file to the instructed directory using the From 0fc4a6df7bcb70b7fd1d127ea76eb1222ae36d83 Mon Sep 17 00:00:00 2001 From: "Damian K. Kowalczyk" Date: Tue, 21 Apr 2026 16:28:09 -0700 Subject: [PATCH 8/8] cc_model_filename no longer ignored --- .gitignore | 1 + src/model_config_utils.cc | 3 +- src/test/CMakeLists.txt | 50 ++++++++ src/test/model_config_utils_test.cc | 178 ++++++++++++++++++++++++++++ 4 files changed, 231 insertions(+), 1 deletion(-) create mode 100644 src/test/model_config_utils_test.cc diff --git a/.gitignore b/.gitignore index 9948de3b5..39aea3b5a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ /build +/build_test /.vscode *.so *__pycache__/ diff --git a/src/model_config_utils.cc b/src/model_config_utils.cc index 465ec0089..6c2c18f2e 100644 --- a/src/model_config_utils.cc +++ b/src/model_config_utils.cc @@ -1273,7 +1273,8 @@ AutoCompleteBackendFields( } } if (config->backend() == kPythonBackend) { - if (config->default_model_filename().empty()) { + if (config->default_model_filename().empty() && + config->cc_model_filenames().empty()) { config->set_default_model_filename(kPythonFilename); } return Status::Success; diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index f0225b5f8..0917a97fa 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -439,6 +439,56 @@ install( RUNTIME DESTINATION bin ) +# +# Unit test for AutoCompleteBackendFields in model_config_utils +# +add_executable( + model_config_utils_test + model_config_utils_test.cc + ../model_config_utils.cc + ../status.cc + ../filesystem/api.cc + ../model_config_utils.h + ../status.h + ../filesystem/api.h +) + +set_target_properties( + model_config_utils_test + PROPERTIES + SKIP_BUILD_RPATH TRUE + BUILD_WITH_INSTALL_RPATH TRUE + INSTALL_RPATH_USE_LINK_PATH FALSE + INSTALL_RPATH "" +) + +target_include_directories( + model_config_utils_test + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/.. + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${GTEST_INCLUDE_DIRS} + ${Boost_INCLUDE_DIRS} +) + +target_link_libraries( + model_config_utils_test + PRIVATE + triton-common-error # from repo-common + triton-common-model-config # from repo-common + triton-common-json # from repo-common + triton-common-logging # from repo-common + proto-library # from repo-common + GTest::gtest + GTest::gtest_main + protobuf::libprotobuf +) + +install( + TARGETS model_config_utils_test + RUNTIME DESTINATION bin +) + if(${TRITON_ENABLE_METRICS}) # diff --git a/src/test/model_config_utils_test.cc b/src/test/model_config_utils_test.cc new file mode 100644 index 000000000..af3a7b546 --- /dev/null +++ b/src/test/model_config_utils_test.cc @@ -0,0 +1,178 @@ +// Copyright 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "model_config_utils.h" + +#include + +#include +#include + +#include "constants.h" +#include "filesystem/api.h" +#include "gtest/gtest.h" + +namespace tc = triton::core; + +namespace { + +// Helper to create a temporary model directory with a version subdirectory +// and an optional file inside it. +class TempModelDir { + public: + TempModelDir() + { + auto status = + tc::MakeTemporaryDirectory(tc::FileSystemType::LOCAL, &root_path_); + EXPECT_TRUE(status.IsOk()) << status.AsString(); + } + + ~TempModelDir() + { + // Best-effort cleanup + std::string cmd = "rm -rf " + root_path_; + (void)system(cmd.c_str()); + } + + // Create version subdir (e.g., "1") and optionally place a file in it. + void AddVersionWithFile( + const std::string& version, const std::string& filename) + { + std::string version_dir = tc::JoinPath({root_path_, version}); + mkdir(version_dir.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH); + if (!filename.empty()) { + std::ofstream f(tc::JoinPath({version_dir, filename})); + f << "# placeholder"; + } + } + + const std::string& Path() const { return root_path_; } + + private: + std::string root_path_; +}; + +class AutoCompleteBackendFieldsTest : public ::testing::Test {}; + +// When backend is "python" and default_model_filename is empty and +// cc_model_filenames is empty, default_model_filename should be set to +// "model.py". +TEST_F(AutoCompleteBackendFieldsTest, PythonBackendSetsDefaultFilename) +{ + TempModelDir dir; + dir.AddVersionWithFile("1", "model.py"); + + inference::ModelConfig config; + config.set_backend("python"); + // default_model_filename and cc_model_filenames are both empty + + auto status = + tc::AutoCompleteBackendFields("test_model", dir.Path(), &config); + ASSERT_TRUE(status.IsOk()) << status.AsString(); + EXPECT_EQ(config.default_model_filename(), "model.py"); +} + +// When backend is "python" and default_model_filename is empty but +// cc_model_filenames is populated, default_model_filename should NOT be +// auto-filled to "model.py". +TEST_F( + AutoCompleteBackendFieldsTest, + PythonBackendSkipsDefaultFilenameWhenCcModelFilenamesSet) +{ + TempModelDir dir; + dir.AddVersionWithFile("1", "custom_model.py"); + + inference::ModelConfig config; + config.set_backend("python"); + (*config.mutable_cc_model_filenames())["gpu"] = "custom_model.py"; + // default_model_filename is empty, cc_model_filenames is set + + auto status = + tc::AutoCompleteBackendFields("test_model", dir.Path(), &config); + ASSERT_TRUE(status.IsOk()) << status.AsString(); + EXPECT_EQ(config.default_model_filename(), "") + << "default_model_filename should remain empty when cc_model_filenames " + "is set"; +} + +// When backend is "python" and default_model_filename is already set, +// it should be preserved regardless of cc_model_filenames. +TEST_F( + AutoCompleteBackendFieldsTest, + PythonBackendPreservesExplicitDefaultFilename) +{ + TempModelDir dir; + dir.AddVersionWithFile("1", "my_model.py"); + + inference::ModelConfig config; + config.set_backend("python"); + config.set_default_model_filename("my_model.py"); + + auto status = + tc::AutoCompleteBackendFields("test_model", dir.Path(), &config); + ASSERT_TRUE(status.IsOk()) << status.AsString(); + EXPECT_EQ(config.default_model_filename(), "my_model.py"); +} + +// When backend is "python" and both default_model_filename and +// cc_model_filenames are set, both should be preserved as-is. +TEST_F( + AutoCompleteBackendFieldsTest, + PythonBackendPreservesBothDefaultAndCcModelFilenames) +{ + TempModelDir dir; + dir.AddVersionWithFile("1", "my_model.py"); + + inference::ModelConfig config; + config.set_backend("python"); + config.set_default_model_filename("my_model.py"); + (*config.mutable_cc_model_filenames())["gpu"] = "gpu_model.py"; + + auto status = + tc::AutoCompleteBackendFields("test_model", dir.Path(), &config); + ASSERT_TRUE(status.IsOk()) << status.AsString(); + EXPECT_EQ(config.default_model_filename(), "my_model.py"); + EXPECT_EQ(config.cc_model_filenames().at("gpu"), "gpu_model.py"); +} + +// When backend is empty but version dir contains model.py, backend should be +// auto-detected as "python" and default_model_filename set to "model.py". +TEST_F(AutoCompleteBackendFieldsTest, AutoDetectPythonBackendFromModelFile) +{ + TempModelDir dir; + dir.AddVersionWithFile("1", "model.py"); + + inference::ModelConfig config; + // backend, platform, default_model_filename all empty + + auto status = + tc::AutoCompleteBackendFields("test_model", dir.Path(), &config); + ASSERT_TRUE(status.IsOk()) << status.AsString(); + EXPECT_EQ(config.backend(), "python"); + EXPECT_EQ(config.default_model_filename(), "model.py"); +} + +} // namespace