From dcdbec6650e0bb2d21ad7c59d8ca41d32b59ed7c Mon Sep 17 00:00:00 2001 From: Indrajit Bhosale Date: Tue, 18 Feb 2025 12:10:47 -0800 Subject: [PATCH 1/8] Commit for Draft Changes --- src/grpc/grpc_server.cc | 733 ++++++++++++++++++++-------------------- src/grpc/grpc_server.h | 8 +- 2 files changed, 378 insertions(+), 363 deletions(-) diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index 74ec443ae6..836e32d11e 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -1,4 +1,4 @@ -// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -79,6 +79,57 @@ namespace { // are deemed to be not performance critical. //========================================================================= +template +class CommonCallbackData { + public: + using CallbackFunc = + std::function; + + CommonCallbackData( + const std::string& name, + inference::GRPCInferenceService::CallbackService* service, + const CallbackFunc& callback, + const std::pair& restricted_kv) + : name_(name), service_(service), callback_(callback), + restricted_kv_(restricted_kv) + { + } + + void operator()(RequestType* request) + { + ResponseType response; + ::grpc::Status status; + + if (ExecutePrecondition()) { + callback_(*request, &response, &status); + } else { + status = ::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + std::string("This protocol is restricted, expecting header '") + + restricted_kv_.first + "'"); + } + + request->request()->Complete(status); + delete this; + } + + private: + bool ExecutePrecondition() + { + if (!restricted_kv_.first.empty()) { + const auto& metadata = request->context()->client_metadata(); + const auto it = metadata.find(restricted_kv_.first); + return (it != metadata.end()) && (it->second == restricted_kv_.second); + } + return true; + } + + const std::string name_; + inference::GRPCInferenceService::CallbackService* service_; + CallbackFunc callback_; + std::pair restricted_kv_; +}; + template class CommonCallData : public ICallData { public: @@ -264,7 +315,8 @@ class CommonHandler : public HandlerBase { TraceManager* trace_manager, inference::GRPCInferenceService::AsyncService* service, ::grpc::health::v1::Health::AsyncService* health_service, - ::grpc::ServerCompletionQueue* cq, + inference::GRPCInferenceService::CallbackService* + non_inference_callback_service, const RestrictedFeatures& restricted_keys, const uint64_t response_delay); // Descriptive name of of the handler. @@ -315,6 +367,9 @@ class CommonHandler : public HandlerBase { inference::GRPCInferenceService::AsyncService* service_; ::grpc::health::v1::Health::AsyncService* health_service_; + inference::GRPCInferenceService::CallbackService* + non_inference_callback_service_; + ::grpc::ServerCompletionQueue* cq_; std::unique_ptr thread_; RestrictedFeatures restricted_keys_{}; @@ -333,7 +388,8 @@ CommonHandler::CommonHandler( const uint64_t response_delay = 0) : name_(name), tritonserver_(tritonserver), shm_manager_(shm_manager), trace_manager_(trace_manager), service_(service), - health_service_(health_service), cq_(cq), + health_service_(health_service), + non_inference_callback_service_(non_inference_callback_service), cq_(cq), restricted_keys_(restricted_keys), response_delay_(response_delay) { } @@ -464,23 +520,18 @@ CommonHandler::RegisterServerLive() false /* async */, cq_, restricted_kv, response_delay_); } +// This change leverages the callback API, simplifying the handling of the +// ServerReady request by directly using the non_inference_callback_service_. void CommonHandler::RegisterServerReady() { - auto OnRegisterServerReady = - [this]( - ::grpc::ServerContext* ctx, inference::ServerReadyRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestServerReady( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteServerReady = [this]( - inference::ServerReadyRequest& request, - inference::ServerReadyResponse* response, - ::grpc::Status* status) { + // Define a lambda function 'callback' that takes a ServerReadyRequest, + // a ServerReadyResponse, and a grpc::Status. This function performs + // the same logic as the original OnExecuteServerReady function. + auto callback = [this]( + inference::ServerReadyRequest& request, + inference::ServerReadyResponse* response, + ::grpc::Status* status) { bool ready = false; TRITONSERVER_Error* err = TRITONSERVER_ServerIsReady(tritonserver_.get(), &ready); @@ -493,33 +544,25 @@ CommonHandler::RegisterServerReady() const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::HEALTH); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::ServerReadyRequest, inference::ServerReadyResponse>( - "ServerReady", 0, OnRegisterServerReady, OnExecuteServerReady, - false /* async */, cq_, restricted_kv, response_delay_); + + // Use non_inference_callback_service_->ServerReady to register the callback. + // This replaces the use of CommonCallData. + non_inference_callback_service_->ServerReady( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::ServerReadyRequest, inference::ServerReadyResponse>( + "ServerReady", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterHealthCheck() { - auto OnRegisterHealthCheck = - [this]( - ::grpc::ServerContext* ctx, - ::grpc::health::v1::HealthCheckRequest* request, - ::grpc::ServerAsyncResponseWriter< - ::grpc::health::v1::HealthCheckResponse>* responder, - void* tag) { - this->health_service_->RequestCheck( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteHealthCheck = [this]( - ::grpc::health::v1::HealthCheckRequest& - request, - ::grpc::health::v1::HealthCheckResponse* - response, - ::grpc::Status* status) { + auto callback = [this]( + ::grpc::health::v1::HealthCheckRequest& request, + ::grpc::health::v1::HealthCheckResponse* response, + ::grpc::Status* status) { bool live = false; TRITONSERVER_Error* err = TRITONSERVER_ServerIsReady(tritonserver_.get(), &live); @@ -540,32 +583,21 @@ CommonHandler::RegisterHealthCheck() const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::HEALTH); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - ::grpc::health::v1::HealthCheckResponse>, - ::grpc::health::v1::HealthCheckRequest, - ::grpc::health::v1::HealthCheckResponse>( - "Check", 0, OnRegisterHealthCheck, OnExecuteHealthCheck, - false /* async */, cq_, restricted_kv, response_delay_); + + non_inference_callback_service_->Check( + new CommonCallbackData< + ::grpc::health::v1::HealthCheckRequest, + ::grpc::health::v1::HealthCheckResponse>( + "Check", non_inference_callback_service_, callback, restricted_kv)); } void CommonHandler::RegisterModelReady() { - auto OnRegisterModelReady = - [this]( - ::grpc::ServerContext* ctx, inference::ModelReadyRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestModelReady( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteModelReady = [this]( - inference::ModelReadyRequest& request, - inference::ModelReadyResponse* response, - ::grpc::Status* status) { + auto callback = [this]( + ::grpc::health::v1::HealthCheckRequest& request, + ::grpc::health::v1::HealthCheckResponse* response, + ::grpc::Status* status) { bool is_ready = false; int64_t requested_model_version; auto err = @@ -581,335 +613,314 @@ CommonHandler::RegisterModelReady() GrpcStatusUtil::Create(status, err); TRITONSERVER_ErrorDelete(err); }; - const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::HEALTH); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::ModelReadyRequest, inference::ModelReadyResponse>( - "ModelReady", 0, OnRegisterModelReady, OnExecuteModelReady, - false /* async */, cq_, restricted_kv, response_delay_); + non_inference_callback_service_->ModelReady( + new CommonCallbackData< + ::grpc::health::v1::HealthCheckRequest, + ::grpc::health::v1::HealthCheckResponse>( + "ModelReady", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterServerMetadata() { - auto OnRegisterServerMetadata = - [this]( - ::grpc::ServerContext* ctx, inference::ServerMetadataRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestServerMetadata( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteServerMetadata = - [this]( - inference::ServerMetadataRequest& request, - inference::ServerMetadataResponse* response, ::grpc::Status* status) { - TRITONSERVER_Message* server_metadata_message = nullptr; - TRITONSERVER_Error* err = TRITONSERVER_ServerMetadata( - tritonserver_.get(), &server_metadata_message); - GOTO_IF_ERR(err, earlyexit); - - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - server_metadata_message, &buffer, &byte_size); - GOTO_IF_ERR(err, earlyexit); - - { - triton::common::TritonJson::Value server_metadata_json; - err = server_metadata_json.Parse(buffer, byte_size); - GOTO_IF_ERR(err, earlyexit); - + // Define a lambda function 'callback' that takes a ServerMetadataRequest, + // a ServerMetadataResponse, and a grpc::Status. This function performs + // the same logic as the original OnExecuteServerMetadata function. + auto callback = [this]( + inference::ServerMetadataRequest& request, + inference::ServerMetadataResponse* response, + ::grpc::Status* status) { + TRITONSERVER_Message* server_metadata_message = nullptr; + TRITONSERVER_Error* err = TRITONSERVER_ServerMetadata( + tritonserver_.get(), &server_metadata_message); + if (err == nullptr) { + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + server_metadata_message, &buffer, &byte_size); + if (err == nullptr) { + triton::common::TritonJson::Value server_metadata_json; + err = server_metadata_json.Parse(buffer, byte_size); + if (err == nullptr) { const char* name; size_t namelen; err = server_metadata_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); - - const char* version; - size_t versionlen; - err = server_metadata_json.MemberAsString( - "version", &version, &versionlen); - GOTO_IF_ERR(err, earlyexit); - - response->set_name(std::string(name, namelen)); - response->set_version(std::string(version, versionlen)); - - if (server_metadata_json.Find("extensions")) { - triton::common::TritonJson::Value extensions_json; - err = server_metadata_json.MemberAsArray( - "extensions", &extensions_json); - GOTO_IF_ERR(err, earlyexit); - - for (size_t idx = 0; idx < extensions_json.ArraySize(); ++idx) { - const char* ext; - size_t extlen; - err = extensions_json.IndexAsString(idx, &ext, &extlen); - GOTO_IF_ERR(err, earlyexit); - response->add_extensions(std::string(ext, extlen)); + if (err == nullptr) { + const char* version; + size_t versionlen; + err = server_metadata_json.MemberAsString( + "version", &version, &versionlen); + if (err == nullptr) { + response->set_name(std::string(name, namelen)); + response->set_version(std::string(version, versionlen)); + + if (server_metadata_json.Find("extensions")) { + triton::common::TritonJson::Value extensions_json; + err = server_metadata_json.MemberAsArray( + "extensions", &extensions_json); + if (err == nullptr) { + for (size_t idx = 0; idx < extensions_json.ArraySize(); + ++idx) { + const char* ext; + size_t extlen; + err = extensions_json.IndexAsString(idx, &ext, &extlen); + if (err == nullptr) { + response->add_extensions(std::string(ext, extlen)); + } + } + } + } } } - TRITONSERVER_MessageDelete(server_metadata_message); } + } + TRITONSERVER_MessageDelete(server_metadata_message); + } - earlyexit: - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; + GrpcStatusUtil::Create(status, err); + TRITONSERVER_ErrorDelete(err); + }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::METADATA); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::ServerMetadataRequest, inference::ServerMetadataResponse>( - "ServerMetadata", 0, OnRegisterServerMetadata, OnExecuteServerMetadata, - false /* async */, cq_, restricted_kv, response_delay_); + + // Use non_inference_callback_service_->ServerMetadata to register the + // callback. This replaces the use of CommonCallData. + non_inference_callback_service_->ServerMetadata( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::ServerMetadataRequest, inference::ServerMetadataResponse>( + "ServerMetadata", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterModelMetadata() { - auto OnRegisterModelMetadata = - [this]( - ::grpc::ServerContext* ctx, inference::ModelMetadataRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestModelMetadata( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteModelMetadata = [this]( - inference::ModelMetadataRequest& request, - inference::ModelMetadataResponse* response, - ::grpc::Status* status) { + // Define a lambda function 'callback' that takes a ModelMetadataRequest, + // a ModelMetadataResponse, and a grpc::Status. This function performs + // the same logic as the original OnExecuteModelMetadata function. + auto callback = [this]( + inference::ModelMetadataRequest& request, + inference::ModelMetadataResponse* response, + ::grpc::Status* status) { int64_t requested_model_version; auto err = GetModelVersionFromString(request.version(), &requested_model_version); GOTO_IF_ERR(err, earlyexit); - { - TRITONSERVER_Message* model_metadata_message = nullptr; - err = TRITONSERVER_ServerModelMetadata( - tritonserver_.get(), request.name().c_str(), requested_model_version, - &model_metadata_message); - GOTO_IF_ERR(err, earlyexit); + TRITONSERVER_Message* model_metadata_message = nullptr; + err = TRITONSERVER_ServerModelMetadata( + tritonserver_.get(), request.name().c_str(), requested_model_version, + &model_metadata_message); + GOTO_IF_ERR(err, earlyexit); - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - model_metadata_message, &buffer, &byte_size); - GOTO_IF_ERR(err, earlyexit); + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + model_metadata_message, &buffer, &byte_size); + GOTO_IF_ERR(err, earlyexit); - triton::common::TritonJson::Value model_metadata_json; - err = model_metadata_json.Parse(buffer, byte_size); - GOTO_IF_ERR(err, earlyexit); + triton::common::TritonJson::Value model_metadata_json; + err = model_metadata_json.Parse(buffer, byte_size); + GOTO_IF_ERR(err, earlyexit); - const char* name; - size_t namelen; - err = model_metadata_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); + const char* name; + size_t namelen; + err = model_metadata_json.MemberAsString("name", &name, &namelen); + GOTO_IF_ERR(err, earlyexit); - response->set_name(std::string(name, namelen)); + response->set_name(std::string(name, namelen)); - if (model_metadata_json.Find("versions")) { - triton::common::TritonJson::Value versions_json; - err = model_metadata_json.MemberAsArray("versions", &versions_json); - GOTO_IF_ERR(err, earlyexit); + if (model_metadata_json.Find("versions")) { + triton::common::TritonJson::Value versions_json; + err = model_metadata_json.MemberAsArray("versions", &versions_json); + GOTO_IF_ERR(err, earlyexit); - for (size_t idx = 0; idx < versions_json.ArraySize(); ++idx) { - const char* version; - size_t versionlen; - err = versions_json.IndexAsString(idx, &version, &versionlen); - GOTO_IF_ERR(err, earlyexit); - response->add_versions(std::string(version, versionlen)); - } + for (size_t idx = 0; idx < versions_json.ArraySize(); ++idx) { + const char* version; + size_t versionlen; + err = versions_json.IndexAsString(idx, &version, &versionlen); + GOTO_IF_ERR(err, earlyexit); + response->add_versions(std::string(version, versionlen)); } + } + + const char* platform; + size_t platformlen; + err = + model_metadata_json.MemberAsString("platform", &platform, &platformlen); + GOTO_IF_ERR(err, earlyexit); + response->set_platform(std::string(platform, platformlen)); - const char* platform; - size_t platformlen; - err = model_metadata_json.MemberAsString( - "platform", &platform, &platformlen); + if (model_metadata_json.Find("inputs")) { + triton::common::TritonJson::Value inputs_json; + err = model_metadata_json.MemberAsArray("inputs", &inputs_json); GOTO_IF_ERR(err, earlyexit); - response->set_platform(std::string(platform, platformlen)); - if (model_metadata_json.Find("inputs")) { - triton::common::TritonJson::Value inputs_json; - err = model_metadata_json.MemberAsArray("inputs", &inputs_json); + for (size_t idx = 0; idx < inputs_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value io_json; + err = inputs_json.IndexAsObject(idx, &io_json); GOTO_IF_ERR(err, earlyexit); - for (size_t idx = 0; idx < inputs_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value io_json; - err = inputs_json.IndexAsObject(idx, &io_json); - GOTO_IF_ERR(err, earlyexit); + inference::ModelMetadataResponse::TensorMetadata* io = + response->add_inputs(); - inference::ModelMetadataResponse::TensorMetadata* io = - response->add_inputs(); + const char* name; + size_t namelen; + err = io_json.MemberAsString("name", &name, &namelen); + GOTO_IF_ERR(err, earlyexit); - const char* name; - size_t namelen; - err = io_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); + const char* datatype; + size_t datatypelen; + err = io_json.MemberAsString("datatype", &datatype, &datatypelen); + GOTO_IF_ERR(err, earlyexit); - const char* datatype; - size_t datatypelen; - err = io_json.MemberAsString("datatype", &datatype, &datatypelen); - GOTO_IF_ERR(err, earlyexit); + io->set_name(std::string(name, namelen)); + io->set_datatype(std::string(datatype, datatypelen)); - io->set_name(std::string(name, namelen)); - io->set_datatype(std::string(datatype, datatypelen)); + if (io_json.Find("shape")) { + triton::common::TritonJson::Value shape_json; + err = io_json.MemberAsArray("shape", &shape_json); + GOTO_IF_ERR(err, earlyexit); - if (io_json.Find("shape")) { - triton::common::TritonJson::Value shape_json; - err = io_json.MemberAsArray("shape", &shape_json); + for (size_t sidx = 0; sidx < shape_json.ArraySize(); ++sidx) { + int64_t d; + err = shape_json.IndexAsInt(sidx, &d); GOTO_IF_ERR(err, earlyexit); - - for (size_t sidx = 0; sidx < shape_json.ArraySize(); ++sidx) { - int64_t d; - err = shape_json.IndexAsInt(sidx, &d); - GOTO_IF_ERR(err, earlyexit); - - io->add_shape(d); - } + io->add_shape(d); } } } + } + + if (model_metadata_json.Find("outputs")) { + triton::common::TritonJson::Value outputs_json; + err = model_metadata_json.MemberAsArray("outputs", &outputs_json); + GOTO_IF_ERR(err, earlyexit); - if (model_metadata_json.Find("outputs")) { - triton::common::TritonJson::Value outputs_json; - err = model_metadata_json.MemberAsArray("outputs", &outputs_json); + for (size_t idx = 0; idx < outputs_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value io_json; + err = outputs_json.IndexAsObject(idx, &io_json); GOTO_IF_ERR(err, earlyexit); - for (size_t idx = 0; idx < outputs_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value io_json; - err = outputs_json.IndexAsObject(idx, &io_json); - GOTO_IF_ERR(err, earlyexit); + inference::ModelMetadataResponse::TensorMetadata* io = + response->add_outputs(); + + const char* name; + size_t namelen; + err = io_json.MemberAsString("name", &name, &namelen); + GOTO_IF_ERR(err, earlyexit); - inference::ModelMetadataResponse::TensorMetadata* io = - response->add_outputs(); + const char* datatype; + size_t datatypelen; + err = io_json.MemberAsString("datatype", &datatype, &datatypelen); + GOTO_IF_ERR(err, earlyexit); - const char* name; - size_t namelen; - err = io_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); + io->set_name(std::string(name, namelen)); + io->set_datatype(std::string(datatype, datatypelen)); - const char* datatype; - size_t datatypelen; - err = io_json.MemberAsString("datatype", &datatype, &datatypelen); + if (io_json.Find("shape")) { + triton::common::TritonJson::Value shape_json; + err = io_json.MemberAsArray("shape", &shape_json); GOTO_IF_ERR(err, earlyexit); - io->set_name(std::string(name, namelen)); - io->set_datatype(std::string(datatype, datatypelen)); - - if (io_json.Find("shape")) { - triton::common::TritonJson::Value shape_json; - err = io_json.MemberAsArray("shape", &shape_json); + for (size_t sidx = 0; sidx < shape_json.ArraySize(); ++sidx) { + int64_t d; + err = shape_json.IndexAsInt(sidx, &d); GOTO_IF_ERR(err, earlyexit); - - for (size_t sidx = 0; sidx < shape_json.ArraySize(); ++sidx) { - int64_t d; - err = shape_json.IndexAsInt(sidx, &d); - GOTO_IF_ERR(err, earlyexit); - - io->add_shape(d); - } + io->add_shape(d); } } } - - TRITONSERVER_MessageDelete(model_metadata_message); } earlyexit: + TRITONSERVER_MessageDelete(model_metadata_message); GrpcStatusUtil::Create(status, err); TRITONSERVER_ErrorDelete(err); }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::METADATA); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::ModelMetadataRequest, inference::ModelMetadataResponse>( - "ModelMetadata", 0, OnRegisterModelMetadata, OnExecuteModelMetadata, - false /* async */, cq_, restricted_kv, response_delay_); + + // Use non_inference_callback_service_->ModelMetadata to register the + // callback. This replaces the use of CommonCallData. + non_inference_callback_service_->ModelMetadata( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::ModelMetadataRequest, inference::ModelMetadataResponse>( + "ModelMetadata", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterModelConfig() { - auto OnRegisterModelConfig = - [this]( - ::grpc::ServerContext* ctx, inference::ModelConfigRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestModelConfig( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteModelConfig = [this]( - inference::ModelConfigRequest& request, - inference::ModelConfigResponse* response, - ::grpc::Status* status) { + // Define a lambda function 'callback' that takes a ModelConfigRequest, + // a ModelConfigResponse, and a grpc::Status. This function performs + // the same logic as the original OnExecuteModelConfig function. + auto callback = [this]( + inference::ModelConfigRequest& request, + inference::ModelConfigResponse* response, + ::grpc::Status* status) { int64_t requested_model_version; auto err = GetModelVersionFromString(request.version(), &requested_model_version); - if (err == nullptr) { - TRITONSERVER_Message* model_config_message = nullptr; - err = TRITONSERVER_ServerModelConfig( - tritonserver_.get(), request.name().c_str(), requested_model_version, - 1 /* config_version */, &model_config_message); - if (err == nullptr) { - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - model_config_message, &buffer, &byte_size); - if (err == nullptr) { - ::google::protobuf::util::JsonStringToMessage( - ::google::protobuf::stringpiece_internal::StringPiece( - buffer, (int)byte_size), - response->mutable_config()); - } - TRITONSERVER_MessageDelete(model_config_message); - } - } + GOTO_IF_ERR(err, earlyexit); + TRITONSERVER_Message* model_config_message = nullptr; + err = TRITONSERVER_ServerModelConfig( + tritonserver_.get(), request.name().c_str(), requested_model_version, + 1 /* config_version */, &model_config_message); + GOTO_IF_ERR(err, earlyexit); + + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + model_config_message, &buffer, &byte_size); + GOTO_IF_ERR(err, earlyexit); + + ::google::protobuf::util::JsonStringToMessage( + ::google::protobuf::stringpiece_internal::StringPiece( + buffer, static_cast(byte_size)), + response->mutable_config()); + + earlyexit: + TRITONSERVER_MessageDelete(model_config_message); GrpcStatusUtil::Create(status, err); TRITONSERVER_ErrorDelete(err); }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::MODEL_CONFIG); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::ModelConfigRequest, inference::ModelConfigResponse>( - "ModelConfig", 0, OnRegisterModelConfig, OnExecuteModelConfig, - false /* async */, cq_, restricted_kv, response_delay_); + + // Use non_inference_callback_service_->ModelConfig to register the callback. + // This replaces the use of CommonCallData. + non_inference_callback_service_->ModelConfig( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::ModelConfigRequest, inference::ModelConfigResponse>( + "ModelConfig", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterModelStatistics() { - auto OnRegisterModelStatistics = - [this]( - ::grpc::ServerContext* ctx, - inference::ModelStatisticsRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestModelStatistics( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteModelStatistics = [this]( - inference::ModelStatisticsRequest& - request, - inference::ModelStatisticsResponse* - response, - ::grpc::Status* status) { + // Define a lambda function 'callback' that takes a ModelStatisticsRequest, + // a ModelStatisticsResponse, and a grpc::Status. This function performs + // the same logic as the original OnExecuteModelStatistics function. + auto callback = [this]( + inference::ModelStatisticsRequest& request, + inference::ModelStatisticsResponse* response, + ::grpc::Status* status) { #ifdef TRITON_ENABLE_STATS triton::common::TritonJson::Value model_stats_json; @@ -918,24 +929,22 @@ CommonHandler::RegisterModelStatistics() GetModelVersionFromString(request.version(), &requested_model_version); GOTO_IF_ERR(err, earlyexit); - { - TRITONSERVER_Message* model_stats_message = nullptr; - err = TRITONSERVER_ServerModelStatistics( - tritonserver_.get(), request.name().c_str(), requested_model_version, - &model_stats_message); - GOTO_IF_ERR(err, earlyexit); + TRITONSERVER_Message* model_stats_message = nullptr; + err = TRITONSERVER_ServerModelStatistics( + tritonserver_.get(), request.name().c_str(), requested_model_version, + &model_stats_message); + GOTO_IF_ERR(err, earlyexit); - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - model_stats_message, &buffer, &byte_size); - GOTO_IF_ERR(err, earlyexit); + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + model_stats_message, &buffer, &byte_size); + GOTO_IF_ERR(err, earlyexit); - err = model_stats_json.Parse(buffer, byte_size); - GOTO_IF_ERR(err, earlyexit); + err = model_stats_json.Parse(buffer, byte_size); + GOTO_IF_ERR(err, earlyexit); - TRITONSERVER_MessageDelete(model_stats_message); - } + TRITONSERVER_MessageDelete(model_stats_message); if (model_stats_json.Find("model_stats")) { triton::common::TritonJson::Value stats_json; @@ -1133,11 +1142,17 @@ CommonHandler::RegisterModelStatistics() const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::STATISTICS); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::ModelStatisticsRequest, inference::ModelStatisticsResponse>( - "ModelStatistics", 0, OnRegisterModelStatistics, OnExecuteModelStatistics, - false /* async */, cq_, restricted_kv, response_delay_); + + // Use non_inference_callback_service_->ModelStatistics to register the + // callback. This replaces the use of CommonCallData. + non_inference_callback_service_->ModelStatistics( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::ModelStatisticsRequest, + inference::ModelStatisticsResponse>( + "ModelStatistics", non_inference_callback_service_, callback, + restricted_kv)); } template @@ -1163,20 +1178,13 @@ CommonHandler::SetStatisticsDuration( void CommonHandler::RegisterTrace() { - auto OnRegisterTrace = - [this]( - ::grpc::ServerContext* ctx, inference::TraceSettingRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestTraceSetting( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteTrace = [this]( - inference::TraceSettingRequest& request, - inference::TraceSettingResponse* response, - ::grpc::Status* status) { + // Define a lambda function 'callback' that takes a TraceSettingRequest, + // a TraceSettingResponse, and a grpc::Status. This function performs + // the same logic as the original OnExecuteTrace function. + auto callback = [this]( + inference::TraceSettingRequest& request, + inference::TraceSettingResponse* response, + ::grpc::Status* status) { #ifdef TRITON_ENABLE_TRACING TRITONSERVER_Error* err = nullptr; TRITONSERVER_InferenceTraceLevel level = TRITONSERVER_TRACE_LEVEL_DISABLED; @@ -1447,30 +1455,28 @@ CommonHandler::RegisterTrace() const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::TRACE); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::TraceSettingRequest, inference::TraceSettingResponse>( - "Trace", 0, OnRegisterTrace, OnExecuteTrace, false /* async */, cq_, - restricted_kv, response_delay_); + + // Use non_inference_callback_service_->TraceSetting to register the callback. + // This replaces the use of CommonCallData. + non_inference_callback_service_->TraceSetting( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::TraceSettingRequest, inference::TraceSettingResponse>( + "TraceSetting", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterLogging() { - auto OnRegisterLogging = - [this]( - ::grpc::ServerContext* ctx, inference::LogSettingsRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestLogSettings( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteLogging = [this]( - inference::LogSettingsRequest& request, - inference::LogSettingsResponse* response, - ::grpc::Status* status) { + // Define a lambda function 'callback' that takes a LogSettingsRequest, + // a LogSettingsResponse, and a grpc::Status. This function performs + // the same logic as the original OnExecuteLogging function. + auto callback = [this]( + inference::LogSettingsRequest& request, + inference::LogSettingsResponse* response, + ::grpc::Status* status) { #ifdef TRITON_ENABLE_LOGGING TRITONSERVER_Error* err = nullptr; @@ -1634,11 +1640,16 @@ CommonHandler::RegisterLogging() const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::LOGGING); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::LogSettingsRequest, inference::LogSettingsResponse>( - "Logging", 0, OnRegisterLogging, OnExecuteLogging, false /* async */, cq_, - restricted_kv, response_delay_); + + // Use non_inference_callback_service_->LogSettings to register the callback. + // This replaces the use of CommonCallData. + non_inference_callback_service_->LogSettings( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::LogSettingsRequest, inference::LogSettingsResponse>( + "LogSettings", non_inference_callback_service_, callback, + restricted_kv)); } void @@ -2285,6 +2296,7 @@ Server::Server( builder_.SetMaxMessageSize(MAX_GRPC_MESSAGE_SIZE); builder_.RegisterService(&service_); builder_.RegisterService(&health_service_); + builder_.RegisterService(&non_inference_callback_service_); builder_.AddChannelArgument( GRPC_ARG_ALLOW_REUSEPORT, options.socket_.reuse_port_); @@ -2383,8 +2395,8 @@ Server::Server( // A common Handler for other non-inference requests common_handler_.reset(new CommonHandler( "CommonHandler", tritonserver_, shm_manager_, trace_manager_, &service_, - &health_service_, common_cq_.get(), options.restricted_protocols_, - response_delay)); + &health_service_, &non_inference_callback_service_, common_cq_.get(), + options.restricted_protocols_, response_delay)); // [FIXME] "register" logic is different for infer // Handler for model inference requests. @@ -2546,6 +2558,7 @@ Server::Start() (std::string("Socket '") + server_addr_ + "' already in use ").c_str()); } + // Remove this common_handler_->Start(); for (auto& model_infer_handler : model_infer_handlers_) { model_infer_handler->Start(); diff --git a/src/grpc/grpc_server.h b/src/grpc/grpc_server.h index 89d8dc7388..2a7a5ff0ba 100644 --- a/src/grpc/grpc_server.h +++ b/src/grpc/grpc_server.h @@ -1,4 +1,4 @@ -// Copyright 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -139,14 +139,16 @@ class Server { inference::GRPCInferenceService::AsyncService service_; ::grpc::health::v1::Health::AsyncService health_service_; + inference::GRPCInferenceService::CallbackService + non_inference_callback_service_; std::unique_ptr<::grpc::Server> server_; - std::unique_ptr<::grpc::ServerCompletionQueue> common_cq_; + // std::unique_ptr<::grpc::ServerCompletionQueue> common_cq_; std::unique_ptr<::grpc::ServerCompletionQueue> model_infer_cq_; std::unique_ptr<::grpc::ServerCompletionQueue> model_stream_infer_cq_; - std::unique_ptr common_handler_; + // std::unique_ptr common_handler_; std::vector> model_infer_handlers_; std::vector> model_stream_infer_handlers_; From 9c17ed6aba0b29b6f8c600955388b28fbea6e3a3 Mon Sep 17 00:00:00 2001 From: Indrajit Bhosale Date: Wed, 19 Feb 2025 08:02:42 -0800 Subject: [PATCH 2/8] Convert all Non Inference RPCs --- src/grpc/grpc_server.cc | 862 +++++++++++++++++++--------------------- 1 file changed, 411 insertions(+), 451 deletions(-) diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index 836e32d11e..a15912eb41 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -1655,115 +1655,103 @@ CommonHandler::RegisterLogging() void CommonHandler::RegisterSystemSharedMemoryStatus() { - auto OnRegisterSystemSharedMemoryStatus = - [this]( - ::grpc::ServerContext* ctx, - inference::SystemSharedMemoryStatusRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryStatusResponse>* responder, - void* tag) { - this->service_->RequestSystemSharedMemoryStatus( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteSystemSharedMemoryStatus = - [this]( - inference::SystemSharedMemoryStatusRequest& request, - inference::SystemSharedMemoryStatusResponse* response, - ::grpc::Status* status) { - triton::common::TritonJson::Value shm_status_json( - triton::common::TritonJson::ValueType::ARRAY); - TRITONSERVER_Error* err = shm_manager_->GetStatus( - request.name(), TRITONSERVER_MEMORY_CPU, &shm_status_json); - GOTO_IF_ERR(err, earlyexit); + // Define a lambda function 'callback' that takes a + // SystemSharedMemoryStatusRequest, a SystemSharedMemoryStatusResponse, and a + // grpc::Status. This function performs the same logic as the original + // OnExecuteSystemSharedMemoryStatus function. + auto callback = [this]( + inference::SystemSharedMemoryStatusRequest& request, + inference::SystemSharedMemoryStatusResponse* response, + ::grpc::Status* status) { + triton::common::TritonJson::Value shm_status_json( + triton::common::TritonJson::ValueType::ARRAY); + TRITONSERVER_Error* err = shm_manager_->GetStatus( + request.name(), TRITONSERVER_MEMORY_CPU, &shm_status_json); + GOTO_IF_ERR(err, earlyexit); - for (size_t idx = 0; idx < shm_status_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value shm_region_json; - err = shm_status_json.IndexAsObject(idx, &shm_region_json); - GOTO_IF_ERR(err, earlyexit); + for (size_t idx = 0; idx < shm_status_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value shm_region_json; + err = shm_status_json.IndexAsObject(idx, &shm_region_json); + GOTO_IF_ERR(err, earlyexit); - const char* name; - size_t namelen; - err = shm_region_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); + const char* name; + size_t namelen; + err = shm_region_json.MemberAsString("name", &name, &namelen); + GOTO_IF_ERR(err, earlyexit); - const char* key; - size_t keylen; - err = shm_region_json.MemberAsString("key", &key, &keylen); - GOTO_IF_ERR(err, earlyexit); + const char* key; + size_t keylen; + err = shm_region_json.MemberAsString("key", &key, &keylen); + GOTO_IF_ERR(err, earlyexit); - uint64_t offset; - err = shm_region_json.MemberAsUInt("offset", &offset); - GOTO_IF_ERR(err, earlyexit); + uint64_t offset; + err = shm_region_json.MemberAsUInt("offset", &offset); + GOTO_IF_ERR(err, earlyexit); - uint64_t byte_size; - err = shm_region_json.MemberAsUInt("byte_size", &byte_size); - GOTO_IF_ERR(err, earlyexit); + uint64_t byte_size; + err = shm_region_json.MemberAsUInt("byte_size", &byte_size); + GOTO_IF_ERR(err, earlyexit); - inference::SystemSharedMemoryStatusResponse::RegionStatus - region_status; - region_status.set_name(std::string(name, namelen)); - region_status.set_key(std::string(key, keylen)); - region_status.set_offset(offset); - region_status.set_byte_size(byte_size); + inference::SystemSharedMemoryStatusResponse::RegionStatus region_status; + region_status.set_name(std::string(name, namelen)); + region_status.set_key(std::string(key, keylen)); + region_status.set_offset(offset); + region_status.set_byte_size(byte_size); - (*response->mutable_regions())[name] = region_status; - } + (*response->mutable_regions())[name] = region_status; + } - earlyexit: - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; + earlyexit: + GrpcStatusUtil::Create(status, err); + TRITONSERVER_ErrorDelete(err); + }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryStatusResponse>, - inference::SystemSharedMemoryStatusRequest, - inference::SystemSharedMemoryStatusResponse>( - "SystemSharedMemoryStatus", 0, OnRegisterSystemSharedMemoryStatus, - OnExecuteSystemSharedMemoryStatus, false /* async */, cq_, restricted_kv, - response_delay_); + + // Use non_inference_callback_service_->SystemSharedMemoryStatus to register + // the callback. This replaces the use of CommonCallData. + non_inference_callback_service_->SystemSharedMemoryStatus( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::SystemSharedMemoryStatusRequest, + inference::SystemSharedMemoryStatusResponse>( + "SystemSharedMemoryStatus", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterSystemSharedMemoryRegister() { - auto OnRegisterSystemSharedMemoryRegister = - [this]( - ::grpc::ServerContext* ctx, - inference::SystemSharedMemoryRegisterRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryRegisterResponse>* responder, - void* tag) { - this->service_->RequestSystemSharedMemoryRegister( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteSystemSharedMemoryRegister = - [this]( - inference::SystemSharedMemoryRegisterRequest& request, - inference::SystemSharedMemoryRegisterResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = shm_manager_->RegisterSystemSharedMemory( - request.name(), request.key(), request.offset(), - request.byte_size()); + // Define a lambda function 'callback' that takes a + // SystemSharedMemoryRegisterRequest, a SystemSharedMemoryRegisterResponse, + // and a grpc::Status. This function performs the same logic as the original + // OnExecuteSystemSharedMemoryRegister function. + auto callback = [this]( + inference::SystemSharedMemoryRegisterRequest& request, + inference::SystemSharedMemoryRegisterResponse* response, + ::grpc::Status* status) { + TRITONSERVER_Error* err = shm_manager_->RegisterSystemSharedMemory( + request.name(), request.key(), request.offset(), request.byte_size()); - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; + GrpcStatusUtil::Create(status, err); + TRITONSERVER_ErrorDelete(err); + }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryRegisterResponse>, - inference::SystemSharedMemoryRegisterRequest, - inference::SystemSharedMemoryRegisterResponse>( - "SystemSharedMemoryRegister", 0, OnRegisterSystemSharedMemoryRegister, - OnExecuteSystemSharedMemoryRegister, false /* async */, cq_, - restricted_kv, response_delay_); + + // Use non_inference_callback_service_->SystemSharedMemoryRegister to register + // the callback. This replaces the use of CommonCallData. + non_inference_callback_service_->SystemSharedMemoryRegister( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::SystemSharedMemoryRegisterRequest, + inference::SystemSharedMemoryRegisterResponse>( + "SystemSharedMemoryRegister", non_inference_callback_service_, + callback, restricted_kv)); } void @@ -1812,447 +1800,419 @@ CommonHandler::RegisterSystemSharedMemoryUnregister() void CommonHandler::RegisterCudaSharedMemoryStatus() { - auto OnRegisterCudaSharedMemoryStatus = - [this]( - ::grpc::ServerContext* ctx, - inference::CudaSharedMemoryStatusRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryStatusResponse>* responder, - void* tag) { - this->service_->RequestCudaSharedMemoryStatus( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - auto OnExecuteCudaSharedMemoryStatus = - [this]( - inference::CudaSharedMemoryStatusRequest& request, - inference::CudaSharedMemoryStatusResponse* response, - ::grpc::Status* status) { - triton::common::TritonJson::Value shm_status_json( - triton::common::TritonJson::ValueType::ARRAY); - TRITONSERVER_Error* err = shm_manager_->GetStatus( - request.name(), TRITONSERVER_MEMORY_GPU, &shm_status_json); - GOTO_IF_ERR(err, earlyexit); - - for (size_t idx = 0; idx < shm_status_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value shm_region_json; - err = shm_status_json.IndexAsObject(idx, &shm_region_json); - GOTO_IF_ERR(err, earlyexit); + // Define a lambda function 'callback' that takes a + // CudaSharedMemoryStatusRequest, a CudaSharedMemoryStatusResponse, and a + // grpc::Status. This function performs the same logic as the original + // OnExecuteCudaSharedMemoryStatus function. + auto callback = [this]( + inference::CudaSharedMemoryStatusRequest& request, + inference::CudaSharedMemoryStatusResponse* response, + ::grpc::Status* status) { + triton::common::TritonJson::Value shm_status_json( + triton::common::TritonJson::ValueType::ARRAY); + TRITONSERVER_Error* err = shm_manager_->GetStatus( + request.name(), TRITONSERVER_MEMORY_GPU, &shm_status_json); + GOTO_IF_ERR(err, earlyexit); - const char* name; - size_t namelen; - err = shm_region_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); + for (size_t idx = 0; idx < shm_status_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value shm_region_json; + err = shm_status_json.IndexAsObject(idx, &shm_region_json); + GOTO_IF_ERR(err, earlyexit); - uint64_t device_id; - err = shm_region_json.MemberAsUInt("device_id", &device_id); - GOTO_IF_ERR(err, earlyexit); + const char* name; + size_t namelen; + err = shm_region_json.MemberAsString("name", &name, &namelen); + GOTO_IF_ERR(err, earlyexit); - uint64_t byte_size; - err = shm_region_json.MemberAsUInt("byte_size", &byte_size); - GOTO_IF_ERR(err, earlyexit); + uint64_t device_id; + err = shm_region_json.MemberAsUInt("device_id", &device_id); + GOTO_IF_ERR(err, earlyexit); + uint64_t byte_size; + err = shm_region_json.MemberAsUInt("byte_size", &byte_size); + GOTO_IF_ERR(err, earlyexit); - inference::CudaSharedMemoryStatusResponse::RegionStatus region_status; - region_status.set_name(std::string(name, namelen)); - region_status.set_device_id(device_id); - region_status.set_byte_size(byte_size); + inference::CudaSharedMemoryStatusResponse::RegionStatus region_status; + region_status.set_name(std::string(name, namelen)); + region_status.set_device_id(device_id); + region_status.set_byte_size(byte_size); - (*response->mutable_regions())[name] = region_status; - } - earlyexit: - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; + (*response->mutable_regions())[name] = region_status; + } + earlyexit: + GrpcStatusUtil::Create(status, err); + TRITONSERVER_ErrorDelete(err); + }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryStatusResponse>, - inference::CudaSharedMemoryStatusRequest, - inference::CudaSharedMemoryStatusResponse>( - "CudaSharedMemoryStatus", 0, OnRegisterCudaSharedMemoryStatus, - OnExecuteCudaSharedMemoryStatus, false /* async */, cq_, restricted_kv, - response_delay_); + + // Use non_inference_callback_service_->CudaSharedMemoryStatus to register the + // callback. This replaces the use of CommonCallData. + non_inference_callback_service_->CudaSharedMemoryStatus( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::CudaSharedMemoryStatusRequest, + inference::CudaSharedMemoryStatusResponse>( + "CudaSharedMemoryStatus", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterCudaSharedMemoryRegister() { - auto OnRegisterCudaSharedMemoryRegister = - [this]( - ::grpc::ServerContext* ctx, - inference::CudaSharedMemoryRegisterRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryRegisterResponse>* responder, - void* tag) { - this->service_->RequestCudaSharedMemoryRegister( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteCudaSharedMemoryRegister = - [this]( - inference::CudaSharedMemoryRegisterRequest& request, - inference::CudaSharedMemoryRegisterResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; + // Define a lambda function 'callback' that takes a + // CudaSharedMemoryRegisterRequest, a CudaSharedMemoryRegisterResponse, and a + // grpc::Status. This function performs the same logic as the original + // OnExecuteCudaSharedMemoryRegister function. + auto callback = [this]( + inference::CudaSharedMemoryRegisterRequest& request, + inference::CudaSharedMemoryRegisterResponse* response, + ::grpc::Status* status) { + TRITONSERVER_Error* err = nullptr; #ifdef TRITON_ENABLE_GPU - err = shm_manager_->RegisterCUDASharedMemory( - request.name(), - reinterpret_cast( - request.raw_handle().c_str()), - request.byte_size(), request.device_id()); + err = shm_manager_->RegisterCUDASharedMemory( + request.name(), + reinterpret_cast( + request.raw_handle().c_str()), + request.byte_size(), request.device_id()); #else - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - std::string( - "failed to register CUDA shared memory region: '" + - request.name() + "', GPUs not supported") - .c_str()); + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "failed to register CUDA shared memory region: '" + request.name() + + "', GPUs not supported") + .c_str()); #endif // TRITON_ENABLE_GPU - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; + GrpcStatusUtil::Create(status, err); + TRITONSERVER_ErrorDelete(err); + }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryRegisterResponse>, - inference::CudaSharedMemoryRegisterRequest, - inference::CudaSharedMemoryRegisterResponse>( - "CudaSharedMemoryRegister", 0, OnRegisterCudaSharedMemoryRegister, - OnExecuteCudaSharedMemoryRegister, false /* async */, cq_, restricted_kv, - response_delay_); + + // Use non_inference_callback_service_->CudaSharedMemoryRegister to register + // the callback. This replaces the use of CommonCallData. + non_inference_callback_service_->CudaSharedMemoryRegister( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::CudaSharedMemoryRegisterRequest, + inference::CudaSharedMemoryRegisterResponse>( + "CudaSharedMemoryRegister", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterCudaSharedMemoryUnregister() { - auto OnRegisterCudaSharedMemoryUnregister = - [this]( - ::grpc::ServerContext* ctx, - inference::CudaSharedMemoryUnregisterRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryUnregisterResponse>* responder, - void* tag) { - this->service_->RequestCudaSharedMemoryUnregister( - ctx, request, responder, this->cq_, this->cq_, tag); - }; + // Define a lambda function 'callback' that takes a + // CudaSharedMemoryUnregisterRequest, a CudaSharedMemoryUnregisterResponse, + // and a grpc::Status. This function performs the same logic as the original + // OnExecuteCudaSharedMemoryUnregister function. + auto callback = [this]( + inference::CudaSharedMemoryUnregisterRequest& request, + inference::CudaSharedMemoryUnregisterResponse* response, + ::grpc::Status* status) { + TRITONSERVER_Error* err = nullptr; + if (request.name().empty()) { + err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_GPU); + } else { + err = shm_manager_->Unregister(request.name(), TRITONSERVER_MEMORY_GPU); + } - auto OnExecuteCudaSharedMemoryUnregister = - [this]( - inference::CudaSharedMemoryUnregisterRequest& request, - inference::CudaSharedMemoryUnregisterResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.name().empty()) { - err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_GPU); - } else { - err = - shm_manager_->Unregister(request.name(), TRITONSERVER_MEMORY_GPU); - } + GrpcStatusUtil::Create(status, err); + TRITONSERVER_ErrorDelete(err); + }; - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryUnregisterResponse>, - inference::CudaSharedMemoryUnregisterRequest, - inference::CudaSharedMemoryUnregisterResponse>( - "CudaSharedMemoryUnregister", 0, OnRegisterCudaSharedMemoryUnregister, - OnExecuteCudaSharedMemoryUnregister, false /* async */, cq_, - restricted_kv, response_delay_); + // Use non_inference_callback_service_->CudaSharedMemoryUnregister to register + // the callback. This replaces the use of CommonCallData. + non_inference_callback_service_->CudaSharedMemoryUnregister( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::CudaSharedMemoryUnregisterRequest, + inference::CudaSharedMemoryUnregisterResponse>( + "CudaSharedMemoryUnregister", non_inference_callback_service_, + callback, restricted_kv)); } void CommonHandler::RegisterRepositoryIndex() { - auto OnRegisterRepositoryIndex = - [this]( - ::grpc::ServerContext* ctx, - inference::RepositoryIndexRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestRepositoryIndex( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteRepositoryIndex = - [this]( - inference::RepositoryIndexRequest& request, - inference::RepositoryIndexResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.repository_name().empty()) { - uint32_t flags = 0; - if (request.ready()) { - flags |= TRITONSERVER_INDEX_FLAG_READY; - } - - TRITONSERVER_Message* model_index_message = nullptr; - err = TRITONSERVER_ServerModelIndex( - tritonserver_.get(), flags, &model_index_message); - GOTO_IF_ERR(err, earlyexit); + // Define a lambda function 'callback' that takes a RepositoryIndexRequest, + // a RepositoryIndexResponse, and a grpc::Status. This function performs + // the same logic as the original OnExecuteRepositoryIndex function. + auto callback = [this]( + inference::RepositoryIndexRequest& request, + inference::RepositoryIndexResponse* response, + ::grpc::Status* status) { + TRITONSERVER_Error* err = nullptr; + if (request.repository_name().empty()) { + uint32_t flags = 0; + if (request.ready()) { + flags |= TRITONSERVER_INDEX_FLAG_READY; + } - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - model_index_message, &buffer, &byte_size); - GOTO_IF_ERR(err, earlyexit); + TRITONSERVER_Message* model_index_message = nullptr; + err = TRITONSERVER_ServerModelIndex( + tritonserver_.get(), flags, &model_index_message); + GOTO_IF_ERR(err, earlyexit); - triton::common::TritonJson::Value model_index_json; - err = model_index_json.Parse(buffer, byte_size); - GOTO_IF_ERR(err, earlyexit); + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + model_index_message, &buffer, &byte_size); + GOTO_IF_ERR(err, earlyexit); - err = model_index_json.AssertType( - triton::common::TritonJson::ValueType::ARRAY); - GOTO_IF_ERR(err, earlyexit); + triton::common::TritonJson::Value model_index_json; + err = model_index_json.Parse(buffer, byte_size); + GOTO_IF_ERR(err, earlyexit); - for (size_t idx = 0; idx < model_index_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value index_json; - err = model_index_json.IndexAsObject(idx, &index_json); - GOTO_IF_ERR(err, earlyexit); + err = model_index_json.AssertType( + triton::common::TritonJson::ValueType::ARRAY); + GOTO_IF_ERR(err, earlyexit); - auto model_index = response->add_models(); + for (size_t idx = 0; idx < model_index_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value index_json; + err = model_index_json.IndexAsObject(idx, &index_json); + GOTO_IF_ERR(err, earlyexit); - const char* name; - size_t namelen; - err = index_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_name(std::string(name, namelen)); + auto model_index = response->add_models(); - if (index_json.Find("version")) { - const char* version; - size_t versionlen; - err = index_json.MemberAsString("version", &version, &versionlen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_version(std::string(version, versionlen)); - } - if (index_json.Find("state")) { - const char* state; - size_t statelen; - err = index_json.MemberAsString("state", &state, &statelen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_state(std::string(state, statelen)); - } - if (index_json.Find("reason")) { - const char* reason; - size_t reasonlen; - err = index_json.MemberAsString("reason", &reason, &reasonlen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_reason(std::string(reason, reasonlen)); - } - } + const char* name; + size_t namelen; + err = index_json.MemberAsString("name", &name, &namelen); + GOTO_IF_ERR(err, earlyexit); + model_index->set_name(std::string(name, namelen)); - TRITONSERVER_MessageDelete(model_index_message); - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_UNSUPPORTED, - "'repository_name' specification is not supported"); + if (index_json.Find("version")) { + const char* version; + size_t versionlen; + err = index_json.MemberAsString("version", &version, &versionlen); + GOTO_IF_ERR(err, earlyexit); + model_index->set_version(std::string(version, versionlen)); + } + if (index_json.Find("state")) { + const char* state; + size_t statelen; + err = index_json.MemberAsString("state", &state, &statelen); + GOTO_IF_ERR(err, earlyexit); + model_index->set_state(std::string(state, statelen)); } + if (index_json.Find("reason")) { + const char* reason; + size_t reasonlen; + err = index_json.MemberAsString("reason", &reason, &reasonlen); + GOTO_IF_ERR(err, earlyexit); + model_index->set_reason(std::string(reason, reasonlen)); + } + } - earlyexit: - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; + TRITONSERVER_MessageDelete(model_index_message); + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "'repository_name' specification is not supported"); + } + + earlyexit: + GrpcStatusUtil::Create(status, err); + TRITONSERVER_ErrorDelete(err); + }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::RepositoryIndexRequest, inference::RepositoryIndexResponse>( - "RepositoryIndex", 0, OnRegisterRepositoryIndex, OnExecuteRepositoryIndex, - false /* async */, cq_, restricted_kv, response_delay_); + + // Use non_inference_callback_service_->RepositoryIndex to register the + // callback. This replaces the use of CommonCallData. + non_inference_callback_service_->RepositoryIndex( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::RepositoryIndexRequest, + inference::RepositoryIndexResponse>( + "RepositoryIndex", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterRepositoryModelLoad() { - auto OnRegisterRepositoryModelLoad = - [this]( - ::grpc::ServerContext* ctx, - inference::RepositoryModelLoadRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::RepositoryModelLoadResponse>* responder, - void* tag) { - this->service_->RequestRepositoryModelLoad( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteRepositoryModelLoad = - [this]( - inference::RepositoryModelLoadRequest& request, - inference::RepositoryModelLoadResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.repository_name().empty()) { - std::vector params; - // WAR for the const-ness check - std::vector const_params; - for (const auto& param_proto : request.parameters()) { - if (param_proto.first == "config") { - if (param_proto.second.parameter_choice_case() != - inference::ModelRepositoryParameter::ParameterChoiceCase:: - kStringParam) { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("invalid value type for load parameter '") + - param_proto.first + "', expected string_param.") - .c_str()); - break; - } else { - auto param = TRITONSERVER_ParameterNew( - param_proto.first.c_str(), TRITONSERVER_PARAMETER_STRING, - param_proto.second.string_param().c_str()); - if (param != nullptr) { - params.emplace_back(param); - const_params.emplace_back(param); - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - "unexpected error on creating Triton parameter"); - break; - } - } - } else if (param_proto.first.rfind("file:", 0) == 0) { - if (param_proto.second.parameter_choice_case() != - inference::ModelRepositoryParameter::ParameterChoiceCase:: - kBytesParam) { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("invalid value type for load parameter '") + - param_proto.first + "', expected bytes_param.") - .c_str()); - break; - } else { - auto param = TRITONSERVER_ParameterBytesNew( - param_proto.first.c_str(), - param_proto.second.bytes_param().data(), - param_proto.second.bytes_param().length()); - if (param != nullptr) { - params.emplace_back(param); - const_params.emplace_back(param); - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - "unexpected error on creating Triton parameter"); - break; - } - } + // Define a lambda function 'callback' that takes a + // RepositoryModelLoadRequest, a RepositoryModelLoadResponse, and a + // grpc::Status. This function performs the same logic as the original + // OnExecuteRepositoryModelLoad function. + auto callback = [this]( + inference::RepositoryModelLoadRequest& request, + inference::RepositoryModelLoadResponse* response, + ::grpc::Status* status) { + TRITONSERVER_Error* err = nullptr; + if (request.repository_name().empty()) { + std::vector params; + // WAR for the const-ness check + std::vector const_params; + for (const auto& param_proto : request.parameters()) { + if (param_proto.first == "config") { + if (param_proto.second.parameter_choice_case() != + inference::ModelRepositoryParameter::ParameterChoiceCase:: + kStringParam) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("invalid value type for load parameter '") + + param_proto.first + "', expected string_param.") + .c_str()); + break; + } else { + auto param = TRITONSERVER_ParameterNew( + param_proto.first.c_str(), TRITONSERVER_PARAMETER_STRING, + param_proto.second.string_param().c_str()); + if (param != nullptr) { + params.emplace_back(param); + const_params.emplace_back(param); } else { err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("unrecognized load parameter '") + - param_proto.first + "'.") - .c_str()); + TRITONSERVER_ERROR_INTERNAL, + "unexpected error on creating Triton parameter"); break; } } - if (err == nullptr) { - err = TRITONSERVER_ServerLoadModelWithParameters( - tritonserver_.get(), request.model_name().c_str(), - const_params.data(), const_params.size()); - } - // Assumes no further 'params' access after load API returns - for (auto& param : params) { - TRITONSERVER_ParameterDelete(param); + } else if (param_proto.first.rfind("file:", 0) == 0) { + if (param_proto.second.parameter_choice_case() != + inference::ModelRepositoryParameter::ParameterChoiceCase:: + kBytesParam) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("invalid value type for load parameter '") + + param_proto.first + "', expected bytes_param.") + .c_str()); + break; + } else { + auto param = TRITONSERVER_ParameterBytesNew( + param_proto.first.c_str(), + param_proto.second.bytes_param().data(), + param_proto.second.bytes_param().length()); + if (param != nullptr) { + params.emplace_back(param); + const_params.emplace_back(param); + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "unexpected error on creating Triton parameter"); + break; + } } } else { err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_UNSUPPORTED, - "'repository_name' specification is not supported"); + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("unrecognized load parameter '") + + param_proto.first + "'.") + .c_str()); + break; } + } + if (err == nullptr) { + err = TRITONSERVER_ServerLoadModelWithParameters( + tritonserver_.get(), request.model_name().c_str(), + const_params.data(), const_params.size()); + } + // Assumes no further 'params' access after load API returns + for (auto& param : params) { + TRITONSERVER_ParameterDelete(param); + } + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "'repository_name' specification is not supported"); + } - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; + GrpcStatusUtil::Create(status, err); + TRITONSERVER_ErrorDelete(err); + }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::RepositoryModelLoadRequest, - inference::RepositoryModelLoadResponse>( - "RepositoryModelLoad", 0, OnRegisterRepositoryModelLoad, - OnExecuteRepositoryModelLoad, true /* async */, cq_, restricted_kv, - response_delay_); + + // Use non_inference_callback_service_->RepositoryModelLoad to register the + // callback. This replaces the use of CommonCallData. + non_inference_callback_service_->RepositoryModelLoad( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::RepositoryModelLoadRequest, + inference::RepositoryModelLoadResponse>( + "RepositoryModelLoad", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterRepositoryModelUnload() { - auto OnRegisterRepositoryModelUnload = - [this]( - ::grpc::ServerContext* ctx, - inference::RepositoryModelUnloadRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::RepositoryModelUnloadResponse>* responder, - void* tag) { - this->service_->RequestRepositoryModelUnload( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteRepositoryModelUnload = - [this]( - inference::RepositoryModelUnloadRequest& request, - inference::RepositoryModelUnloadResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.repository_name().empty()) { - // Check if the dependent models should be removed - bool unload_dependents = false; - for (auto param : request.parameters()) { - if (param.first.compare("unload_dependents") == 0) { - const auto& unload_param = param.second; - if (unload_param.parameter_choice_case() != - inference::ModelRepositoryParameter::ParameterChoiceCase:: - kBoolParam) { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - "invalid value type for 'unload_dependents' parameter, " - "expected " - "bool_param."); - } - unload_dependents = unload_param.bool_param(); - break; - } - } - if (err == nullptr) { - if (unload_dependents) { - err = TRITONSERVER_ServerUnloadModelAndDependents( - tritonserver_.get(), request.model_name().c_str()); - } else { - err = TRITONSERVER_ServerUnloadModel( - tritonserver_.get(), request.model_name().c_str()); - } + // Define a lambda function 'callback' that takes a + // RepositoryModelUnloadRequest, a RepositoryModelUnloadResponse, and a + // grpc::Status. This function performs the same logic as the original + // OnExecuteRepositoryModelUnload function. + auto callback = [this]( + inference::RepositoryModelUnloadRequest& request, + inference::RepositoryModelUnloadResponse* response, + ::grpc::Status* status) { + TRITONSERVER_Error* err = nullptr; + if (request.repository_name().empty()) { + // Check if the dependent models should be removed + bool unload_dependents = false; + for (auto param : request.parameters()) { + if (param.first.compare("unload_dependents") == 0) { + const auto& unload_param = param.second; + if (unload_param.parameter_choice_case() != + inference::ModelRepositoryParameter::ParameterChoiceCase:: + kBoolParam) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "invalid value type for 'unload_dependents' parameter, " + "expected " + "bool_param."); } + unload_dependents = unload_param.bool_param(); + break; + } + } + if (err == nullptr) { + if (unload_dependents) { + err = TRITONSERVER_ServerUnloadModelAndDependents( + tritonserver_.get(), request.model_name().c_str()); } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_UNSUPPORTED, - "'repository_name' specification is not supported"); + err = TRITONSERVER_ServerUnloadModel( + tritonserver_.get(), request.model_name().c_str()); } + } + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "'repository_name' specification is not supported"); + } - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; + GrpcStatusUtil::Create(status, err); + TRITONSERVER_ErrorDelete(err); + }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::RepositoryModelUnloadResponse>, - inference::RepositoryModelUnloadRequest, - inference::RepositoryModelUnloadResponse>( - "RepositoryModelUnload", 0, OnRegisterRepositoryModelUnload, - OnExecuteRepositoryModelUnload, true /* async */, cq_, restricted_kv, - response_delay_); + + // Use non_inference_callback_service_->RepositoryModelUnload to register the + // callback. This replaces the use of CommonCallData. + non_inference_callback_service_->RepositoryModelUnload( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::RepositoryModelUnloadRequest, + inference::RepositoryModelUnloadResponse>( + "RepositoryModelUnload", non_inference_callback_service_, callback, + restricted_kv)); } } // namespace From e52cde3e5d23b706412c38f1432a138d8eeda36a Mon Sep 17 00:00:00 2001 From: Indrajit Bhosale Date: Mon, 24 Feb 2025 11:03:13 -0800 Subject: [PATCH 3/8] Non Infrence Code Added --- src/grpc/grpc_handler.h | 11 +- src/grpc/grpc_server.cc | 2033 ++++++++++++++------------------------- src/grpc/grpc_server.h | 2 +- 3 files changed, 731 insertions(+), 1315 deletions(-) diff --git a/src/grpc/grpc_handler.h b/src/grpc/grpc_handler.h index 4f1bcdfac0..ce9a30667e 100644 --- a/src/grpc/grpc_handler.h +++ b/src/grpc/grpc_handler.h @@ -1,4 +1,4 @@ -// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -25,14 +25,23 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #pragma once +#include + #include +#include "grpc_service.grpc.pb.h" + namespace triton { namespace server { namespace grpc { class HandlerBase { public: virtual ~HandlerBase() = default; virtual void Start() = 0; virtual void Stop() = 0; + virtual inference::GRPCInferenceService::CallbackService* + GetUnifiedCallbackService() + { + return nullptr; + } }; class ICallData { diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index a15912eb41..ab644e3a07 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -79,560 +79,201 @@ namespace { // are deemed to be not performance critical. //========================================================================= -template -class CommonCallbackData { +// Define your unified callback service that implements several non-inference +// RPCs. +class UnifiedCallbackService + : public inference::GRPCInferenceService::CallbackService { public: - using CallbackFunc = - std::function; - - CommonCallbackData( - const std::string& name, - inference::GRPCInferenceService::CallbackService* service, - const CallbackFunc& callback, - const std::pair& restricted_kv) - : name_(name), service_(service), callback_(callback), - restricted_kv_(restricted_kv) - { - } - - void operator()(RequestType* request) - { - ResponseType response; - ::grpc::Status status; - - if (ExecutePrecondition()) { - callback_(*request, &response, &status); - } else { - status = ::grpc::Status( - ::grpc::StatusCode::UNAVAILABLE, - std::string("This protocol is restricted, expecting header '") + - restricted_kv_.first + "'"); - } - - request->request()->Complete(status); - delete this; - } - - private: - bool ExecutePrecondition() - { - if (!restricted_kv_.first.empty()) { - const auto& metadata = request->context()->client_metadata(); - const auto it = metadata.find(restricted_kv_.first); - return (it != metadata.end()) && (it->second == restricted_kv_.second); - } - return true; - } - - const std::string name_; - inference::GRPCInferenceService::CallbackService* service_; - CallbackFunc callback_; - std::pair restricted_kv_; -}; - -template -class CommonCallData : public ICallData { - public: - using StandardRegisterFunc = std::function; - using StandardCallbackFunc = - std::function; - - CommonCallData( - const std::string& name, const uint64_t id, - const StandardRegisterFunc OnRegister, - const StandardCallbackFunc OnExecute, const bool async, - ::grpc::ServerCompletionQueue* cq, - const std::pair& restricted_kv, - const uint64_t& response_delay = 0) - : name_(name), id_(id), OnRegister_(OnRegister), OnExecute_(OnExecute), - async_(async), cq_(cq), responder_(&ctx_), step_(Steps::START), - restricted_kv_(restricted_kv), response_delay_(response_delay) - { - OnRegister_(&ctx_, &request_, &responder_, this); - LOG_VERBOSE(1) << "Ready for RPC '" << name_ << "', " << id_; - } - - ~CommonCallData() + UnifiedCallbackService( + const std::shared_ptr& server, + const std::shared_ptr& shm_manager, + const std::pair& restrictedKV) + : tritonserver_(server), shm_manager_(shm_manager), + restricted_kv_(restrictedKV) { - if (async_thread_.joinable()) { - async_thread_.join(); - } - } - - bool Process(bool ok) override; - - std::string Name() override { return name_; } - - uint64_t Id() override { return id_; } - - private: - void Execute(); - void AddToCompletionQueue(); - void WriteResponse(); - bool ExecutePrecondition(); - - const std::string name_; - const uint64_t id_; - const StandardRegisterFunc OnRegister_; - const StandardCallbackFunc OnExecute_; - const bool async_; - ::grpc::ServerCompletionQueue* cq_; - - ::grpc::ServerContext ctx_; - ::grpc::Alarm alarm_; - - ResponderType responder_; - RequestType request_; - ResponseType response_; - ::grpc::Status status_; - - std::thread async_thread_; - - Steps step_; - - std::pair restricted_kv_{"", ""}; - - const uint64_t response_delay_; -}; - -template -bool -CommonCallData::Process(bool rpc_ok) -{ - LOG_VERBOSE(1) << "Process for " << name_ << ", rpc_ok=" << rpc_ok << ", " - << id_ << " step " << step_; - - // If RPC failed on a new request then the server is shutting down - // and so we should do nothing (including not registering for a new - // request). If RPC failed on a non-START step then there is nothing - // we can do since we one execute one step. - const bool shutdown = (!rpc_ok && (step_ == Steps::START)); - if (shutdown) { - if (async_thread_.joinable()) { - async_thread_.join(); - } - step_ = Steps::FINISH; - } - - if (step_ == Steps::START) { - // Start a new request to replace this one... - if (!shutdown) { - new CommonCallData( - name_, id_ + 1, OnRegister_, OnExecute_, async_, cq_, restricted_kv_, - response_delay_); - } - - if (!async_) { - // For synchronous calls, execute and write response - // here. - Execute(); - WriteResponse(); - } else { - // For asynchronous calls, delegate the execution to another - // thread. - step_ = Steps::ISSUED; - async_thread_ = std::thread(&CommonCallData::Execute, this); - } - } else if (step_ == Steps::WRITEREADY) { - // Will only come here for asynchronous mode. - WriteResponse(); - } else if (step_ == Steps::COMPLETE) { - step_ = Steps::FINISH; - } - - return step_ != Steps::FINISH; -} - -template -void -CommonCallData::Execute() -{ - if (ExecutePrecondition()) { - OnExecute_(request_, &response_, &status_); - } else { - status_ = ::grpc::Status( - ::grpc::StatusCode::UNAVAILABLE, - std::string("This protocol is restricted, expecting header '") + - restricted_kv_.first + "'"); - } - step_ = Steps::WRITEREADY; - - if (async_) { - // For asynchronous operation, need to add itself onto the completion - // queue so that the response can be written once the object is - // taken up next for execution. - AddToCompletionQueue(); } -} -template -bool -CommonCallData::ExecutePrecondition() -{ - if (!restricted_kv_.first.empty()) { - const auto& metadata = ctx_.client_metadata(); - const auto it = metadata.find(restricted_kv_.first); - return (it != metadata.end()) && (it->second == restricted_kv_.second); - } - return true; -} - -template -void -CommonCallData::AddToCompletionQueue() -{ - alarm_.Set(cq_, gpr_now(gpr_clock_type::GPR_CLOCK_REALTIME), this); -} - -template -void -CommonCallData::WriteResponse() -{ - if (response_delay_ != 0) { - // Will delay the write of the response by the specified time. - // This can be used to test the flow where there are other - // responses available to be written. - LOG_VERBOSE(1) << "Delaying the write of the response by " - << response_delay_ << " seconds"; - std::this_thread::sleep_for(std::chrono::seconds(response_delay_)); - } - step_ = Steps::COMPLETE; - responder_.Finish(response_, status_, this); -} - -// -// CommonHandler -// -// A common handler for all non-inference requests. -// -class CommonHandler : public HandlerBase { - public: - CommonHandler( - const std::string& name, - const std::shared_ptr& tritonserver, - const std::shared_ptr& shm_manager, - TraceManager* trace_manager, - inference::GRPCInferenceService::AsyncService* service, - ::grpc::health::v1::Health::AsyncService* health_service, - inference::GRPCInferenceService::CallbackService* - non_inference_callback_service, - const RestrictedFeatures& restricted_keys, const uint64_t response_delay); - - // Descriptive name of of the handler. - const std::string& Name() const { return name_; } - - // Start handling requests. - void Start() override; - - // Stop handling requests. - void Stop() override; - - private: - void SetUpAllRequests(); - - // [FIXME] turn into generated code - void RegisterServerLive(); - void RegisterServerReady(); - void RegisterHealthCheck(); - void RegisterModelReady(); - void RegisterServerMetadata(); - void RegisterModelMetadata(); - void RegisterModelConfig(); - void RegisterModelStatistics(); - void RegisterTrace(); - void RegisterLogging(); - void RegisterSystemSharedMemoryStatus(); - void RegisterSystemSharedMemoryRegister(); - void RegisterSystemSharedMemoryUnregister(); - void RegisterCudaSharedMemoryStatus(); - void RegisterCudaSharedMemoryRegister(); - void RegisterCudaSharedMemoryUnregister(); - void RegisterRepositoryIndex(); - void RegisterRepositoryModelLoad(); - void RegisterRepositoryModelUnload(); - - // Set count and cumulative duration for 'RegisterModelStatistics()' template TRITONSERVER_Error* SetStatisticsDuration( triton::common::TritonJson::Value& statistics_json, const std::string& statistics_name, - PBTYPE* mutable_statistics_duration_protobuf) const; - - const std::string name_; - std::shared_ptr tritonserver_; - - std::shared_ptr shm_manager_; - TraceManager* trace_manager_; - - inference::GRPCInferenceService::AsyncService* service_; - ::grpc::health::v1::Health::AsyncService* health_service_; - inference::GRPCInferenceService::CallbackService* - non_inference_callback_service_; - - ::grpc::ServerCompletionQueue* cq_; - std::unique_ptr thread_; - RestrictedFeatures restricted_keys_{}; - const uint64_t response_delay_ = 0; -}; + PBTYPE* mutable_statistics_duration_protobuf) + { + triton::common::TritonJson::Value statistics_duration_json; + RETURN_IF_ERR(statistics_json.MemberAsObject( + statistics_name.c_str(), &statistics_duration_json)); + + uint64_t value; + RETURN_IF_ERR(statistics_duration_json.MemberAsUInt("count", &value)); + mutable_statistics_duration_protobuf->set_count(value); + RETURN_IF_ERR(statistics_duration_json.MemberAsUInt("ns", &value)); + mutable_statistics_duration_protobuf->set_ns(value); + return nullptr; + } -CommonHandler::CommonHandler( - const std::string& name, - const std::shared_ptr& tritonserver, - const std::shared_ptr& shm_manager, - TraceManager* trace_manager, - inference::GRPCInferenceService::AsyncService* service, - ::grpc::health::v1::Health::AsyncService* health_service, - ::grpc::ServerCompletionQueue* cq, - const RestrictedFeatures& restricted_keys, - const uint64_t response_delay = 0) - : name_(name), tritonserver_(tritonserver), shm_manager_(shm_manager), - trace_manager_(trace_manager), service_(service), - health_service_(health_service), - non_inference_callback_service_(non_inference_callback_service), cq_(cq), - restricted_keys_(restricted_keys), response_delay_(response_delay) -{ -} + // Example RPC method: ServerLive + ::grpc::ServerUnaryReactor* ServerLive( + ::grpc::CallbackServerContext* context, + const inference::ServerLiveRequest* request, + inference::ServerLiveResponse* response) override + { + auto* reactor = context->DefaultReactor(); -void -CommonHandler::Start() -{ - // Use a barrier to make sure we don't return until thread has - // started. - auto barrier = std::make_shared(2); - - thread_.reset(new std::thread([this, barrier] { - SetUpAllRequests(); - barrier->Wait(); - - void* tag; - bool ok; - - while (cq_->Next(&tag, &ok)) { - ICallData* call_data = static_cast(tag); - if (!call_data->Process(ok)) { - LOG_VERBOSE(1) << "Done for " << call_data->Name() << ", " - << call_data->Id(); - delete call_data; + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; } } - })); - - barrier->Wait(); - LOG_VERBOSE(1) << "Thread started for " << Name(); -} - -void -CommonHandler::Stop() -{ - if (thread_->joinable()) { - thread_->join(); - } - - LOG_VERBOSE(1) << "Thread exited for " << Name(); -} -void -CommonHandler::SetUpAllRequests() -{ - // Define all the RPCs to be handled by this handler below - // - // Within each of the Register function, the format of RPC specification is: - // 1. A OnRegister function: This will be called when the - // server is ready to receive the requests for this RPC. - // 2. A OnExecute function: This will be called when the - // to process the request. - // 3. Create a CommonCallData object with the above callback - // functions - - // health (GRPC standard) - RegisterHealthCheck(); - // health (Triton) - RegisterServerLive(); - RegisterServerReady(); - RegisterModelReady(); - - // Metadata - RegisterServerMetadata(); - RegisterModelMetadata(); - - // model config - RegisterModelConfig(); - - // shared memory - // system.. - RegisterSystemSharedMemoryStatus(); - RegisterSystemSharedMemoryRegister(); - RegisterSystemSharedMemoryUnregister(); - // cuda.. - RegisterCudaSharedMemoryStatus(); - RegisterCudaSharedMemoryRegister(); - RegisterCudaSharedMemoryUnregister(); - - // model repository - RegisterRepositoryIndex(); - RegisterRepositoryModelLoad(); - RegisterRepositoryModelUnload(); - - // statistics - RegisterModelStatistics(); - - // trace - RegisterTrace(); - - // logging - RegisterLogging(); -} - -void -CommonHandler::RegisterServerLive() -{ - auto OnRegisterServerLive = - [this]( - ::grpc::ServerContext* ctx, inference::ServerLiveRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestServerLive( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteServerLive = [this]( - inference::ServerLiveRequest& request, - inference::ServerLiveResponse* response, - ::grpc::Status* status) { + // Business logic for ServerLive. bool live = false; TRITONSERVER_Error* err = TRITONSERVER_ServerIsLive(tritonserver_.get(), &live); - response->set_live((err == nullptr) && live); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::HEALTH); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::ServerLiveRequest, inference::ServerLiveResponse>( - "ServerLive", 0, OnRegisterServerLive, OnExecuteServerLive, - false /* async */, cq_, restricted_kv, response_delay_); -} + reactor->Finish(status); + return reactor; + } -// This change leverages the callback API, simplifying the handling of the -// ServerReady request by directly using the non_inference_callback_service_. -void -CommonHandler::RegisterServerReady() -{ - // Define a lambda function 'callback' that takes a ServerReadyRequest, - // a ServerReadyResponse, and a grpc::Status. This function performs - // the same logic as the original OnExecuteServerReady function. - auto callback = [this]( - inference::ServerReadyRequest& request, - inference::ServerReadyResponse* response, - ::grpc::Status* status) { + ::grpc::ServerUnaryReactor* ServerReady( + ::grpc::CallbackServerContext* context, + const inference::ServerReadyRequest* request, + inference::ServerReadyResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Business logic for ServerReady. bool ready = false; TRITONSERVER_Error* err = TRITONSERVER_ServerIsReady(tritonserver_.get(), &ready); - response->set_ready((err == nullptr) && ready); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::HEALTH); - - // Use non_inference_callback_service_->ServerReady to register the callback. - // This replaces the use of CommonCallData. - non_inference_callback_service_->ServerReady( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::ServerReadyRequest, inference::ServerReadyResponse>( - "ServerReady", non_inference_callback_service_, callback, - restricted_kv)); -} + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterHealthCheck() -{ - auto callback = [this]( - ::grpc::health::v1::HealthCheckRequest& request, - ::grpc::health::v1::HealthCheckResponse* response, - ::grpc::Status* status) { - bool live = false; - TRITONSERVER_Error* err = - TRITONSERVER_ServerIsReady(tritonserver_.get(), &live); + // ::grpc::ServerUnaryReactor* Check( + // ::grpc::CallbackServerContext* context, + // const ::grpc::health::v1::HealthCheckRequest* request, + // ::grpc::health::v1::HealthCheckResponse* response) override { + // auto* reactor = context->DefaultReactor(); + + // // (Optionally) Check client metadata for restricted access. + // if (!restricted_kv_.first.empty()) { + // const auto& metadata = context->client_metadata(); + // auto it = metadata.find(restricted_kv_.first); + // if (it == metadata.end() || it->second != restricted_kv_.second) { + // reactor->Finish(::grpc::Status(::grpc::StatusCode::UNAVAILABLE, + // "Missing or mismatched restricted header")); + // return reactor; + // } + // } + + // // Business logic for HealthCheck. + // bool live = false; + // TRITONSERVER_Error* err = TRITONSERVER_ServerIsReady(tritonserver_.get(), + // &live); + + // auto serving_status = + // ::grpc::health::v1::HealthCheckResponse_ServingStatus_UNKNOWN; if (err == + // nullptr) { + // serving_status = live + // ? ::grpc::health::v1::HealthCheckResponse_ServingStatus_SERVING + // : + // ::grpc::health::v1::HealthCheckResponse_ServingStatus_NOT_SERVING; + // } + // response->set_status(serving_status); + + // ::grpc::Status status; + // GrpcStatusUtil::Create(&status, err); + // TRITONSERVER_ErrorDelete(err); + // reactor->Finish(status); + // return reactor; + // } + + ::grpc::ServerUnaryReactor* ModelReady( + ::grpc::CallbackServerContext* context, + const inference::ModelReadyRequest* request, + inference::ModelReadyResponse* response) override + { + auto* reactor = context->DefaultReactor(); - auto serving_status = - ::grpc::health::v1::HealthCheckResponse_ServingStatus_UNKNOWN; - if (err == nullptr) { - serving_status = - live ? ::grpc::health::v1::HealthCheckResponse_ServingStatus_SERVING - : ::grpc::health::v1:: - HealthCheckResponse_ServingStatus_NOT_SERVING; + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } } - response->set_status(serving_status); - - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::HEALTH); - non_inference_callback_service_->Check( - new CommonCallbackData< - ::grpc::health::v1::HealthCheckRequest, - ::grpc::health::v1::HealthCheckResponse>( - "Check", non_inference_callback_service_, callback, restricted_kv)); -} - -void -CommonHandler::RegisterModelReady() -{ - auto callback = [this]( - ::grpc::health::v1::HealthCheckRequest& request, - ::grpc::health::v1::HealthCheckResponse* response, - ::grpc::Status* status) { + // Business logic for ModelReady. bool is_ready = false; int64_t requested_model_version; - auto err = - GetModelVersionFromString(request.version(), &requested_model_version); + TRITONSERVER_Error* err = + GetModelVersionFromString(request->version(), &requested_model_version); if (err == nullptr) { err = TRITONSERVER_ServerModelIsReady( - tritonserver_.get(), request.name().c_str(), requested_model_version, + tritonserver_.get(), request->name().c_str(), requested_model_version, &is_ready); } response->set_ready(is_ready); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::HEALTH); - non_inference_callback_service_->ModelReady( - new CommonCallbackData< - ::grpc::health::v1::HealthCheckRequest, - ::grpc::health::v1::HealthCheckResponse>( - "ModelReady", non_inference_callback_service_, callback, - restricted_kv)); -} + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterServerMetadata() -{ - // Define a lambda function 'callback' that takes a ServerMetadataRequest, - // a ServerMetadataResponse, and a grpc::Status. This function performs - // the same logic as the original OnExecuteServerMetadata function. - auto callback = [this]( - inference::ServerMetadataRequest& request, - inference::ServerMetadataResponse* response, - ::grpc::Status* status) { + ::grpc::ServerUnaryReactor* ServerMetadata( + ::grpc::CallbackServerContext* context, + const inference::ServerMetadataRequest* request, + inference::ServerMetadataResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Business logic for ServerMetadata. TRITONSERVER_Message* server_metadata_message = nullptr; TRITONSERVER_Error* err = TRITONSERVER_ServerMetadata( tritonserver_.get(), &server_metadata_message); @@ -680,271 +321,271 @@ CommonHandler::RegisterServerMetadata() TRITONSERVER_MessageDelete(server_metadata_message); } - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::METADATA); - - // Use non_inference_callback_service_->ServerMetadata to register the - // callback. This replaces the use of CommonCallData. - non_inference_callback_service_->ServerMetadata( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::ServerMetadataRequest, inference::ServerMetadataResponse>( - "ServerMetadata", non_inference_callback_service_, callback, - restricted_kv)); -} - -void -CommonHandler::RegisterModelMetadata() -{ - // Define a lambda function 'callback' that takes a ModelMetadataRequest, - // a ModelMetadataResponse, and a grpc::Status. This function performs - // the same logic as the original OnExecuteModelMetadata function. - auto callback = [this]( - inference::ModelMetadataRequest& request, - inference::ModelMetadataResponse* response, - ::grpc::Status* status) { - int64_t requested_model_version; - auto err = - GetModelVersionFromString(request.version(), &requested_model_version); - GOTO_IF_ERR(err, earlyexit); + reactor->Finish(status); + return reactor; + } - TRITONSERVER_Message* model_metadata_message = nullptr; - err = TRITONSERVER_ServerModelMetadata( - tritonserver_.get(), request.name().c_str(), requested_model_version, - &model_metadata_message); - GOTO_IF_ERR(err, earlyexit); + ::grpc::ServerUnaryReactor* ModelMetadata( + ::grpc::CallbackServerContext* context, + const inference::ModelMetadataRequest* request, + inference::ModelMetadataResponse* response) override + { + auto* reactor = context->DefaultReactor(); - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - model_metadata_message, &buffer, &byte_size); - GOTO_IF_ERR(err, earlyexit); + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } - triton::common::TritonJson::Value model_metadata_json; - err = model_metadata_json.Parse(buffer, byte_size); + // Core business logic - kept same as original + int64_t requested_model_version; + auto err = + GetModelVersionFromString(request->version(), &requested_model_version); GOTO_IF_ERR(err, earlyexit); + { + TRITONSERVER_Message* model_metadata_message = nullptr; + err = TRITONSERVER_ServerModelMetadata( + tritonserver_.get(), request->name().c_str(), requested_model_version, + &model_metadata_message); + GOTO_IF_ERR(err, earlyexit); - const char* name; - size_t namelen; - err = model_metadata_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + model_metadata_message, &buffer, &byte_size); + GOTO_IF_ERR(err, earlyexit); - response->set_name(std::string(name, namelen)); + triton::common::TritonJson::Value model_metadata_json; + err = model_metadata_json.Parse(buffer, byte_size); + GOTO_IF_ERR(err, earlyexit); - if (model_metadata_json.Find("versions")) { - triton::common::TritonJson::Value versions_json; - err = model_metadata_json.MemberAsArray("versions", &versions_json); + const char* name; + size_t namelen; + err = model_metadata_json.MemberAsString("name", &name, &namelen); GOTO_IF_ERR(err, earlyexit); - for (size_t idx = 0; idx < versions_json.ArraySize(); ++idx) { - const char* version; - size_t versionlen; - err = versions_json.IndexAsString(idx, &version, &versionlen); + response->set_name(std::string(name, namelen)); + + if (model_metadata_json.Find("versions")) { + triton::common::TritonJson::Value versions_json; + err = model_metadata_json.MemberAsArray("versions", &versions_json); GOTO_IF_ERR(err, earlyexit); - response->add_versions(std::string(version, versionlen)); - } - } - const char* platform; - size_t platformlen; - err = - model_metadata_json.MemberAsString("platform", &platform, &platformlen); - GOTO_IF_ERR(err, earlyexit); - response->set_platform(std::string(platform, platformlen)); + for (size_t idx = 0; idx < versions_json.ArraySize(); ++idx) { + const char* version; + size_t versionlen; + err = versions_json.IndexAsString(idx, &version, &versionlen); + GOTO_IF_ERR(err, earlyexit); + response->add_versions(std::string(version, versionlen)); + } + } - if (model_metadata_json.Find("inputs")) { - triton::common::TritonJson::Value inputs_json; - err = model_metadata_json.MemberAsArray("inputs", &inputs_json); + const char* platform; + size_t platformlen; + err = model_metadata_json.MemberAsString( + "platform", &platform, &platformlen); GOTO_IF_ERR(err, earlyexit); + response->set_platform(std::string(platform, platformlen)); - for (size_t idx = 0; idx < inputs_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value io_json; - err = inputs_json.IndexAsObject(idx, &io_json); + if (model_metadata_json.Find("inputs")) { + triton::common::TritonJson::Value inputs_json; + err = model_metadata_json.MemberAsArray("inputs", &inputs_json); GOTO_IF_ERR(err, earlyexit); - inference::ModelMetadataResponse::TensorMetadata* io = - response->add_inputs(); - - const char* name; - size_t namelen; - err = io_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); + for (size_t idx = 0; idx < inputs_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value io_json; + err = inputs_json.IndexAsObject(idx, &io_json); + GOTO_IF_ERR(err, earlyexit); - const char* datatype; - size_t datatypelen; - err = io_json.MemberAsString("datatype", &datatype, &datatypelen); - GOTO_IF_ERR(err, earlyexit); + inference::ModelMetadataResponse::TensorMetadata* io = + response->add_inputs(); - io->set_name(std::string(name, namelen)); - io->set_datatype(std::string(datatype, datatypelen)); + const char* name; + size_t namelen; + err = io_json.MemberAsString("name", &name, &namelen); + GOTO_IF_ERR(err, earlyexit); - if (io_json.Find("shape")) { - triton::common::TritonJson::Value shape_json; - err = io_json.MemberAsArray("shape", &shape_json); + const char* datatype; + size_t datatypelen; + err = io_json.MemberAsString("datatype", &datatype, &datatypelen); GOTO_IF_ERR(err, earlyexit); - for (size_t sidx = 0; sidx < shape_json.ArraySize(); ++sidx) { - int64_t d; - err = shape_json.IndexAsInt(sidx, &d); + io->set_name(std::string(name, namelen)); + io->set_datatype(std::string(datatype, datatypelen)); + + if (io_json.Find("shape")) { + triton::common::TritonJson::Value shape_json; + err = io_json.MemberAsArray("shape", &shape_json); GOTO_IF_ERR(err, earlyexit); - io->add_shape(d); + + for (size_t sidx = 0; sidx < shape_json.ArraySize(); ++sidx) { + int64_t d; + err = shape_json.IndexAsInt(sidx, &d); + GOTO_IF_ERR(err, earlyexit); + io->add_shape(d); + } } } } - } - - if (model_metadata_json.Find("outputs")) { - triton::common::TritonJson::Value outputs_json; - err = model_metadata_json.MemberAsArray("outputs", &outputs_json); - GOTO_IF_ERR(err, earlyexit); - for (size_t idx = 0; idx < outputs_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value io_json; - err = outputs_json.IndexAsObject(idx, &io_json); + if (model_metadata_json.Find("outputs")) { + triton::common::TritonJson::Value outputs_json; + err = model_metadata_json.MemberAsArray("outputs", &outputs_json); GOTO_IF_ERR(err, earlyexit); - inference::ModelMetadataResponse::TensorMetadata* io = - response->add_outputs(); - - const char* name; - size_t namelen; - err = io_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); + for (size_t idx = 0; idx < outputs_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value io_json; + err = outputs_json.IndexAsObject(idx, &io_json); + GOTO_IF_ERR(err, earlyexit); - const char* datatype; - size_t datatypelen; - err = io_json.MemberAsString("datatype", &datatype, &datatypelen); - GOTO_IF_ERR(err, earlyexit); + inference::ModelMetadataResponse::TensorMetadata* io = + response->add_outputs(); - io->set_name(std::string(name, namelen)); - io->set_datatype(std::string(datatype, datatypelen)); + const char* name; + size_t namelen; + err = io_json.MemberAsString("name", &name, &namelen); + GOTO_IF_ERR(err, earlyexit); - if (io_json.Find("shape")) { - triton::common::TritonJson::Value shape_json; - err = io_json.MemberAsArray("shape", &shape_json); + const char* datatype; + size_t datatypelen; + err = io_json.MemberAsString("datatype", &datatype, &datatypelen); GOTO_IF_ERR(err, earlyexit); - for (size_t sidx = 0; sidx < shape_json.ArraySize(); ++sidx) { - int64_t d; - err = shape_json.IndexAsInt(sidx, &d); + io->set_name(std::string(name, namelen)); + io->set_datatype(std::string(datatype, datatypelen)); + + if (io_json.Find("shape")) { + triton::common::TritonJson::Value shape_json; + err = io_json.MemberAsArray("shape", &shape_json); GOTO_IF_ERR(err, earlyexit); - io->add_shape(d); + + for (size_t sidx = 0; sidx < shape_json.ArraySize(); ++sidx) { + int64_t d; + err = shape_json.IndexAsInt(sidx, &d); + GOTO_IF_ERR(err, earlyexit); + io->add_shape(d); + } } } } + TRITONSERVER_MessageDelete(model_metadata_message); } - earlyexit: - TRITONSERVER_MessageDelete(model_metadata_message); - GrpcStatusUtil::Create(status, err); + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::METADATA); - - // Use non_inference_callback_service_->ModelMetadata to register the - // callback. This replaces the use of CommonCallData. - non_inference_callback_service_->ModelMetadata( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::ModelMetadataRequest, inference::ModelMetadataResponse>( - "ModelMetadata", non_inference_callback_service_, callback, - restricted_kv)); -} + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterModelConfig() -{ - // Define a lambda function 'callback' that takes a ModelConfigRequest, - // a ModelConfigResponse, and a grpc::Status. This function performs - // the same logic as the original OnExecuteModelConfig function. - auto callback = [this]( - inference::ModelConfigRequest& request, - inference::ModelConfigResponse* response, - ::grpc::Status* status) { + ::grpc::ServerUnaryReactor* ModelConfig( + ::grpc::CallbackServerContext* context, + const inference::ModelConfigRequest* request, + inference::ModelConfigResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic int64_t requested_model_version; auto err = - GetModelVersionFromString(request.version(), &requested_model_version); - GOTO_IF_ERR(err, earlyexit); - - TRITONSERVER_Message* model_config_message = nullptr; - err = TRITONSERVER_ServerModelConfig( - tritonserver_.get(), request.name().c_str(), requested_model_version, - 1 /* config_version */, &model_config_message); - GOTO_IF_ERR(err, earlyexit); + GetModelVersionFromString(request->version(), &requested_model_version); + if (err == nullptr) { + TRITONSERVER_Message* model_config_message = nullptr; + err = TRITONSERVER_ServerModelConfig( + tritonserver_.get(), request->name().c_str(), requested_model_version, + 1 /* config_version */, &model_config_message); + if (err == nullptr) { + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + model_config_message, &buffer, &byte_size); + if (err == nullptr) { + ::google::protobuf::util::JsonStringToMessage( + ::google::protobuf::stringpiece_internal::StringPiece( + buffer, (int)byte_size), + response->mutable_config()); + } + TRITONSERVER_MessageDelete(model_config_message); + } + } - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - model_config_message, &buffer, &byte_size); - GOTO_IF_ERR(err, earlyexit); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } - ::google::protobuf::util::JsonStringToMessage( - ::google::protobuf::stringpiece_internal::StringPiece( - buffer, static_cast(byte_size)), - response->mutable_config()); + // Other RPC methods (e.g., ServerReady, HealthCheck) would be implemented + // similarly. + ::grpc::ServerUnaryReactor* ModelStatistics( + ::grpc::CallbackServerContext* context, + const inference::ModelStatisticsRequest* request, + inference::ModelStatisticsResponse* response) override + { + auto* reactor = context->DefaultReactor(); - earlyexit: - TRITONSERVER_MessageDelete(model_config_message); - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::MODEL_CONFIG); - - // Use non_inference_callback_service_->ModelConfig to register the callback. - // This replaces the use of CommonCallData. - non_inference_callback_service_->ModelConfig( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::ModelConfigRequest, inference::ModelConfigResponse>( - "ModelConfig", non_inference_callback_service_, callback, - restricted_kv)); -} + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } -void -CommonHandler::RegisterModelStatistics() -{ - // Define a lambda function 'callback' that takes a ModelStatisticsRequest, - // a ModelStatisticsResponse, and a grpc::Status. This function performs - // the same logic as the original OnExecuteModelStatistics function. - auto callback = [this]( - inference::ModelStatisticsRequest& request, - inference::ModelStatisticsResponse* response, - ::grpc::Status* status) { + // Core business logic - kept same as original #ifdef TRITON_ENABLE_STATS triton::common::TritonJson::Value model_stats_json; int64_t requested_model_version; auto err = - GetModelVersionFromString(request.version(), &requested_model_version); - GOTO_IF_ERR(err, earlyexit); - - TRITONSERVER_Message* model_stats_message = nullptr; - err = TRITONSERVER_ServerModelStatistics( - tritonserver_.get(), request.name().c_str(), requested_model_version, - &model_stats_message); + GetModelVersionFromString(request->version(), &requested_model_version); GOTO_IF_ERR(err, earlyexit); + { + TRITONSERVER_Message* model_stats_message = nullptr; + err = TRITONSERVER_ServerModelStatistics( + tritonserver_.get(), request->name().c_str(), requested_model_version, + &model_stats_message); + GOTO_IF_ERR(err, earlyexit); - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - model_stats_message, &buffer, &byte_size); - GOTO_IF_ERR(err, earlyexit); + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + model_stats_message, &buffer, &byte_size); + GOTO_IF_ERR(err, earlyexit); - err = model_stats_json.Parse(buffer, byte_size); - GOTO_IF_ERR(err, earlyexit); + err = model_stats_json.Parse(buffer, byte_size); + GOTO_IF_ERR(err, earlyexit); - TRITONSERVER_MessageDelete(model_stats_message); + TRITONSERVER_MessageDelete(model_stats_message); + } if (model_stats_json.Find("model_stats")) { triton::common::TritonJson::Value stats_json; @@ -1129,62 +770,42 @@ CommonHandler::RegisterModelStatistics() } earlyexit: - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #else auto err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNAVAILABLE, "the server does not support model statistics"); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #endif - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::STATISTICS); - - // Use non_inference_callback_service_->ModelStatistics to register the - // callback. This replaces the use of CommonCallData. - non_inference_callback_service_->ModelStatistics( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::ModelStatisticsRequest, - inference::ModelStatisticsResponse>( - "ModelStatistics", non_inference_callback_service_, callback, - restricted_kv)); -} -template -TRITONSERVER_Error* -CommonHandler::SetStatisticsDuration( - triton::common::TritonJson::Value& statistics_json, - const std::string& statistics_name, - PBTYPE* mutable_statistics_duration_protobuf) const -{ - triton::common::TritonJson::Value statistics_duration_json; - RETURN_IF_ERR(statistics_json.MemberAsObject( - statistics_name.c_str(), &statistics_duration_json)); + reactor->Finish(status); + return reactor; + } - uint64_t value; - RETURN_IF_ERR(statistics_duration_json.MemberAsUInt("count", &value)); - mutable_statistics_duration_protobuf->set_count(value); - RETURN_IF_ERR(statistics_duration_json.MemberAsUInt("ns", &value)); - mutable_statistics_duration_protobuf->set_ns(value); + ::grpc::ServerUnaryReactor* TraceSetting( + ::grpc::CallbackServerContext* context, + const inference::TraceSettingRequest* request, + inference::TraceSettingResponse* response) override + { + auto* reactor = context->DefaultReactor(); - return nullptr; -} + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } -void -CommonHandler::RegisterTrace() -{ - // Define a lambda function 'callback' that takes a TraceSettingRequest, - // a TraceSettingResponse, and a grpc::Status. This function performs - // the same logic as the original OnExecuteTrace function. - auto callback = [this]( - inference::TraceSettingRequest& request, - inference::TraceSettingResponse* response, - ::grpc::Status* status) { + // Core business logic - kept same as original #ifdef TRITON_ENABLE_TRACING TRITONSERVER_Error* err = nullptr; TRITONSERVER_InferenceTraceLevel level = TRITONSERVER_TRACE_LEVEL_DISABLED; @@ -1195,29 +816,29 @@ CommonHandler::RegisterTrace() InferenceTraceMode trace_mode; TraceConfigMap config_map; - if (!request.model_name().empty()) { + if (!request->model_name().empty()) { bool ready = false; - GOTO_IF_ERR( - TRITONSERVER_ServerModelIsReady( - tritonserver_.get(), request.model_name().c_str(), - -1 /* model version */, &ready), - earlyexit); + err = TRITONSERVER_ServerModelIsReady( + tritonserver_.get(), request->model_name().c_str(), + -1 /* model version */, &ready); + GOTO_IF_ERR(err, earlyexit); if (!ready) { err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, - (std::string("Request for unknown model : ") + request.model_name()) + (std::string("Request for unknown model : ") + + request->model_name()) .c_str()); GOTO_IF_ERR(err, earlyexit); } } // Update trace setting - if (!request.settings().empty()) { + if (!request->settings().empty()) { TraceManager::NewSetting new_setting; { static std::string setting_name = "trace_file"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNSUPPORTED, "trace file location can not be updated through network " @@ -1227,8 +848,8 @@ CommonHandler::RegisterTrace() } { static std::string setting_name = "trace_level"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { if (it->second.value().size() == 0) { new_setting.clear_level_ = true; } else { @@ -1258,8 +879,8 @@ CommonHandler::RegisterTrace() } { static std::string setting_name = "trace_rate"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { if (it->second.value().size() == 0) { new_setting.clear_rate_ = true; } else if (it->second.value().size() == 1) { @@ -1298,8 +919,8 @@ CommonHandler::RegisterTrace() } { static std::string setting_name = "trace_count"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { if (it->second.value().size() == 0) { new_setting.clear_count_ = true; } else if (it->second.value().size() == 1) { @@ -1347,8 +968,8 @@ CommonHandler::RegisterTrace() } { static std::string setting_name = "log_frequency"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { if (it->second.value().size() == 0) { new_setting.clear_log_frequency_ = true; } else if (it->second.value().size() == 1) { @@ -1386,16 +1007,16 @@ CommonHandler::RegisterTrace() } } - err = - trace_manager_->UpdateTraceSetting(request.model_name(), new_setting); + err = trace_manager_->UpdateTraceSetting( + request->model_name(), new_setting); GOTO_IF_ERR(err, earlyexit); } - // Get current trace setting, this is needed even if the setting - // has been updated above as some values may not be provided in the request. + // Get current trace setting trace_manager_->GetTraceSetting( - request.model_name(), &level, &rate, &count, &log_frequency, &filepath, + request->model_name(), &level, &rate, &count, &log_frequency, &filepath, &trace_mode, &config_map); + // level { inference::TraceSettingResponse::SettingValue level_setting; @@ -1411,6 +1032,7 @@ CommonHandler::RegisterTrace() } (*response->mutable_settings())["trace_level"] = level_setting; } + (*response->mutable_settings())["trace_rate"].add_value( std::to_string(rate)); (*response->mutable_settings())["trace_count"].add_value( @@ -1442,52 +1064,53 @@ CommonHandler::RegisterTrace() } } } + earlyexit: - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #else auto err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNAVAILABLE, "the server does not support trace"); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #endif - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::TRACE); - - // Use non_inference_callback_service_->TraceSetting to register the callback. - // This replaces the use of CommonCallData. - non_inference_callback_service_->TraceSetting( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::TraceSettingRequest, inference::TraceSettingResponse>( - "TraceSetting", non_inference_callback_service_, callback, - restricted_kv)); -} -void -CommonHandler::RegisterLogging() -{ - // Define a lambda function 'callback' that takes a LogSettingsRequest, - // a LogSettingsResponse, and a grpc::Status. This function performs - // the same logic as the original OnExecuteLogging function. - auto callback = [this]( - inference::LogSettingsRequest& request, - inference::LogSettingsResponse* response, - ::grpc::Status* status) { + reactor->Finish(status); + return reactor; + } + + ::grpc::ServerUnaryReactor* LogSettings( + ::grpc::CallbackServerContext* context, + const inference::LogSettingsRequest* request, + inference::LogSettingsResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + // Core business logic - kept same as original #ifdef TRITON_ENABLE_LOGGING TRITONSERVER_Error* err = nullptr; // Update log settings // Server and Core repos do not have the same Logger object // Each update must be applied to both server and core repo versions - if (!request.settings().empty()) { + if (!request->settings().empty()) { { static std::string setting_name = "log_file"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNSUPPORTED, "log file location can not be updated through network protocol"); @@ -1496,8 +1119,8 @@ CommonHandler::RegisterLogging() } { static std::string setting_name = "log_info"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { const auto& log_param = it->second; if (log_param.parameter_choice_case() != inference::LogSettingsRequest_SettingValue::ParameterChoiceCase:: @@ -1516,8 +1139,8 @@ CommonHandler::RegisterLogging() } { static std::string setting_name = "log_warning"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { const auto& log_param = it->second; if (log_param.parameter_choice_case() != inference::LogSettingsRequest_SettingValue::ParameterChoiceCase:: @@ -1536,8 +1159,8 @@ CommonHandler::RegisterLogging() } { static std::string setting_name = "log_error"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { const auto& log_param = it->second; if (log_param.parameter_choice_case() != inference::LogSettingsRequest_SettingValue::ParameterChoiceCase:: @@ -1556,8 +1179,8 @@ CommonHandler::RegisterLogging() } { static std::string setting_name = "log_verbose_level"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { const auto& log_param = it->second; if (log_param.parameter_choice_case() != inference::LogSettingsRequest_SettingValue::ParameterChoiceCase:: @@ -1576,8 +1199,8 @@ CommonHandler::RegisterLogging() } { static std::string setting_name = "log_format"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { const auto& log_param = it->second; if (log_param.parameter_choice_case() != inference::LogSettingsRequest_SettingValue::ParameterChoiceCase:: @@ -1616,6 +1239,7 @@ CommonHandler::RegisterLogging() } GOTO_IF_ERR(err, earlyexit); } + (*response->mutable_settings())["log_file"].set_string_param(LOG_FILE); (*response->mutable_settings())["log_info"].set_bool_param(LOG_INFO_IS_ON); (*response->mutable_settings())["log_warning"].set_bool_param( @@ -1626,47 +1250,79 @@ CommonHandler::RegisterLogging() LOG_VERBOSE_LEVEL); (*response->mutable_settings())["log_format"].set_string_param( LOG_FORMAT_STRING); + earlyexit: - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #else auto err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNAVAILABLE, "the server does not support dynamic logging"); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #endif - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::LOGGING); - - // Use non_inference_callback_service_->LogSettings to register the callback. - // This replaces the use of CommonCallData. - non_inference_callback_service_->LogSettings( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::LogSettingsRequest, inference::LogSettingsResponse>( - "LogSettings", non_inference_callback_service_, callback, - restricted_kv)); -} -void -CommonHandler::RegisterSystemSharedMemoryStatus() -{ - // Define a lambda function 'callback' that takes a - // SystemSharedMemoryStatusRequest, a SystemSharedMemoryStatusResponse, and a - // grpc::Status. This function performs the same logic as the original - // OnExecuteSystemSharedMemoryStatus function. - auto callback = [this]( - inference::SystemSharedMemoryStatusRequest& request, - inference::SystemSharedMemoryStatusResponse* response, - ::grpc::Status* status) { + reactor->Finish(status); + return reactor; + } + + ::grpc::ServerUnaryReactor* SystemSharedMemoryRegister( + ::grpc::CallbackServerContext* context, + const inference::SystemSharedMemoryRegisterRequest* request, + inference::SystemSharedMemoryRegisterResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic - kept same as original + TRITONSERVER_Error* err = shm_manager_->RegisterSystemSharedMemory( + request->name(), request->key(), request->offset(), + request->byte_size()); + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } + + ::grpc::ServerUnaryReactor* SystemSharedMemoryStatus( + ::grpc::CallbackServerContext* context, + const inference::SystemSharedMemoryStatusRequest* request, + inference::SystemSharedMemoryStatusResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic - kept same as original triton::common::TritonJson::Value shm_status_json( triton::common::TritonJson::ValueType::ARRAY); TRITONSERVER_Error* err = shm_manager_->GetStatus( - request.name(), TRITONSERVER_MEMORY_CPU, &shm_status_json); + request->name(), TRITONSERVER_MEMORY_CPU, &shm_status_json); GOTO_IF_ERR(err, earlyexit); for (size_t idx = 0; idx < shm_status_json.ArraySize(); ++idx) { @@ -1702,116 +1358,80 @@ CommonHandler::RegisterSystemSharedMemoryStatus() } earlyexit: - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - - // Use non_inference_callback_service_->SystemSharedMemoryStatus to register - // the callback. This replaces the use of CommonCallData. - non_inference_callback_service_->SystemSharedMemoryStatus( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::SystemSharedMemoryStatusRequest, - inference::SystemSharedMemoryStatusResponse>( - "SystemSharedMemoryStatus", non_inference_callback_service_, callback, - restricted_kv)); -} + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterSystemSharedMemoryRegister() -{ - // Define a lambda function 'callback' that takes a - // SystemSharedMemoryRegisterRequest, a SystemSharedMemoryRegisterResponse, - // and a grpc::Status. This function performs the same logic as the original - // OnExecuteSystemSharedMemoryRegister function. - auto callback = [this]( - inference::SystemSharedMemoryRegisterRequest& request, - inference::SystemSharedMemoryRegisterResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = shm_manager_->RegisterSystemSharedMemory( - request.name(), request.key(), request.offset(), request.byte_size()); + ::grpc::ServerUnaryReactor* CudaSharedMemoryRegister( + ::grpc::CallbackServerContext* context, + const inference::CudaSharedMemoryRegisterRequest* request, + inference::CudaSharedMemoryRegisterResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic + TRITONSERVER_Error* err = nullptr; +#ifdef TRITON_ENABLE_GPU + err = shm_manager_->RegisterCUDASharedMemory( + request->name(), + reinterpret_cast( + request->raw_handle().c_str()), + request->byte_size(), request->device_id()); +#else + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "failed to register CUDA shared memory region: '" + + request->name() + "', GPUs not supported") + .c_str()); +#endif // TRITON_ENABLE_GPU - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - - // Use non_inference_callback_service_->SystemSharedMemoryRegister to register - // the callback. This replaces the use of CommonCallData. - non_inference_callback_service_->SystemSharedMemoryRegister( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::SystemSharedMemoryRegisterRequest, - inference::SystemSharedMemoryRegisterResponse>( - "SystemSharedMemoryRegister", non_inference_callback_service_, - callback, restricted_kv)); -} + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterSystemSharedMemoryUnregister() -{ - auto OnRegisterSystemSharedMemoryUnregister = - [this]( - ::grpc::ServerContext* ctx, - inference::SystemSharedMemoryUnregisterRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryUnregisterResponse>* responder, - void* tag) { - this->service_->RequestSystemSharedMemoryUnregister( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteSystemSharedMemoryUnregister = - [this]( - inference::SystemSharedMemoryUnregisterRequest& request, - inference::SystemSharedMemoryUnregisterResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.name().empty()) { - err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_CPU); - } else { - err = - shm_manager_->Unregister(request.name(), TRITONSERVER_MEMORY_CPU); - } + ::grpc::ServerUnaryReactor* CudaSharedMemoryStatus( + ::grpc::CallbackServerContext* context, + const inference::CudaSharedMemoryStatusRequest* request, + inference::CudaSharedMemoryStatusResponse* response) override + { + auto* reactor = context->DefaultReactor(); - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryUnregisterResponse>, - inference::SystemSharedMemoryUnregisterRequest, - inference::SystemSharedMemoryUnregisterResponse>( - "SystemSharedMemoryUnregister", 0, OnRegisterSystemSharedMemoryUnregister, - OnExecuteSystemSharedMemoryUnregister, false /* async */, cq_, - restricted_kv, response_delay_); -} + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } -void -CommonHandler::RegisterCudaSharedMemoryStatus() -{ - // Define a lambda function 'callback' that takes a - // CudaSharedMemoryStatusRequest, a CudaSharedMemoryStatusResponse, and a - // grpc::Status. This function performs the same logic as the original - // OnExecuteCudaSharedMemoryStatus function. - auto callback = [this]( - inference::CudaSharedMemoryStatusRequest& request, - inference::CudaSharedMemoryStatusResponse* response, - ::grpc::Status* status) { + // Core business logic - kept same as original triton::common::TritonJson::Value shm_status_json( triton::common::TritonJson::ValueType::ARRAY); TRITONSERVER_Error* err = shm_manager_->GetStatus( - request.name(), TRITONSERVER_MEMORY_GPU, &shm_status_json); + request->name(), TRITONSERVER_MEMORY_GPU, &shm_status_json); GOTO_IF_ERR(err, earlyexit); for (size_t idx = 0; idx < shm_status_json.ArraySize(); ++idx) { @@ -1839,380 +1459,163 @@ CommonHandler::RegisterCudaSharedMemoryStatus() (*response->mutable_regions())[name] = region_status; } + earlyexit: - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - - // Use non_inference_callback_service_->CudaSharedMemoryStatus to register the - // callback. This replaces the use of CommonCallData. - non_inference_callback_service_->CudaSharedMemoryStatus( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::CudaSharedMemoryStatusRequest, - inference::CudaSharedMemoryStatusResponse>( - "CudaSharedMemoryStatus", non_inference_callback_service_, callback, - restricted_kv)); -} + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterCudaSharedMemoryRegister() -{ - // Define a lambda function 'callback' that takes a - // CudaSharedMemoryRegisterRequest, a CudaSharedMemoryRegisterResponse, and a - // grpc::Status. This function performs the same logic as the original - // OnExecuteCudaSharedMemoryRegister function. - auto callback = [this]( - inference::CudaSharedMemoryRegisterRequest& request, - inference::CudaSharedMemoryRegisterResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; -#ifdef TRITON_ENABLE_GPU - err = shm_manager_->RegisterCUDASharedMemory( - request.name(), - reinterpret_cast( - request.raw_handle().c_str()), - request.byte_size(), request.device_id()); -#else - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - std::string( - "failed to register CUDA shared memory region: '" + request.name() + - "', GPUs not supported") - .c_str()); -#endif // TRITON_ENABLE_GPU + ::grpc::ServerUnaryReactor* SystemSharedMemoryUnregister( + ::grpc::CallbackServerContext* context, + const inference::SystemSharedMemoryUnregisterRequest* request, + inference::SystemSharedMemoryUnregisterResponse* response) override + { + auto* reactor = context->DefaultReactor(); - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - - // Use non_inference_callback_service_->CudaSharedMemoryRegister to register - // the callback. This replaces the use of CommonCallData. - non_inference_callback_service_->CudaSharedMemoryRegister( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::CudaSharedMemoryRegisterRequest, - inference::CudaSharedMemoryRegisterResponse>( - "CudaSharedMemoryRegister", non_inference_callback_service_, callback, - restricted_kv)); -} + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } -void -CommonHandler::RegisterCudaSharedMemoryUnregister() -{ - // Define a lambda function 'callback' that takes a - // CudaSharedMemoryUnregisterRequest, a CudaSharedMemoryUnregisterResponse, - // and a grpc::Status. This function performs the same logic as the original - // OnExecuteCudaSharedMemoryUnregister function. - auto callback = [this]( - inference::CudaSharedMemoryUnregisterRequest& request, - inference::CudaSharedMemoryUnregisterResponse* response, - ::grpc::Status* status) { + // Core business logic - kept same as original TRITONSERVER_Error* err = nullptr; - if (request.name().empty()) { - err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_GPU); + if (request->name().empty()) { + err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_CPU); } else { - err = shm_manager_->Unregister(request.name(), TRITONSERVER_MEMORY_GPU); + err = shm_manager_->Unregister(request->name(), TRITONSERVER_MEMORY_CPU); } - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - - // Use non_inference_callback_service_->CudaSharedMemoryUnregister to register - // the callback. This replaces the use of CommonCallData. - non_inference_callback_service_->CudaSharedMemoryUnregister( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::CudaSharedMemoryUnregisterRequest, - inference::CudaSharedMemoryUnregisterResponse>( - "CudaSharedMemoryUnregister", non_inference_callback_service_, - callback, restricted_kv)); -} + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterRepositoryIndex() -{ - // Define a lambda function 'callback' that takes a RepositoryIndexRequest, - // a RepositoryIndexResponse, and a grpc::Status. This function performs - // the same logic as the original OnExecuteRepositoryIndex function. - auto callback = [this]( - inference::RepositoryIndexRequest& request, - inference::RepositoryIndexResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.repository_name().empty()) { - uint32_t flags = 0; - if (request.ready()) { - flags |= TRITONSERVER_INDEX_FLAG_READY; - } + // Add here + ::grpc::ServerUnaryReactor* CudaSharedMemoryUnregister( + ::grpc::CallbackServerContext* context, + const inference::CudaSharedMemoryUnregisterRequest* request, + inference::CudaSharedMemoryUnregisterResponse* response) override + { + auto* reactor = context->DefaultReactor(); - TRITONSERVER_Message* model_index_message = nullptr; - err = TRITONSERVER_ServerModelIndex( - tritonserver_.get(), flags, &model_index_message); - GOTO_IF_ERR(err, earlyexit); + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - model_index_message, &buffer, &byte_size); - GOTO_IF_ERR(err, earlyexit); + // Core business logic - kept same as original + TRITONSERVER_Error* err = nullptr; + if (request->name().empty()) { + err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_GPU); + } else { + err = shm_manager_->Unregister(request->name(), TRITONSERVER_MEMORY_GPU); + } - triton::common::TritonJson::Value model_index_json; - err = model_index_json.Parse(buffer, byte_size); - GOTO_IF_ERR(err, earlyexit); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } - err = model_index_json.AssertType( - triton::common::TritonJson::ValueType::ARRAY); - GOTO_IF_ERR(err, earlyexit); + private: + std::shared_ptr tritonserver_; + std::shared_ptr shm_manager_; + std::pair restricted_kv_; +}; - for (size_t idx = 0; idx < model_index_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value index_json; - err = model_index_json.IndexAsObject(idx, &index_json); - GOTO_IF_ERR(err, earlyexit); +// +// CommonHandler +// +// A common handler for all non-inference requests. +// +class CommonHandler : public HandlerBase { + public: + CommonHandler( + const std::string& name, + const std::shared_ptr& tritonserver, + const std::shared_ptr& shm_manager, + TraceManager* trace_manager, + inference::GRPCInferenceService::AsyncService* service, + ::grpc::health::v1::Health::AsyncService* health_service, + inference::GRPCInferenceService::CallbackService* + non_inference_callback_service, + const RestrictedFeatures& restricted_keys, const uint64_t response_delay); - auto model_index = response->add_models(); + // Implement pure virtual functions + void Start() override {} // No-op for callback implementation + void Stop() override {} // No-op for callback implementation - const char* name; - size_t namelen; - err = index_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_name(std::string(name, namelen)); + // Descriptive name of of the handler. + const std::string& Name() const { return name_; } - if (index_json.Find("version")) { - const char* version; - size_t versionlen; - err = index_json.MemberAsString("version", &version, &versionlen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_version(std::string(version, versionlen)); - } - if (index_json.Find("state")) { - const char* state; - size_t statelen; - err = index_json.MemberAsString("state", &state, &statelen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_state(std::string(state, statelen)); - } - if (index_json.Find("reason")) { - const char* reason; - size_t reasonlen; - err = index_json.MemberAsString("reason", &reason, &reasonlen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_reason(std::string(reason, reasonlen)); - } - } + void CreateUnifiedCallbackService(); - TRITONSERVER_MessageDelete(model_index_message); - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_UNSUPPORTED, - "'repository_name' specification is not supported"); - } + // Add a new public method to return the non_inference_callback_service_ + inference::GRPCInferenceService::CallbackService* GetUnifiedCallbackService() + { + return non_inference_callback_service_; + } - earlyexit: - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); - - // Use non_inference_callback_service_->RepositoryIndex to register the - // callback. This replaces the use of CommonCallData. - non_inference_callback_service_->RepositoryIndex( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::RepositoryIndexRequest, - inference::RepositoryIndexResponse>( - "RepositoryIndex", non_inference_callback_service_, callback, - restricted_kv)); -} + private: + const std::string name_; + std::shared_ptr tritonserver_; + std::shared_ptr shm_manager_; + TraceManager* trace_manager_; + inference::GRPCInferenceService::AsyncService* service_; + ::grpc::health::v1::Health::AsyncService* health_service_; + inference::GRPCInferenceService::CallbackService* + non_inference_callback_service_; + std::unique_ptr thread_; + RestrictedFeatures restricted_keys_; + const uint64_t response_delay_; +}; -void -CommonHandler::RegisterRepositoryModelLoad() +CommonHandler::CommonHandler( + const std::string& name, + const std::shared_ptr& tritonserver, + const std::shared_ptr& shm_manager, + TraceManager* trace_manager, + inference::GRPCInferenceService::AsyncService* service, + ::grpc::health::v1::Health::AsyncService* health_service, + inference::GRPCInferenceService::CallbackService* + non_inference_callback_service, + const RestrictedFeatures& restricted_keys, const uint64_t response_delay) + : name_(name), tritonserver_(tritonserver), shm_manager_(shm_manager), + trace_manager_(trace_manager), service_(service), + health_service_(health_service), + non_inference_callback_service_(non_inference_callback_service), + restricted_keys_(restricted_keys), response_delay_(response_delay) { - // Define a lambda function 'callback' that takes a - // RepositoryModelLoadRequest, a RepositoryModelLoadResponse, and a - // grpc::Status. This function performs the same logic as the original - // OnExecuteRepositoryModelLoad function. - auto callback = [this]( - inference::RepositoryModelLoadRequest& request, - inference::RepositoryModelLoadResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.repository_name().empty()) { - std::vector params; - // WAR for the const-ness check - std::vector const_params; - for (const auto& param_proto : request.parameters()) { - if (param_proto.first == "config") { - if (param_proto.second.parameter_choice_case() != - inference::ModelRepositoryParameter::ParameterChoiceCase:: - kStringParam) { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("invalid value type for load parameter '") + - param_proto.first + "', expected string_param.") - .c_str()); - break; - } else { - auto param = TRITONSERVER_ParameterNew( - param_proto.first.c_str(), TRITONSERVER_PARAMETER_STRING, - param_proto.second.string_param().c_str()); - if (param != nullptr) { - params.emplace_back(param); - const_params.emplace_back(param); - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - "unexpected error on creating Triton parameter"); - break; - } - } - } else if (param_proto.first.rfind("file:", 0) == 0) { - if (param_proto.second.parameter_choice_case() != - inference::ModelRepositoryParameter::ParameterChoiceCase:: - kBytesParam) { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("invalid value type for load parameter '") + - param_proto.first + "', expected bytes_param.") - .c_str()); - break; - } else { - auto param = TRITONSERVER_ParameterBytesNew( - param_proto.first.c_str(), - param_proto.second.bytes_param().data(), - param_proto.second.bytes_param().length()); - if (param != nullptr) { - params.emplace_back(param); - const_params.emplace_back(param); - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - "unexpected error on creating Triton parameter"); - break; - } - } - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("unrecognized load parameter '") + - param_proto.first + "'.") - .c_str()); - break; - } - } - if (err == nullptr) { - err = TRITONSERVER_ServerLoadModelWithParameters( - tritonserver_.get(), request.model_name().c_str(), - const_params.data(), const_params.size()); - } - // Assumes no further 'params' access after load API returns - for (auto& param : params) { - TRITONSERVER_ParameterDelete(param); - } - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_UNSUPPORTED, - "'repository_name' specification is not supported"); - } - - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); - - // Use non_inference_callback_service_->RepositoryModelLoad to register the - // callback. This replaces the use of CommonCallData. - non_inference_callback_service_->RepositoryModelLoad( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::RepositoryModelLoadRequest, - inference::RepositoryModelLoadResponse>( - "RepositoryModelLoad", non_inference_callback_service_, callback, - restricted_kv)); + CreateUnifiedCallbackService(); } void -CommonHandler::RegisterRepositoryModelUnload() +CommonHandler::CreateUnifiedCallbackService() { - // Define a lambda function 'callback' that takes a - // RepositoryModelUnloadRequest, a RepositoryModelUnloadResponse, and a - // grpc::Status. This function performs the same logic as the original - // OnExecuteRepositoryModelUnload function. - auto callback = [this]( - inference::RepositoryModelUnloadRequest& request, - inference::RepositoryModelUnloadResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.repository_name().empty()) { - // Check if the dependent models should be removed - bool unload_dependents = false; - for (auto param : request.parameters()) { - if (param.first.compare("unload_dependents") == 0) { - const auto& unload_param = param.second; - if (unload_param.parameter_choice_case() != - inference::ModelRepositoryParameter::ParameterChoiceCase:: - kBoolParam) { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - "invalid value type for 'unload_dependents' parameter, " - "expected " - "bool_param."); - } - unload_dependents = unload_param.bool_param(); - break; - } - } - if (err == nullptr) { - if (unload_dependents) { - err = TRITONSERVER_ServerUnloadModelAndDependents( - tritonserver_.get(), request.model_name().c_str()); - } else { - err = TRITONSERVER_ServerUnloadModel( - tritonserver_.get(), request.model_name().c_str()); - } - } - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_UNSUPPORTED, - "'repository_name' specification is not supported"); - } - - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); - - // Use non_inference_callback_service_->RepositoryModelUnload to register the - // callback. This replaces the use of CommonCallData. - non_inference_callback_service_->RepositoryModelUnload( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::RepositoryModelUnloadRequest, - inference::RepositoryModelUnloadResponse>( - "RepositoryModelUnload", non_inference_callback_service_, callback, - restricted_kv)); + const auto& restrictedKV = restricted_keys_.Get(RestrictedCategory::HEALTH); + // Create a single unified callback service instance. + non_inference_callback_service_ = + new UnifiedCallbackService(tritonserver_, shm_manager_, restrictedKV); } } // namespace @@ -2254,9 +1657,8 @@ Server::Server( builder_.AddListeningPort(server_addr_, credentials, &bound_port_); builder_.SetMaxMessageSize(MAX_GRPC_MESSAGE_SIZE); - builder_.RegisterService(&service_); - builder_.RegisterService(&health_service_); - builder_.RegisterService(&non_inference_callback_service_); + // builder_.RegisterService(&service_); + // builder_.RegisterService(&health_service_); builder_.AddChannelArgument( GRPC_ARG_ALLOW_REUSEPORT, options.socket_.reuse_port_); @@ -2342,7 +1744,6 @@ Server::Server( LOG_TABLE_VERBOSE(1, table_printer); } - common_cq_ = builder_.AddCompletionQueue(); model_infer_cq_ = builder_.AddCompletionQueue(); model_stream_infer_cq_ = builder_.AddCompletionQueue(); @@ -2355,8 +1756,15 @@ Server::Server( // A common Handler for other non-inference requests common_handler_.reset(new CommonHandler( "CommonHandler", tritonserver_, shm_manager_, trace_manager_, &service_, - &health_service_, &non_inference_callback_service_, common_cq_.get(), + &health_service_, &non_inference_callback_service_, options.restricted_protocols_, response_delay)); + // Use common_handler_ and register + // builder_.RegisterService(non_inference_callback_service_); here Cast to + // CommonHandler to access the method + auto* handler = dynamic_cast(common_handler_.get()); + if (handler != nullptr) { + builder_.RegisterService(handler->GetUnifiedCallbackService()); + } // [FIXME] "register" logic is different for infer // Handler for model inference requests. @@ -2519,7 +1927,6 @@ Server::Start() } // Remove this - common_handler_->Start(); for (auto& model_infer_handler : model_infer_handlers_) { model_infer_handler->Start(); } @@ -2543,13 +1950,13 @@ Server::Stop() // Always shutdown the completion queue after the server. server_->Shutdown(); - common_cq_->Shutdown(); + // common_cq_->Shutdown(); model_infer_cq_->Shutdown(); model_stream_infer_cq_->Shutdown(); // Must stop all handlers explicitly to wait for all the handler // threads to join since they are referencing completion queue, etc. - common_handler_->Stop(); + // common_handler_->Stop(); for (auto& model_infer_handler : model_infer_handlers_) { model_infer_handler->Stop(); } diff --git a/src/grpc/grpc_server.h b/src/grpc/grpc_server.h index 2a7a5ff0ba..89203d5d0b 100644 --- a/src/grpc/grpc_server.h +++ b/src/grpc/grpc_server.h @@ -148,7 +148,7 @@ class Server { std::unique_ptr<::grpc::ServerCompletionQueue> model_infer_cq_; std::unique_ptr<::grpc::ServerCompletionQueue> model_stream_infer_cq_; - // std::unique_ptr common_handler_; + std::unique_ptr common_handler_; std::vector> model_infer_handlers_; std::vector> model_stream_infer_handlers_; From c3c7c90f6c1ab710bdea78f6ccda2dbc8e120db0 Mon Sep 17 00:00:00 2001 From: Indrajit Bhosale Date: Wed, 26 Feb 2025 10:04:49 -0800 Subject: [PATCH 4/8] Non Infrence Migrated --- src/grpc/CMakeLists.txt | 3 ++- src/grpc/grpc_handler.h | 3 ++- src/grpc/grpc_server.cc | 13 +++++++------ src/grpc/grpc_server.h | 3 ++- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/grpc/CMakeLists.txt b/src/grpc/CMakeLists.txt index 0cd027a30a..1b0544c37c 100644 --- a/src/grpc/CMakeLists.txt +++ b/src/grpc/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -67,6 +67,7 @@ target_link_libraries( triton-common-json # from repo-common grpc-health-library # from repo-common grpc-service-library # from repo-common + grpccallback-service-library # from repo-common triton-core-serverapi # from repo-core triton-core-serverstub # from repo-core gRPC::grpc++ diff --git a/src/grpc/grpc_handler.h b/src/grpc/grpc_handler.h index ce9a30667e..ad5f551b70 100644 --- a/src/grpc/grpc_handler.h +++ b/src/grpc/grpc_handler.h @@ -30,6 +30,7 @@ #include #include "grpc_service.grpc.pb.h" +#include "grpccallback_service.grpc.pb.h" namespace triton { namespace server { namespace grpc { class HandlerBase { @@ -37,7 +38,7 @@ class HandlerBase { virtual ~HandlerBase() = default; virtual void Start() = 0; virtual void Stop() = 0; - virtual inference::GRPCInferenceService::CallbackService* + virtual inference::GRPCInferenceServiceCallback::CallbackService* GetUnifiedCallbackService() { return nullptr; diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index ab644e3a07..1922801ac6 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -82,7 +82,7 @@ namespace { // Define your unified callback service that implements several non-inference // RPCs. class UnifiedCallbackService - : public inference::GRPCInferenceService::CallbackService { + : public inference::GRPCInferenceServiceCallback::CallbackService { public: UnifiedCallbackService( const std::shared_ptr& server, @@ -1557,7 +1557,7 @@ class CommonHandler : public HandlerBase { TraceManager* trace_manager, inference::GRPCInferenceService::AsyncService* service, ::grpc::health::v1::Health::AsyncService* health_service, - inference::GRPCInferenceService::CallbackService* + inference::GRPCInferenceServiceCallback::CallbackService* non_inference_callback_service, const RestrictedFeatures& restricted_keys, const uint64_t response_delay); @@ -1571,7 +1571,8 @@ class CommonHandler : public HandlerBase { void CreateUnifiedCallbackService(); // Add a new public method to return the non_inference_callback_service_ - inference::GRPCInferenceService::CallbackService* GetUnifiedCallbackService() + inference::GRPCInferenceServiceCallback::CallbackService* + GetUnifiedCallbackService() { return non_inference_callback_service_; } @@ -1583,7 +1584,7 @@ class CommonHandler : public HandlerBase { TraceManager* trace_manager_; inference::GRPCInferenceService::AsyncService* service_; ::grpc::health::v1::Health::AsyncService* health_service_; - inference::GRPCInferenceService::CallbackService* + inference::GRPCInferenceServiceCallback::CallbackService* non_inference_callback_service_; std::unique_ptr thread_; RestrictedFeatures restricted_keys_; @@ -1597,7 +1598,7 @@ CommonHandler::CommonHandler( TraceManager* trace_manager, inference::GRPCInferenceService::AsyncService* service, ::grpc::health::v1::Health::AsyncService* health_service, - inference::GRPCInferenceService::CallbackService* + inference::GRPCInferenceServiceCallback::CallbackService* non_inference_callback_service, const RestrictedFeatures& restricted_keys, const uint64_t response_delay) : name_(name), tritonserver_(tritonserver), shm_manager_(shm_manager), @@ -1657,7 +1658,7 @@ Server::Server( builder_.AddListeningPort(server_addr_, credentials, &bound_port_); builder_.SetMaxMessageSize(MAX_GRPC_MESSAGE_SIZE); - // builder_.RegisterService(&service_); + builder_.RegisterService(&service_); // builder_.RegisterService(&health_service_); builder_.AddChannelArgument( GRPC_ARG_ALLOW_REUSEPORT, options.socket_.reuse_port_); diff --git a/src/grpc/grpc_server.h b/src/grpc/grpc_server.h index 89203d5d0b..eceb7f9b85 100644 --- a/src/grpc/grpc_server.h +++ b/src/grpc/grpc_server.h @@ -36,6 +36,7 @@ #include "grpc_handler.h" #include "grpc_service.grpc.pb.h" #include "grpc_utils.h" +#include "grpccallback_service.grpc.pb.h" #include "health.grpc.pb.h" #include "infer_handler.h" #include "stream_infer_handler.h" @@ -139,7 +140,7 @@ class Server { inference::GRPCInferenceService::AsyncService service_; ::grpc::health::v1::Health::AsyncService health_service_; - inference::GRPCInferenceService::CallbackService + inference::GRPCInferenceServiceCallback::CallbackService non_inference_callback_service_; std::unique_ptr<::grpc::Server> server_; From ad4a2a281ac07e1c31039025ed3af3f6013f802a Mon Sep 17 00:00:00 2001 From: Indrajit Bhosale Date: Mon, 10 Mar 2025 19:15:47 -0700 Subject: [PATCH 5/8] Add missing RPCs --- src/grpc/grpc_server.cc | 325 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 321 insertions(+), 4 deletions(-) diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index 1922801ac6..88039b86db 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -82,7 +82,8 @@ namespace { // Define your unified callback service that implements several non-inference // RPCs. class UnifiedCallbackService - : public inference::GRPCInferenceServiceCallback::CallbackService { + : public inference::GRPCInferenceServiceCallback::CallbackService, + public ::grpc::health::v1::Health::CallbackService { public: UnifiedCallbackService( const std::shared_ptr& server, @@ -1537,6 +1538,323 @@ class UnifiedCallbackService return reactor; } + ::grpc::ServerUnaryReactor* RepositoryIndex( + ::grpc::CallbackServerContext* context, + const inference::RepositoryIndexRequest* request, + inference::RepositoryIndexResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic + TRITONSERVER_Error* err = nullptr; + if (request->repository_name().empty()) { + uint32_t flags = 0; + if (request->ready()) { + flags |= TRITONSERVER_INDEX_FLAG_READY; + } + + TRITONSERVER_Message* model_index_message = nullptr; + err = TRITONSERVER_ServerModelIndex( + tritonserver_.get(), flags, &model_index_message); + if (err == nullptr) { + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + model_index_message, &buffer, &byte_size); + if (err == nullptr) { + triton::common::TritonJson::Value model_index_json; + err = model_index_json.Parse(buffer, byte_size); + if (err == nullptr) { + err = model_index_json.AssertType( + triton::common::TritonJson::ValueType::ARRAY); + if (err == nullptr) { + for (size_t idx = 0; idx < model_index_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value index_json; + err = model_index_json.IndexAsObject(idx, &index_json); + if (err != nullptr) { + break; + } + + auto model_index = response->add_models(); + + const char* name; + size_t namelen; + err = index_json.MemberAsString("name", &name, &namelen); + if (err != nullptr) { + break; + } + model_index->set_name(std::string(name, namelen)); + + if (index_json.Find("version")) { + const char* version; + size_t versionlen; + err = index_json.MemberAsString( + "version", &version, &versionlen); + if (err != nullptr) { + break; + } + model_index->set_version(std::string(version, versionlen)); + } + if (index_json.Find("state")) { + const char* state; + size_t statelen; + err = index_json.MemberAsString("state", &state, &statelen); + if (err != nullptr) { + break; + } + model_index->set_state(std::string(state, statelen)); + } + if (index_json.Find("reason")) { + const char* reason; + size_t reasonlen; + err = + index_json.MemberAsString("reason", &reason, &reasonlen); + if (err != nullptr) { + break; + } + model_index->set_reason(std::string(reason, reasonlen)); + } + } + } + } + } + TRITONSERVER_MessageDelete(model_index_message); + } + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "'repository_name' specification is not supported"); + } + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } + + ::grpc::ServerUnaryReactor* RepositoryModelLoad( + ::grpc::CallbackServerContext* context, + const inference::RepositoryModelLoadRequest* request, + inference::RepositoryModelLoadResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic + TRITONSERVER_Error* err = nullptr; + if (request->repository_name().empty()) { + std::vector params; + // WAR for the const-ness check + std::vector const_params; + + for (const auto& param_proto : request->parameters()) { + if (param_proto.first == "config") { + if (param_proto.second.parameter_choice_case() != + inference::ModelRepositoryParameter::ParameterChoiceCase:: + kStringParam) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("invalid value type for load parameter '") + + param_proto.first + "', expected string_param.") + .c_str()); + break; + } else { + auto param = TRITONSERVER_ParameterNew( + param_proto.first.c_str(), TRITONSERVER_PARAMETER_STRING, + param_proto.second.string_param().c_str()); + if (param != nullptr) { + params.emplace_back(param); + const_params.emplace_back(param); + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "unexpected error on creating Triton parameter"); + break; + } + } + } else if (param_proto.first.rfind("file:", 0) == 0) { + if (param_proto.second.parameter_choice_case() != + inference::ModelRepositoryParameter::ParameterChoiceCase:: + kBytesParam) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("invalid value type for load parameter '") + + param_proto.first + "', expected bytes_param.") + .c_str()); + break; + } else { + auto param = TRITONSERVER_ParameterBytesNew( + param_proto.first.c_str(), + param_proto.second.bytes_param().data(), + param_proto.second.bytes_param().length()); + if (param != nullptr) { + params.emplace_back(param); + const_params.emplace_back(param); + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "unexpected error on creating Triton parameter"); + break; + } + } + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("unrecognized load parameter '") + + param_proto.first + "'.") + .c_str()); + break; + } + } + + if (err == nullptr) { + err = TRITONSERVER_ServerLoadModelWithParameters( + tritonserver_.get(), request->model_name().c_str(), + const_params.data(), const_params.size()); + } + + // Assumes no further 'params' access after load API returns + for (auto& param : params) { + TRITONSERVER_ParameterDelete(param); + } + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "'repository_name' specification is not supported"); + } + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } + + ::grpc::ServerUnaryReactor* RepositoryModelUnload( + ::grpc::CallbackServerContext* context, + const inference::RepositoryModelUnloadRequest* request, + inference::RepositoryModelUnloadResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic + TRITONSERVER_Error* err = nullptr; + if (request->repository_name().empty()) { + // Check if the dependent models should be removed + bool unload_dependents = false; + for (const auto& param : request->parameters()) { + if (param.first.compare("unload_dependents") == 0) { + const auto& unload_param = param.second; + if (unload_param.parameter_choice_case() != + inference::ModelRepositoryParameter::ParameterChoiceCase:: + kBoolParam) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "invalid value type for 'unload_dependents' parameter, " + "expected bool_param."); + } + unload_dependents = unload_param.bool_param(); + break; + } + } + + if (err == nullptr) { + if (unload_dependents) { + err = TRITONSERVER_ServerUnloadModelAndDependents( + tritonserver_.get(), request->model_name().c_str()); + } else { + err = TRITONSERVER_ServerUnloadModel( + tritonserver_.get(), request->model_name().c_str()); + } + } + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "'repository_name' specification is not supported"); + } + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } + + ::grpc::ServerUnaryReactor* Check( + ::grpc::CallbackServerContext* context, + const ::grpc::health::v1::HealthCheckRequest* request, + ::grpc::health::v1::HealthCheckResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Check if server is ready + bool ready = false; + TRITONSERVER_Error* err = + TRITONSERVER_ServerIsReady(tritonserver_.get(), &ready); + + // Set health status based on server readiness + if (err == nullptr && ready) { + response->set_status(::grpc::health::v1::HealthCheckResponse::SERVING); + } else { + response->set_status( + ::grpc::health::v1::HealthCheckResponse::NOT_SERVING); + } + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } + private: std::shared_ptr tritonserver_; std::shared_ptr shm_manager_; @@ -1759,11 +2077,10 @@ Server::Server( "CommonHandler", tritonserver_, shm_manager_, trace_manager_, &service_, &health_service_, &non_inference_callback_service_, options.restricted_protocols_, response_delay)); - // Use common_handler_ and register - // builder_.RegisterService(non_inference_callback_service_); here Cast to - // CommonHandler to access the method + // Use common_handler_ and register services auto* handler = dynamic_cast(common_handler_.get()); if (handler != nullptr) { + // Register the unified service directly without casting builder_.RegisterService(handler->GetUnifiedCallbackService()); } From b896688e04a4a6fb18485784d834b19762954d5f Mon Sep 17 00:00:00 2001 From: Indrajit Bhosale Date: Mon, 10 Mar 2025 19:50:41 -0700 Subject: [PATCH 6/8] Health API fixed --- src/grpc/grpc_handler.h | 7 +++ src/grpc/grpc_server.cc | 95 +++++++++++++++++++++++++++++++++++------ src/grpc/grpc_server.h | 1 + 3 files changed, 89 insertions(+), 14 deletions(-) diff --git a/src/grpc/grpc_handler.h b/src/grpc/grpc_handler.h index ad5f551b70..405a78d737 100644 --- a/src/grpc/grpc_handler.h +++ b/src/grpc/grpc_handler.h @@ -31,6 +31,7 @@ #include "grpc_service.grpc.pb.h" #include "grpccallback_service.grpc.pb.h" +#include "health.grpc.pb.h" namespace triton { namespace server { namespace grpc { class HandlerBase { @@ -43,6 +44,12 @@ class HandlerBase { { return nullptr; } + + virtual ::grpc::health::v1::Health::CallbackService* + GetHealthCallbackService() + { + return nullptr; + } }; class ICallData { diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index 88039b86db..357f63d3ac 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -79,8 +79,61 @@ namespace { // are deemed to be not performance critical. //========================================================================= -// Define your unified callback service that implements several non-inference -// RPCs. +// Define a dedicated health service that implements the health check RPC +class HealthCallbackService + : public ::grpc::health::v1::Health::CallbackService { + public: + HealthCallbackService( + const std::shared_ptr& server, + const std::pair& restrictedKV) + : tritonserver_(server), restricted_kv_(restrictedKV) + { + } + + ::grpc::ServerUnaryReactor* Check( + ::grpc::CallbackServerContext* context, + const ::grpc::health::v1::HealthCheckRequest* request, + ::grpc::health::v1::HealthCheckResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // Check restricted access if configured + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Check if server is ready + bool ready = false; + TRITONSERVER_Error* err = + TRITONSERVER_ServerIsReady(tritonserver_.get(), &ready); + + // Set health status based on server readiness + if (err == nullptr && ready) { + response->set_status(::grpc::health::v1::HealthCheckResponse::SERVING); + } else { + response->set_status( + ::grpc::health::v1::HealthCheckResponse::NOT_SERVING); + } + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } + + private: + std::shared_ptr tritonserver_; + std::pair restricted_kv_; +}; + class UnifiedCallbackService : public inference::GRPCInferenceServiceCallback::CallbackService, public ::grpc::health::v1::Health::CallbackService { @@ -1886,15 +1939,20 @@ class CommonHandler : public HandlerBase { // Descriptive name of of the handler. const std::string& Name() const { return name_; } - void CreateUnifiedCallbackService(); + void CreateCallbackServices(); - // Add a new public method to return the non_inference_callback_service_ + // Add methods to return the callback services inference::GRPCInferenceServiceCallback::CallbackService* GetUnifiedCallbackService() { return non_inference_callback_service_; } + ::grpc::health::v1::Health::CallbackService* GetHealthCallbackService() + { + return health_callback_service_; + } + private: const std::string name_; std::shared_ptr tritonserver_; @@ -1904,6 +1962,7 @@ class CommonHandler : public HandlerBase { ::grpc::health::v1::Health::AsyncService* health_service_; inference::GRPCInferenceServiceCallback::CallbackService* non_inference_callback_service_; + ::grpc::health::v1::Health::CallbackService* health_callback_service_; std::unique_ptr thread_; RestrictedFeatures restricted_keys_; const uint64_t response_delay_; @@ -1923,18 +1982,26 @@ CommonHandler::CommonHandler( trace_manager_(trace_manager), service_(service), health_service_(health_service), non_inference_callback_service_(non_inference_callback_service), - restricted_keys_(restricted_keys), response_delay_(response_delay) + health_callback_service_(nullptr), restricted_keys_(restricted_keys), + response_delay_(response_delay) { - CreateUnifiedCallbackService(); + CreateCallbackServices(); } void -CommonHandler::CreateUnifiedCallbackService() +CommonHandler::CreateCallbackServices() { - const auto& restrictedKV = restricted_keys_.Get(RestrictedCategory::HEALTH); - // Create a single unified callback service instance. - non_inference_callback_service_ = - new UnifiedCallbackService(tritonserver_, shm_manager_, restrictedKV); + // Create the unified callback service for non-inference operations + const auto& inference_restrictedKV = + restricted_keys_.Get(RestrictedCategory::INFERENCE); + non_inference_callback_service_ = new UnifiedCallbackService( + tritonserver_, shm_manager_, inference_restrictedKV); + + // Create the health callback service + const auto& health_restrictedKV = + restricted_keys_.Get(RestrictedCategory::HEALTH); + health_callback_service_ = + new HealthCallbackService(tritonserver_, health_restrictedKV); } } // namespace @@ -1977,7 +2044,6 @@ Server::Server( builder_.AddListeningPort(server_addr_, credentials, &bound_port_); builder_.SetMaxMessageSize(MAX_GRPC_MESSAGE_SIZE); builder_.RegisterService(&service_); - // builder_.RegisterService(&health_service_); builder_.AddChannelArgument( GRPC_ARG_ALLOW_REUSEPORT, options.socket_.reuse_port_); @@ -2075,13 +2141,14 @@ Server::Server( // A common Handler for other non-inference requests common_handler_.reset(new CommonHandler( "CommonHandler", tritonserver_, shm_manager_, trace_manager_, &service_, - &health_service_, &non_inference_callback_service_, + &health_service_, nullptr /* non_inference_callback_service */, options.restricted_protocols_, response_delay)); // Use common_handler_ and register services auto* handler = dynamic_cast(common_handler_.get()); if (handler != nullptr) { - // Register the unified service directly without casting + // Register both the unified service and health service builder_.RegisterService(handler->GetUnifiedCallbackService()); + builder_.RegisterService(handler->GetHealthCallbackService()); } // [FIXME] "register" logic is different for infer diff --git a/src/grpc/grpc_server.h b/src/grpc/grpc_server.h index eceb7f9b85..5777059fcf 100644 --- a/src/grpc/grpc_server.h +++ b/src/grpc/grpc_server.h @@ -143,6 +143,7 @@ class Server { inference::GRPCInferenceServiceCallback::CallbackService non_inference_callback_service_; + ::grpc::health::v1::Health::CallbackService* health_callback_service_; std::unique_ptr<::grpc::Server> server_; // std::unique_ptr<::grpc::ServerCompletionQueue> common_cq_; From 636b933df784d1041ed9eba1a8571ad967e858f7 Mon Sep 17 00:00:00 2001 From: Indrajit Bhosale Date: Tue, 11 Mar 2025 11:36:07 -0700 Subject: [PATCH 7/8] Test Script for new Service Names --- test_grpc_callbacks.sh | 148 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100755 test_grpc_callbacks.sh diff --git a/test_grpc_callbacks.sh b/test_grpc_callbacks.sh new file mode 100755 index 0000000000..8058af5cdd --- /dev/null +++ b/test_grpc_callbacks.sh @@ -0,0 +1,148 @@ +#!/bin/bash +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Note: Before running this script, start Triton server in explicit model control mode: +# tritonserver --model-repository=/path/to/model/repository --model-control-mode=explicit + +# Default server URL +SERVER_URL=${1:-"localhost:8001"} +PROTO_PATH="/mnt/builddir/triton-server/_deps/repo-common-src/protobuf" +PROTO_FILE="${PROTO_PATH}/grpccallback_service.proto" +HEALTH_PROTO="${PROTO_PATH}/health.proto" + +# Colors for output +GREEN='\033[0;32m' +RED='\033[0;31m' +NC='\033[0m' # No Color +BOLD='\033[1m' + +# Function to print test results +print_result() { + local test_name=$1 + local result=$2 + if [ $result -eq 0 ]; then + echo -e "${test_name}: ${GREEN}PASSED${NC}" + else + echo -e "${test_name}: ${RED}FAILED${NC}" + fi +} + +echo -e "\n${BOLD}Testing gRPC Callback RPCs against ${SERVER_URL}${NC}\n" + +# Test Health Check +echo -e "\n${BOLD}Testing Health Check:${NC}" +grpcurl -proto ${HEALTH_PROTO} \ + --import-path ${PROTO_PATH} \ + -plaintext ${SERVER_URL} \ + grpc.health.v1.Health/Check +print_result "Health Check" $? + +# Test Repository Index +echo -e "\n${BOLD}Testing Repository Index:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/RepositoryIndex +print_result "Repository Index" $? + +# Test Model Load +echo -e "\n${BOLD}Testing Model Load:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext -d '{"model_name": "simple"}' \ + ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/RepositoryModelLoad +print_result "Model Load" $? + +# Wait for model to load +sleep 2 + +# Test Model Unload +echo -e "\n${BOLD}Testing Model Unload:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext -d '{"model_name": "simple"}' \ + ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/RepositoryModelUnload +print_result "Model Unload" $? + +# Test Server Live +echo -e "\n${BOLD}Testing Server Live:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/ServerLive +print_result "Server Live" $? + +# Test Server Ready +echo -e "\n${BOLD}Testing Server Ready:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/ServerReady +print_result "Server Ready" $? + +# Load model again before testing Model Ready +echo -e "\n${BOLD}Loading model for Model Ready test:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext -d '{"model_name": "simple"}' \ + ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/RepositoryModelLoad +print_result "Model Load" $? + +# Wait for model to load +sleep 2 + +# Test Model Ready +echo -e "\n${BOLD}Testing Model Ready:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext -d '{"name": "simple"}' \ + ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/ModelReady +print_result "Model Ready" $? + +# Test Server Metadata +echo -e "\n${BOLD}Testing Server Metadata:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/ServerMetadata +print_result "Server Metadata" $? + +# Test Model Metadata +echo -e "\n${BOLD}Testing Model Metadata:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext -d '{"name": "simple"}' \ + ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/ModelMetadata +print_result "Model Metadata" $? + +echo -e "\n${BOLD}Test Summary:${NC}" +echo "----------------------------------------" \ No newline at end of file From bb626ad2e22bfb1c18c23209e8d279159dec0fe9 Mon Sep 17 00:00:00 2001 From: Indrajit Bhosale Date: Tue, 11 Mar 2025 18:15:48 -0700 Subject: [PATCH 8/8] Fix RestrictedFeatures --- src/grpc/grpc_server.cc | 255 +++++++++++++++++----------------------- 1 file changed, 105 insertions(+), 150 deletions(-) diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index 357f63d3ac..54dc98dfc9 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -85,8 +85,8 @@ class HealthCallbackService public: HealthCallbackService( const std::shared_ptr& server, - const std::pair& restrictedKV) - : tritonserver_(server), restricted_kv_(restrictedKV) + RestrictedFeatures& restricted_keys_) + : tritonserver_(server), restricted_keys_(restricted_keys_) { } @@ -98,10 +98,12 @@ class HealthCallbackService auto* reactor = context->DefaultReactor(); // Check restricted access if configured - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -131,19 +133,18 @@ class HealthCallbackService private: std::shared_ptr tritonserver_; - std::pair restricted_kv_; + RestrictedFeatures restricted_keys_; }; class UnifiedCallbackService - : public inference::GRPCInferenceServiceCallback::CallbackService, - public ::grpc::health::v1::Health::CallbackService { + : public inference::GRPCInferenceServiceCallback::CallbackService { public: UnifiedCallbackService( const std::shared_ptr& server, const std::shared_ptr& shm_manager, - const std::pair& restrictedKV) + RestrictedFeatures& restricted_keys_) : tritonserver_(server), shm_manager_(shm_manager), - restricted_kv_(restrictedKV) + restricted_keys_(restricted_keys_) { } @@ -174,10 +175,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -206,10 +209,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -230,45 +235,6 @@ class UnifiedCallbackService return reactor; } - // ::grpc::ServerUnaryReactor* Check( - // ::grpc::CallbackServerContext* context, - // const ::grpc::health::v1::HealthCheckRequest* request, - // ::grpc::health::v1::HealthCheckResponse* response) override { - // auto* reactor = context->DefaultReactor(); - - // // (Optionally) Check client metadata for restricted access. - // if (!restricted_kv_.first.empty()) { - // const auto& metadata = context->client_metadata(); - // auto it = metadata.find(restricted_kv_.first); - // if (it == metadata.end() || it->second != restricted_kv_.second) { - // reactor->Finish(::grpc::Status(::grpc::StatusCode::UNAVAILABLE, - // "Missing or mismatched restricted header")); - // return reactor; - // } - // } - - // // Business logic for HealthCheck. - // bool live = false; - // TRITONSERVER_Error* err = TRITONSERVER_ServerIsReady(tritonserver_.get(), - // &live); - - // auto serving_status = - // ::grpc::health::v1::HealthCheckResponse_ServingStatus_UNKNOWN; if (err == - // nullptr) { - // serving_status = live - // ? ::grpc::health::v1::HealthCheckResponse_ServingStatus_SERVING - // : - // ::grpc::health::v1::HealthCheckResponse_ServingStatus_NOT_SERVING; - // } - // response->set_status(serving_status); - - // ::grpc::Status status; - // GrpcStatusUtil::Create(&status, err); - // TRITONSERVER_ErrorDelete(err); - // reactor->Finish(status); - // return reactor; - // } - ::grpc::ServerUnaryReactor* ModelReady( ::grpc::CallbackServerContext* context, const inference::ModelReadyRequest* request, @@ -277,10 +243,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -316,10 +284,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::METADATA); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -390,10 +360,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::METADATA); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -551,10 +523,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_CONFIG); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -603,10 +577,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::STATISTICS); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -848,10 +824,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::TRACE); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1143,10 +1121,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::LOGGING); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1330,10 +1310,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1361,10 +1343,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1427,10 +1411,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1470,10 +1456,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1530,10 +1518,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1565,10 +1555,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1599,10 +1591,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1706,10 +1700,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1815,10 +1811,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1869,49 +1867,10 @@ class UnifiedCallbackService return reactor; } - ::grpc::ServerUnaryReactor* Check( - ::grpc::CallbackServerContext* context, - const ::grpc::health::v1::HealthCheckRequest* request, - ::grpc::health::v1::HealthCheckResponse* response) override - { - auto* reactor = context->DefaultReactor(); - - // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { - const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { - reactor->Finish(::grpc::Status( - ::grpc::StatusCode::UNAVAILABLE, - "Missing or mismatched restricted header")); - return reactor; - } - } - - // Check if server is ready - bool ready = false; - TRITONSERVER_Error* err = - TRITONSERVER_ServerIsReady(tritonserver_.get(), &ready); - - // Set health status based on server readiness - if (err == nullptr && ready) { - response->set_status(::grpc::health::v1::HealthCheckResponse::SERVING); - } else { - response->set_status( - ::grpc::health::v1::HealthCheckResponse::NOT_SERVING); - } - - ::grpc::Status status; - GrpcStatusUtil::Create(&status, err); - TRITONSERVER_ErrorDelete(err); - reactor->Finish(status); - return reactor; - } - private: std::shared_ptr tritonserver_; std::shared_ptr shm_manager_; - std::pair restricted_kv_; + RestrictedFeatures restricted_keys_; }; // @@ -1992,16 +1951,12 @@ void CommonHandler::CreateCallbackServices() { // Create the unified callback service for non-inference operations - const auto& inference_restrictedKV = - restricted_keys_.Get(RestrictedCategory::INFERENCE); - non_inference_callback_service_ = new UnifiedCallbackService( - tritonserver_, shm_manager_, inference_restrictedKV); + non_inference_callback_service_ = + new UnifiedCallbackService(tritonserver_, shm_manager_, restricted_keys_); // Create the health callback service - const auto& health_restrictedKV = - restricted_keys_.Get(RestrictedCategory::HEALTH); health_callback_service_ = - new HealthCallbackService(tritonserver_, health_restrictedKV); + new HealthCallbackService(tritonserver_, restricted_keys_); } } // namespace