diff --git a/src/grpc/CMakeLists.txt b/src/grpc/CMakeLists.txt index 0cd027a30a..1b0544c37c 100644 --- a/src/grpc/CMakeLists.txt +++ b/src/grpc/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -67,6 +67,7 @@ target_link_libraries( triton-common-json # from repo-common grpc-health-library # from repo-common grpc-service-library # from repo-common + grpccallback-service-library # from repo-common triton-core-serverapi # from repo-core triton-core-serverstub # from repo-core gRPC::grpc++ diff --git a/src/grpc/grpc_handler.h b/src/grpc/grpc_handler.h index 4f1bcdfac0..405a78d737 100644 --- a/src/grpc/grpc_handler.h +++ b/src/grpc/grpc_handler.h @@ -1,4 +1,4 @@ -// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -25,14 +25,31 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #pragma once +#include + #include +#include "grpc_service.grpc.pb.h" +#include "grpccallback_service.grpc.pb.h" +#include "health.grpc.pb.h" + namespace triton { namespace server { namespace grpc { class HandlerBase { public: virtual ~HandlerBase() = default; virtual void Start() = 0; virtual void Stop() = 0; + virtual inference::GRPCInferenceServiceCallback::CallbackService* + GetUnifiedCallbackService() + { + return nullptr; + } + + virtual ::grpc::health::v1::Health::CallbackService* + GetHealthCallbackService() + { + return nullptr; + } }; class ICallData { diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index 5beb3aba72..7d6b71f632 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -79,622 +79,309 @@ namespace { // are deemed to be not performance critical. //========================================================================= -template -class CommonCallData : public ICallData { +// Define a dedicated health service that implements the health check RPC +class HealthCallbackService + : public ::grpc::health::v1::Health::CallbackService { public: - using StandardRegisterFunc = std::function; - using StandardCallbackFunc = - std::function; - - CommonCallData( - const std::string& name, const uint64_t id, - const StandardRegisterFunc OnRegister, - const StandardCallbackFunc OnExecute, const bool async, - ::grpc::ServerCompletionQueue* cq, - const std::pair& restricted_kv, - const uint64_t& response_delay = 0) - : name_(name), id_(id), OnRegister_(OnRegister), OnExecute_(OnExecute), - async_(async), cq_(cq), responder_(&ctx_), step_(Steps::START), - restricted_kv_(restricted_kv), response_delay_(response_delay) + HealthCallbackService( + const std::shared_ptr& server, + RestrictedFeatures& restricted_keys_) + : tritonserver_(server), restricted_keys_(restricted_keys_) { - OnRegister_(&ctx_, &request_, &responder_, this); - LOG_VERBOSE(1) << "Ready for RPC '" << name_ << "', " << id_; } - ~CommonCallData() + ::grpc::ServerUnaryReactor* Check( + ::grpc::CallbackServerContext* context, + const ::grpc::health::v1::HealthCheckRequest* request, + ::grpc::health::v1::HealthCheckResponse* response) override { - if (async_thread_.joinable()) { - async_thread_.join(); - } - } - - bool Process(bool ok) override; - - std::string Name() override { return name_; } - - uint64_t Id() override { return id_; } - - private: - void Execute(); - void AddToCompletionQueue(); - void WriteResponse(); - bool ExecutePrecondition(); - - const std::string name_; - const uint64_t id_; - const StandardRegisterFunc OnRegister_; - const StandardCallbackFunc OnExecute_; - const bool async_; - ::grpc::ServerCompletionQueue* cq_; - - ::grpc::ServerContext ctx_; - ::grpc::Alarm alarm_; - - ResponderType responder_; - RequestType request_; - ResponseType response_; - ::grpc::Status status_; - - std::thread async_thread_; - - Steps step_; - - std::pair restricted_kv_{"", ""}; - - const uint64_t response_delay_; -}; - -template -bool -CommonCallData::Process(bool rpc_ok) -{ - LOG_VERBOSE(1) << "Process for " << name_ << ", rpc_ok=" << rpc_ok << ", " - << id_ << " step " << step_; - - // If RPC failed on a new request then the server is shutting down - // and so we should do nothing (including not registering for a new - // request). If RPC failed on a non-START step then there is nothing - // we can do since we one execute one step. - const bool shutdown = (!rpc_ok && (step_ == Steps::START)); - if (shutdown) { - if (async_thread_.joinable()) { - async_thread_.join(); + auto* reactor = context->DefaultReactor(); + + // Check restricted access if configured + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } } - step_ = Steps::FINISH; - } - if (step_ == Steps::START) { - // Start a new request to replace this one... - if (!shutdown) { - new CommonCallData( - name_, id_ + 1, OnRegister_, OnExecute_, async_, cq_, restricted_kv_, - response_delay_); - } + // Check if server is ready + bool ready = false; + TRITONSERVER_Error* err = + TRITONSERVER_ServerIsReady(tritonserver_.get(), &ready); - if (!async_) { - // For synchronous calls, execute and write response - // here. - Execute(); - WriteResponse(); + // Set health status based on server readiness + if (err == nullptr && ready) { + response->set_status(::grpc::health::v1::HealthCheckResponse::SERVING); } else { - // For asynchronous calls, delegate the execution to another - // thread. - step_ = Steps::ISSUED; - async_thread_ = std::thread(&CommonCallData::Execute, this); + response->set_status( + ::grpc::health::v1::HealthCheckResponse::NOT_SERVING); } - } else if (step_ == Steps::WRITEREADY) { - // Will only come here for asynchronous mode. - WriteResponse(); - } else if (step_ == Steps::COMPLETE) { - step_ = Steps::FINISH; - } - - return step_ != Steps::FINISH; -} - -template -void -CommonCallData::Execute() -{ - if (ExecutePrecondition()) { - OnExecute_(request_, &response_, &status_); - } else { - status_ = ::grpc::Status( - ::grpc::StatusCode::UNAVAILABLE, - std::string("This protocol is restricted, expecting header '") + - restricted_kv_.first + "'"); - } - step_ = Steps::WRITEREADY; - - if (async_) { - // For asynchronous operation, need to add itself onto the completion - // queue so that the response can be written once the object is - // taken up next for execution. - AddToCompletionQueue(); - } -} -template -bool -CommonCallData::ExecutePrecondition() -{ - if (!restricted_kv_.first.empty()) { - const auto& metadata = ctx_.client_metadata(); - const auto it = metadata.find(restricted_kv_.first); - return (it != metadata.end()) && (it->second == restricted_kv_.second); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; } - return true; -} - -template -void -CommonCallData::AddToCompletionQueue() -{ - alarm_.Set(cq_, gpr_now(gpr_clock_type::GPR_CLOCK_REALTIME), this); -} -template -void -CommonCallData::WriteResponse() -{ - if (response_delay_ != 0) { - // Will delay the write of the response by the specified time. - // This can be used to test the flow where there are other - // responses available to be written. - LOG_VERBOSE(1) << "Delaying the write of the response by " - << response_delay_ << " seconds"; - std::this_thread::sleep_for(std::chrono::seconds(response_delay_)); - } - step_ = Steps::COMPLETE; - responder_.Finish(response_, status_, this); -} + private: + std::shared_ptr tritonserver_; + RestrictedFeatures restricted_keys_; +}; -// -// CommonHandler -// -// A common handler for all non-inference requests. -// -class CommonHandler : public HandlerBase { +class UnifiedCallbackService + : public inference::GRPCInferenceServiceCallback::CallbackService { public: - CommonHandler( - const std::string& name, - const std::shared_ptr& tritonserver, + UnifiedCallbackService( + const std::shared_ptr& server, const std::shared_ptr& shm_manager, - TraceManager* trace_manager, - inference::GRPCInferenceService::AsyncService* service, - ::grpc::health::v1::Health::AsyncService* health_service, - ::grpc::ServerCompletionQueue* cq, - const RestrictedFeatures& restricted_keys, const uint64_t response_delay); - - // Descriptive name of of the handler. - const std::string& Name() const { return name_; } - - // Start handling requests. - void Start() override; - - // Stop handling requests. - void Stop() override; + RestrictedFeatures& restricted_keys_) + : tritonserver_(server), shm_manager_(shm_manager), + restricted_keys_(restricted_keys_) + { + } - private: - void SetUpAllRequests(); - - // [FIXME] turn into generated code - void RegisterServerLive(); - void RegisterServerReady(); - void RegisterHealthCheck(); - void RegisterModelReady(); - void RegisterServerMetadata(); - void RegisterModelMetadata(); - void RegisterModelConfig(); - void RegisterModelStatistics(); - void RegisterTrace(); - void RegisterLogging(); - void RegisterSystemSharedMemoryStatus(); - void RegisterSystemSharedMemoryRegister(); - void RegisterSystemSharedMemoryUnregister(); - void RegisterCudaSharedMemoryStatus(); - void RegisterCudaSharedMemoryRegister(); - void RegisterCudaSharedMemoryUnregister(); - void RegisterRepositoryIndex(); - void RegisterRepositoryModelLoad(); - void RegisterRepositoryModelUnload(); - - // Set count and cumulative duration for 'RegisterModelStatistics()' template TRITONSERVER_Error* SetStatisticsDuration( triton::common::TritonJson::Value& statistics_json, const std::string& statistics_name, - PBTYPE* mutable_statistics_duration_protobuf) const; - - const std::string name_; - std::shared_ptr tritonserver_; - - std::shared_ptr shm_manager_; - TraceManager* trace_manager_; - - inference::GRPCInferenceService::AsyncService* service_; - ::grpc::health::v1::Health::AsyncService* health_service_; - ::grpc::ServerCompletionQueue* cq_; - std::unique_ptr thread_; - RestrictedFeatures restricted_keys_{}; - const uint64_t response_delay_ = 0; -}; - -CommonHandler::CommonHandler( - const std::string& name, - const std::shared_ptr& tritonserver, - const std::shared_ptr& shm_manager, - TraceManager* trace_manager, - inference::GRPCInferenceService::AsyncService* service, - ::grpc::health::v1::Health::AsyncService* health_service, - ::grpc::ServerCompletionQueue* cq, - const RestrictedFeatures& restricted_keys, - const uint64_t response_delay = 0) - : name_(name), tritonserver_(tritonserver), shm_manager_(shm_manager), - trace_manager_(trace_manager), service_(service), - health_service_(health_service), cq_(cq), - restricted_keys_(restricted_keys), response_delay_(response_delay) -{ -} + PBTYPE* mutable_statistics_duration_protobuf) + { + triton::common::TritonJson::Value statistics_duration_json; + RETURN_IF_ERR(statistics_json.MemberAsObject( + statistics_name.c_str(), &statistics_duration_json)); + + uint64_t value; + RETURN_IF_ERR(statistics_duration_json.MemberAsUInt("count", &value)); + mutable_statistics_duration_protobuf->set_count(value); + RETURN_IF_ERR(statistics_duration_json.MemberAsUInt("ns", &value)); + mutable_statistics_duration_protobuf->set_ns(value); + return nullptr; + } -void -CommonHandler::Start() -{ - // Use a barrier to make sure we don't return until thread has - // started. - auto barrier = std::make_shared(2); - - thread_.reset(new std::thread([this, barrier] { - SetUpAllRequests(); - barrier->Wait(); - - void* tag; - bool ok; - - while (cq_->Next(&tag, &ok)) { - ICallData* call_data = static_cast(tag); - if (!call_data->Process(ok)) { - LOG_VERBOSE(1) << "Done for " << call_data->Name() << ", " - << call_data->Id(); - delete call_data; + // Example RPC method: ServerLive + ::grpc::ServerUnaryReactor* ServerLive( + ::grpc::CallbackServerContext* context, + const inference::ServerLiveRequest* request, + inference::ServerLiveResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; } } - })); - - barrier->Wait(); - LOG_VERBOSE(1) << "Thread started for " << Name(); -} - -void -CommonHandler::Stop() -{ - if (thread_->joinable()) { - thread_->join(); - } - - LOG_VERBOSE(1) << "Thread exited for " << Name(); -} - -void -CommonHandler::SetUpAllRequests() -{ - // Define all the RPCs to be handled by this handler below - // - // Within each of the Register function, the format of RPC specification is: - // 1. A OnRegister function: This will be called when the - // server is ready to receive the requests for this RPC. - // 2. A OnExecute function: This will be called when the - // to process the request. - // 3. Create a CommonCallData object with the above callback - // functions - - // health (GRPC standard) - RegisterHealthCheck(); - // health (Triton) - RegisterServerLive(); - RegisterServerReady(); - RegisterModelReady(); - - // Metadata - RegisterServerMetadata(); - RegisterModelMetadata(); - - // model config - RegisterModelConfig(); - - // shared memory - // system.. - RegisterSystemSharedMemoryStatus(); - RegisterSystemSharedMemoryRegister(); - RegisterSystemSharedMemoryUnregister(); - // cuda.. - RegisterCudaSharedMemoryStatus(); - RegisterCudaSharedMemoryRegister(); - RegisterCudaSharedMemoryUnregister(); - - // model repository - RegisterRepositoryIndex(); - RegisterRepositoryModelLoad(); - RegisterRepositoryModelUnload(); - - // statistics - RegisterModelStatistics(); - - // trace - RegisterTrace(); - - // logging - RegisterLogging(); -} -void -CommonHandler::RegisterServerLive() -{ - auto OnRegisterServerLive = - [this]( - ::grpc::ServerContext* ctx, inference::ServerLiveRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestServerLive( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteServerLive = [this]( - inference::ServerLiveRequest& request, - inference::ServerLiveResponse* response, - ::grpc::Status* status) { + // Business logic for ServerLive. bool live = false; TRITONSERVER_Error* err = TRITONSERVER_ServerIsLive(tritonserver_.get(), &live); - response->set_live((err == nullptr) && live); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::HEALTH); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::ServerLiveRequest, inference::ServerLiveResponse>( - "ServerLive", 0, OnRegisterServerLive, OnExecuteServerLive, - false /* async */, cq_, restricted_kv, response_delay_); -} + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterServerReady() -{ - auto OnRegisterServerReady = - [this]( - ::grpc::ServerContext* ctx, inference::ServerReadyRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestServerReady( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteServerReady = [this]( - inference::ServerReadyRequest& request, - inference::ServerReadyResponse* response, - ::grpc::Status* status) { + ::grpc::ServerUnaryReactor* ServerReady( + ::grpc::CallbackServerContext* context, + const inference::ServerReadyRequest* request, + inference::ServerReadyResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Business logic for ServerReady. bool ready = false; TRITONSERVER_Error* err = TRITONSERVER_ServerIsReady(tritonserver_.get(), &ready); - response->set_ready((err == nullptr) && ready); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::HEALTH); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::ServerReadyRequest, inference::ServerReadyResponse>( - "ServerReady", 0, OnRegisterServerReady, OnExecuteServerReady, - false /* async */, cq_, restricted_kv, response_delay_); -} - -void -CommonHandler::RegisterHealthCheck() -{ - auto OnRegisterHealthCheck = - [this]( - ::grpc::ServerContext* ctx, - ::grpc::health::v1::HealthCheckRequest* request, - ::grpc::ServerAsyncResponseWriter< - ::grpc::health::v1::HealthCheckResponse>* responder, - void* tag) { - this->health_service_->RequestCheck( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteHealthCheck = [this]( - ::grpc::health::v1::HealthCheckRequest& - request, - ::grpc::health::v1::HealthCheckResponse* - response, - ::grpc::Status* status) { - bool live = false; - TRITONSERVER_Error* err = - TRITONSERVER_ServerIsReady(tritonserver_.get(), &live); + reactor->Finish(status); + return reactor; + } - auto serving_status = - ::grpc::health::v1::HealthCheckResponse_ServingStatus_UNKNOWN; - if (err == nullptr) { - serving_status = - live ? ::grpc::health::v1::HealthCheckResponse_ServingStatus_SERVING - : ::grpc::health::v1:: - HealthCheckResponse_ServingStatus_NOT_SERVING; + ::grpc::ServerUnaryReactor* ModelReady( + ::grpc::CallbackServerContext* context, + const inference::ModelReadyRequest* request, + inference::ModelReadyResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } } - response->set_status(serving_status); - - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::HEALTH); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - ::grpc::health::v1::HealthCheckResponse>, - ::grpc::health::v1::HealthCheckRequest, - ::grpc::health::v1::HealthCheckResponse>( - "Check", 0, OnRegisterHealthCheck, OnExecuteHealthCheck, - false /* async */, cq_, restricted_kv, response_delay_); -} -void -CommonHandler::RegisterModelReady() -{ - auto OnRegisterModelReady = - [this]( - ::grpc::ServerContext* ctx, inference::ModelReadyRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestModelReady( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteModelReady = [this]( - inference::ModelReadyRequest& request, - inference::ModelReadyResponse* response, - ::grpc::Status* status) { + // Business logic for ModelReady. bool is_ready = false; int64_t requested_model_version; - auto err = - GetModelVersionFromString(request.version(), &requested_model_version); + TRITONSERVER_Error* err = + GetModelVersionFromString(request->version(), &requested_model_version); if (err == nullptr) { err = TRITONSERVER_ServerModelIsReady( - tritonserver_.get(), request.name().c_str(), requested_model_version, + tritonserver_.get(), request->name().c_str(), requested_model_version, &is_ready); } response->set_ready(is_ready); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::HEALTH); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::ModelReadyRequest, inference::ModelReadyResponse>( - "ModelReady", 0, OnRegisterModelReady, OnExecuteModelReady, - false /* async */, cq_, restricted_kv, response_delay_); -} - -void -CommonHandler::RegisterServerMetadata() -{ - auto OnRegisterServerMetadata = - [this]( - ::grpc::ServerContext* ctx, inference::ServerMetadataRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestServerMetadata( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteServerMetadata = - [this]( - inference::ServerMetadataRequest& request, - inference::ServerMetadataResponse* response, ::grpc::Status* status) { - TRITONSERVER_Message* server_metadata_message = nullptr; - TRITONSERVER_Error* err = TRITONSERVER_ServerMetadata( - tritonserver_.get(), &server_metadata_message); - GOTO_IF_ERR(err, earlyexit); - - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - server_metadata_message, &buffer, &byte_size); - GOTO_IF_ERR(err, earlyexit); + reactor->Finish(status); + return reactor; + } - { - triton::common::TritonJson::Value server_metadata_json; - err = server_metadata_json.Parse(buffer, byte_size); - GOTO_IF_ERR(err, earlyexit); + ::grpc::ServerUnaryReactor* ServerMetadata( + ::grpc::CallbackServerContext* context, + const inference::ServerMetadataRequest* request, + inference::ServerMetadataResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::METADATA); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + // Business logic for ServerMetadata. + TRITONSERVER_Message* server_metadata_message = nullptr; + TRITONSERVER_Error* err = TRITONSERVER_ServerMetadata( + tritonserver_.get(), &server_metadata_message); + if (err == nullptr) { + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + server_metadata_message, &buffer, &byte_size); + if (err == nullptr) { + triton::common::TritonJson::Value server_metadata_json; + err = server_metadata_json.Parse(buffer, byte_size); + if (err == nullptr) { const char* name; size_t namelen; err = server_metadata_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); - - const char* version; - size_t versionlen; - err = server_metadata_json.MemberAsString( - "version", &version, &versionlen); - GOTO_IF_ERR(err, earlyexit); - - response->set_name(std::string(name, namelen)); - response->set_version(std::string(version, versionlen)); - - if (server_metadata_json.Find("extensions")) { - triton::common::TritonJson::Value extensions_json; - err = server_metadata_json.MemberAsArray( - "extensions", &extensions_json); - GOTO_IF_ERR(err, earlyexit); - - for (size_t idx = 0; idx < extensions_json.ArraySize(); ++idx) { - const char* ext; - size_t extlen; - err = extensions_json.IndexAsString(idx, &ext, &extlen); - GOTO_IF_ERR(err, earlyexit); - response->add_extensions(std::string(ext, extlen)); + if (err == nullptr) { + const char* version; + size_t versionlen; + err = server_metadata_json.MemberAsString( + "version", &version, &versionlen); + if (err == nullptr) { + response->set_name(std::string(name, namelen)); + response->set_version(std::string(version, versionlen)); + + if (server_metadata_json.Find("extensions")) { + triton::common::TritonJson::Value extensions_json; + err = server_metadata_json.MemberAsArray( + "extensions", &extensions_json); + if (err == nullptr) { + for (size_t idx = 0; idx < extensions_json.ArraySize(); + ++idx) { + const char* ext; + size_t extlen; + err = extensions_json.IndexAsString(idx, &ext, &extlen); + if (err == nullptr) { + response->add_extensions(std::string(ext, extlen)); + } + } + } + } } } - TRITONSERVER_MessageDelete(server_metadata_message); } + } + TRITONSERVER_MessageDelete(server_metadata_message); + } - earlyexit: - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::METADATA); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::ServerMetadataRequest, inference::ServerMetadataResponse>( - "ServerMetadata", 0, OnRegisterServerMetadata, OnExecuteServerMetadata, - false /* async */, cq_, restricted_kv, response_delay_); -} + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterModelMetadata() -{ - auto OnRegisterModelMetadata = - [this]( - ::grpc::ServerContext* ctx, inference::ModelMetadataRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestModelMetadata( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteModelMetadata = [this]( - inference::ModelMetadataRequest& request, - inference::ModelMetadataResponse* response, - ::grpc::Status* status) { + ::grpc::ServerUnaryReactor* ModelMetadata( + ::grpc::CallbackServerContext* context, + const inference::ModelMetadataRequest* request, + inference::ModelMetadataResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::METADATA); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic - kept same as original int64_t requested_model_version; auto err = - GetModelVersionFromString(request.version(), &requested_model_version); + GetModelVersionFromString(request->version(), &requested_model_version); GOTO_IF_ERR(err, earlyexit); - { TRITONSERVER_Message* model_metadata_message = nullptr; err = TRITONSERVER_ServerModelMetadata( - tritonserver_.get(), request.name().c_str(), requested_model_version, + tritonserver_.get(), request->name().c_str(), requested_model_version, &model_metadata_message); GOTO_IF_ERR(err, earlyexit); @@ -771,7 +458,6 @@ CommonHandler::RegisterModelMetadata() int64_t d; err = shape_json.IndexAsInt(sidx, &d); GOTO_IF_ERR(err, earlyexit); - io->add_shape(d); } } @@ -813,54 +499,51 @@ CommonHandler::RegisterModelMetadata() int64_t d; err = shape_json.IndexAsInt(sidx, &d); GOTO_IF_ERR(err, earlyexit); - io->add_shape(d); } } } } - TRITONSERVER_MessageDelete(model_metadata_message); } - earlyexit: - GrpcStatusUtil::Create(status, err); + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::METADATA); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::ModelMetadataRequest, inference::ModelMetadataResponse>( - "ModelMetadata", 0, OnRegisterModelMetadata, OnExecuteModelMetadata, - false /* async */, cq_, restricted_kv, response_delay_); -} + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterModelConfig() -{ - auto OnRegisterModelConfig = - [this]( - ::grpc::ServerContext* ctx, inference::ModelConfigRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestModelConfig( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteModelConfig = [this]( - inference::ModelConfigRequest& request, - inference::ModelConfigResponse* response, - ::grpc::Status* status) { + ::grpc::ServerUnaryReactor* ModelConfig( + ::grpc::CallbackServerContext* context, + const inference::ModelConfigRequest* request, + inference::ModelConfigResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_CONFIG); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic int64_t requested_model_version; auto err = - GetModelVersionFromString(request.version(), &requested_model_version); + GetModelVersionFromString(request->version(), &requested_model_version); if (err == nullptr) { TRITONSERVER_Message* model_config_message = nullptr; err = TRITONSERVER_ServerModelConfig( - tritonserver_.get(), request.name().c_str(), requested_model_version, + tritonserver_.get(), request->name().c_str(), requested_model_version, 1 /* config_version */, &model_config_message); if (err == nullptr) { const char* buffer; @@ -877,51 +560,48 @@ CommonHandler::RegisterModelConfig() } } - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::MODEL_CONFIG); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::ModelConfigRequest, inference::ModelConfigResponse>( - "ModelConfig", 0, OnRegisterModelConfig, OnExecuteModelConfig, - false /* async */, cq_, restricted_kv, response_delay_); -} + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterModelStatistics() -{ - auto OnRegisterModelStatistics = - [this]( - ::grpc::ServerContext* ctx, - inference::ModelStatisticsRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestModelStatistics( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteModelStatistics = [this]( - inference::ModelStatisticsRequest& - request, - inference::ModelStatisticsResponse* - response, - ::grpc::Status* status) { + // Other RPC methods (e.g., ServerReady, HealthCheck) would be implemented + // similarly. + ::grpc::ServerUnaryReactor* ModelStatistics( + ::grpc::CallbackServerContext* context, + const inference::ModelStatisticsRequest* request, + inference::ModelStatisticsResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::STATISTICS); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic - kept same as original #ifdef TRITON_ENABLE_STATS triton::common::TritonJson::Value model_stats_json; int64_t requested_model_version; auto err = - GetModelVersionFromString(request.version(), &requested_model_version); + GetModelVersionFromString(request->version(), &requested_model_version); GOTO_IF_ERR(err, earlyexit); - { TRITONSERVER_Message* model_stats_message = nullptr; err = TRITONSERVER_ServerModelStatistics( - tritonserver_.get(), request.name().c_str(), requested_model_version, + tritonserver_.get(), request->name().c_str(), requested_model_version, &model_stats_message); GOTO_IF_ERR(err, earlyexit); @@ -1120,63 +800,44 @@ CommonHandler::RegisterModelStatistics() } earlyexit: - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #else auto err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNAVAILABLE, "the server does not support model statistics"); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #endif - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::STATISTICS); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::ModelStatisticsRequest, inference::ModelStatisticsResponse>( - "ModelStatistics", 0, OnRegisterModelStatistics, OnExecuteModelStatistics, - false /* async */, cq_, restricted_kv, response_delay_); -} - -template -TRITONSERVER_Error* -CommonHandler::SetStatisticsDuration( - triton::common::TritonJson::Value& statistics_json, - const std::string& statistics_name, - PBTYPE* mutable_statistics_duration_protobuf) const -{ - triton::common::TritonJson::Value statistics_duration_json; - RETURN_IF_ERR(statistics_json.MemberAsObject( - statistics_name.c_str(), &statistics_duration_json)); - uint64_t value; - RETURN_IF_ERR(statistics_duration_json.MemberAsUInt("count", &value)); - mutable_statistics_duration_protobuf->set_count(value); - RETURN_IF_ERR(statistics_duration_json.MemberAsUInt("ns", &value)); - mutable_statistics_duration_protobuf->set_ns(value); + reactor->Finish(status); + return reactor; + } - return nullptr; -} + ::grpc::ServerUnaryReactor* TraceSetting( + ::grpc::CallbackServerContext* context, + const inference::TraceSettingRequest* request, + inference::TraceSettingResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::TRACE); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } -void -CommonHandler::RegisterTrace() -{ - auto OnRegisterTrace = - [this]( - ::grpc::ServerContext* ctx, inference::TraceSettingRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestTraceSetting( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteTrace = [this]( - inference::TraceSettingRequest& request, - inference::TraceSettingResponse* response, - ::grpc::Status* status) { + // Core business logic - kept same as original #ifdef TRITON_ENABLE_TRACING TRITONSERVER_Error* err = nullptr; TRITONSERVER_InferenceTraceLevel level = TRITONSERVER_TRACE_LEVEL_DISABLED; @@ -1187,29 +848,29 @@ CommonHandler::RegisterTrace() InferenceTraceMode trace_mode; TraceConfigMap config_map; - if (!request.model_name().empty()) { + if (!request->model_name().empty()) { bool ready = false; - GOTO_IF_ERR( - TRITONSERVER_ServerModelIsReady( - tritonserver_.get(), request.model_name().c_str(), - -1 /* model version */, &ready), - earlyexit); + err = TRITONSERVER_ServerModelIsReady( + tritonserver_.get(), request->model_name().c_str(), + -1 /* model version */, &ready); + GOTO_IF_ERR(err, earlyexit); if (!ready) { err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, - (std::string("Request for unknown model : ") + request.model_name()) + (std::string("Request for unknown model : ") + + request->model_name()) .c_str()); GOTO_IF_ERR(err, earlyexit); } } // Update trace setting - if (!request.settings().empty()) { + if (!request->settings().empty()) { TraceManager::NewSetting new_setting; { static std::string setting_name = "trace_file"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNSUPPORTED, "trace file location can not be updated through network " @@ -1219,8 +880,8 @@ CommonHandler::RegisterTrace() } { static std::string setting_name = "trace_level"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { if (it->second.value().size() == 0) { new_setting.clear_level_ = true; } else { @@ -1250,8 +911,8 @@ CommonHandler::RegisterTrace() } { static std::string setting_name = "trace_rate"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { if (it->second.value().size() == 0) { new_setting.clear_rate_ = true; } else if (it->second.value().size() == 1) { @@ -1290,8 +951,8 @@ CommonHandler::RegisterTrace() } { static std::string setting_name = "trace_count"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { if (it->second.value().size() == 0) { new_setting.clear_count_ = true; } else if (it->second.value().size() == 1) { @@ -1339,8 +1000,8 @@ CommonHandler::RegisterTrace() } { static std::string setting_name = "log_frequency"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { if (it->second.value().size() == 0) { new_setting.clear_log_frequency_ = true; } else if (it->second.value().size() == 1) { @@ -1378,16 +1039,16 @@ CommonHandler::RegisterTrace() } } - err = - trace_manager_->UpdateTraceSetting(request.model_name(), new_setting); + err = trace_manager_->UpdateTraceSetting( + request->model_name(), new_setting); GOTO_IF_ERR(err, earlyexit); } - // Get current trace setting, this is needed even if the setting - // has been updated above as some values may not be provided in the request. + // Get current trace setting trace_manager_->GetTraceSetting( - request.model_name(), &level, &rate, &count, &log_frequency, &filepath, + request->model_name(), &level, &rate, &count, &log_frequency, &filepath, &trace_mode, &config_map); + // level { inference::TraceSettingResponse::SettingValue level_setting; @@ -1403,6 +1064,7 @@ CommonHandler::RegisterTrace() } (*response->mutable_settings())["trace_level"] = level_setting; } + (*response->mutable_settings())["trace_rate"].add_value( std::to_string(rate)); (*response->mutable_settings())["trace_count"].add_value( @@ -1434,54 +1096,55 @@ CommonHandler::RegisterTrace() } } } + earlyexit: - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #else auto err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNAVAILABLE, "the server does not support trace"); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #endif - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::TRACE); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::TraceSettingRequest, inference::TraceSettingResponse>( - "Trace", 0, OnRegisterTrace, OnExecuteTrace, false /* async */, cq_, - restricted_kv, response_delay_); -} -void -CommonHandler::RegisterLogging() -{ - auto OnRegisterLogging = - [this]( - ::grpc::ServerContext* ctx, inference::LogSettingsRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestLogSettings( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteLogging = [this]( - inference::LogSettingsRequest& request, - inference::LogSettingsResponse* response, - ::grpc::Status* status) { + reactor->Finish(status); + return reactor; + } + + ::grpc::ServerUnaryReactor* LogSettings( + ::grpc::CallbackServerContext* context, + const inference::LogSettingsRequest* request, + inference::LogSettingsResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::LOGGING); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + // Core business logic - kept same as original #ifdef TRITON_ENABLE_LOGGING TRITONSERVER_Error* err = nullptr; // Update log settings // Server and Core repos do not have the same Logger object // Each update must be applied to both server and core repo versions - if (!request.settings().empty()) { + if (!request->settings().empty()) { { static std::string setting_name = "log_file"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNSUPPORTED, "log file location can not be updated through network protocol"); @@ -1490,8 +1153,8 @@ CommonHandler::RegisterLogging() } { static std::string setting_name = "log_info"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { const auto& log_param = it->second; if (log_param.parameter_choice_case() != inference::LogSettingsRequest_SettingValue::ParameterChoiceCase:: @@ -1510,8 +1173,8 @@ CommonHandler::RegisterLogging() } { static std::string setting_name = "log_warning"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { const auto& log_param = it->second; if (log_param.parameter_choice_case() != inference::LogSettingsRequest_SettingValue::ParameterChoiceCase:: @@ -1530,8 +1193,8 @@ CommonHandler::RegisterLogging() } { static std::string setting_name = "log_error"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { const auto& log_param = it->second; if (log_param.parameter_choice_case() != inference::LogSettingsRequest_SettingValue::ParameterChoiceCase:: @@ -1550,8 +1213,8 @@ CommonHandler::RegisterLogging() } { static std::string setting_name = "log_verbose_level"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { const auto& log_param = it->second; if (log_param.parameter_choice_case() != inference::LogSettingsRequest_SettingValue::ParameterChoiceCase:: @@ -1570,8 +1233,8 @@ CommonHandler::RegisterLogging() } { static std::string setting_name = "log_format"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { const auto& log_param = it->second; if (log_param.parameter_choice_case() != inference::LogSettingsRequest_SettingValue::ParameterChoiceCase:: @@ -1610,6 +1273,7 @@ CommonHandler::RegisterLogging() } GOTO_IF_ERR(err, earlyexit); } + (*response->mutable_settings())["log_file"].set_string_param(LOG_FILE); (*response->mutable_settings())["log_info"].set_bool_param(LOG_INFO_IS_ON); (*response->mutable_settings())["log_warning"].set_bool_param( @@ -1620,628 +1284,679 @@ CommonHandler::RegisterLogging() LOG_VERBOSE_LEVEL); (*response->mutable_settings())["log_format"].set_string_param( LOG_FORMAT_STRING); + earlyexit: - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #else auto err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNAVAILABLE, "the server does not support dynamic logging"); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #endif - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::LOGGING); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::LogSettingsRequest, inference::LogSettingsResponse>( - "Logging", 0, OnRegisterLogging, OnExecuteLogging, false /* async */, cq_, - restricted_kv, response_delay_); -} -void -CommonHandler::RegisterSystemSharedMemoryStatus() -{ - auto OnRegisterSystemSharedMemoryStatus = - [this]( - ::grpc::ServerContext* ctx, - inference::SystemSharedMemoryStatusRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryStatusResponse>* responder, - void* tag) { - this->service_->RequestSystemSharedMemoryStatus( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteSystemSharedMemoryStatus = - [this]( - inference::SystemSharedMemoryStatusRequest& request, - inference::SystemSharedMemoryStatusResponse* response, - ::grpc::Status* status) { - triton::common::TritonJson::Value shm_status_json( - triton::common::TritonJson::ValueType::ARRAY); - TRITONSERVER_Error* err = shm_manager_->GetStatus( - request.name(), TRITONSERVER_MEMORY_CPU, &shm_status_json); - GOTO_IF_ERR(err, earlyexit); + reactor->Finish(status); + return reactor; + } - for (size_t idx = 0; idx < shm_status_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value shm_region_json; - err = shm_status_json.IndexAsObject(idx, &shm_region_json); - GOTO_IF_ERR(err, earlyexit); + ::grpc::ServerUnaryReactor* SystemSharedMemoryRegister( + ::grpc::CallbackServerContext* context, + const inference::SystemSharedMemoryRegisterRequest* request, + inference::SystemSharedMemoryRegisterResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } - const char* name; - size_t namelen; - err = shm_region_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); + // Core business logic - kept same as original + TRITONSERVER_Error* err = shm_manager_->RegisterSystemSharedMemory( + request->name(), request->key(), request->offset(), + request->byte_size()); - const char* key; - size_t keylen; - err = shm_region_json.MemberAsString("key", &key, &keylen); - GOTO_IF_ERR(err, earlyexit); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } - uint64_t offset; - err = shm_region_json.MemberAsUInt("offset", &offset); - GOTO_IF_ERR(err, earlyexit); + ::grpc::ServerUnaryReactor* SystemSharedMemoryStatus( + ::grpc::CallbackServerContext* context, + const inference::SystemSharedMemoryStatusRequest* request, + inference::SystemSharedMemoryStatusResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } - uint64_t byte_size; - err = shm_region_json.MemberAsUInt("byte_size", &byte_size); - GOTO_IF_ERR(err, earlyexit); - - inference::SystemSharedMemoryStatusResponse::RegionStatus - region_status; - region_status.set_name(std::string(name, namelen)); - region_status.set_key(std::string(key, keylen)); - region_status.set_offset(offset); - region_status.set_byte_size(byte_size); + // Core business logic - kept same as original + triton::common::TritonJson::Value shm_status_json( + triton::common::TritonJson::ValueType::ARRAY); + TRITONSERVER_Error* err = shm_manager_->GetStatus( + request->name(), TRITONSERVER_MEMORY_CPU, &shm_status_json); + GOTO_IF_ERR(err, earlyexit); - (*response->mutable_regions())[name] = region_status; - } + for (size_t idx = 0; idx < shm_status_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value shm_region_json; + err = shm_status_json.IndexAsObject(idx, &shm_region_json); + GOTO_IF_ERR(err, earlyexit); - earlyexit: - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryStatusResponse>, - inference::SystemSharedMemoryStatusRequest, - inference::SystemSharedMemoryStatusResponse>( - "SystemSharedMemoryStatus", 0, OnRegisterSystemSharedMemoryStatus, - OnExecuteSystemSharedMemoryStatus, false /* async */, cq_, restricted_kv, - response_delay_); -} + const char* name; + size_t namelen; + err = shm_region_json.MemberAsString("name", &name, &namelen); + GOTO_IF_ERR(err, earlyexit); -void -CommonHandler::RegisterSystemSharedMemoryRegister() -{ - auto OnRegisterSystemSharedMemoryRegister = - [this]( - ::grpc::ServerContext* ctx, - inference::SystemSharedMemoryRegisterRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryRegisterResponse>* responder, - void* tag) { - this->service_->RequestSystemSharedMemoryRegister( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteSystemSharedMemoryRegister = - [this]( - inference::SystemSharedMemoryRegisterRequest& request, - inference::SystemSharedMemoryRegisterResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = shm_manager_->RegisterSystemSharedMemory( - request.name(), request.key(), request.offset(), - request.byte_size()); - - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryRegisterResponse>, - inference::SystemSharedMemoryRegisterRequest, - inference::SystemSharedMemoryRegisterResponse>( - "SystemSharedMemoryRegister", 0, OnRegisterSystemSharedMemoryRegister, - OnExecuteSystemSharedMemoryRegister, false /* async */, cq_, - restricted_kv, response_delay_); -} + const char* key; + size_t keylen; + err = shm_region_json.MemberAsString("key", &key, &keylen); + GOTO_IF_ERR(err, earlyexit); -void -CommonHandler::RegisterSystemSharedMemoryUnregister() -{ - auto OnRegisterSystemSharedMemoryUnregister = - [this]( - ::grpc::ServerContext* ctx, - inference::SystemSharedMemoryUnregisterRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryUnregisterResponse>* responder, - void* tag) { - this->service_->RequestSystemSharedMemoryUnregister( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteSystemSharedMemoryUnregister = - [this]( - inference::SystemSharedMemoryUnregisterRequest& request, - inference::SystemSharedMemoryUnregisterResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.name().empty()) { - err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_CPU); - } else { - err = - shm_manager_->Unregister(request.name(), TRITONSERVER_MEMORY_CPU); - } + uint64_t offset; + err = shm_region_json.MemberAsUInt("offset", &offset); + GOTO_IF_ERR(err, earlyexit); - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryUnregisterResponse>, - inference::SystemSharedMemoryUnregisterRequest, - inference::SystemSharedMemoryUnregisterResponse>( - "SystemSharedMemoryUnregister", 0, OnRegisterSystemSharedMemoryUnregister, - OnExecuteSystemSharedMemoryUnregister, false /* async */, cq_, - restricted_kv, response_delay_); -} + uint64_t byte_size; + err = shm_region_json.MemberAsUInt("byte_size", &byte_size); + GOTO_IF_ERR(err, earlyexit); -void -CommonHandler::RegisterCudaSharedMemoryStatus() -{ - auto OnRegisterCudaSharedMemoryStatus = - [this]( - ::grpc::ServerContext* ctx, - inference::CudaSharedMemoryStatusRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryStatusResponse>* responder, - void* tag) { - this->service_->RequestCudaSharedMemoryStatus( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - auto OnExecuteCudaSharedMemoryStatus = - [this]( - inference::CudaSharedMemoryStatusRequest& request, - inference::CudaSharedMemoryStatusResponse* response, - ::grpc::Status* status) { - triton::common::TritonJson::Value shm_status_json( - triton::common::TritonJson::ValueType::ARRAY); - TRITONSERVER_Error* err = shm_manager_->GetStatus( - request.name(), TRITONSERVER_MEMORY_GPU, &shm_status_json); - GOTO_IF_ERR(err, earlyexit); + inference::SystemSharedMemoryStatusResponse::RegionStatus region_status; + region_status.set_name(std::string(name, namelen)); + region_status.set_key(std::string(key, keylen)); + region_status.set_offset(offset); + region_status.set_byte_size(byte_size); - for (size_t idx = 0; idx < shm_status_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value shm_region_json; - err = shm_status_json.IndexAsObject(idx, &shm_region_json); - GOTO_IF_ERR(err, earlyexit); + (*response->mutable_regions())[name] = region_status; + } - const char* name; - size_t namelen; - err = shm_region_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); + earlyexit: + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } - uint64_t device_id; - err = shm_region_json.MemberAsUInt("device_id", &device_id); - GOTO_IF_ERR(err, earlyexit); + ::grpc::ServerUnaryReactor* CudaSharedMemoryRegister( + ::grpc::CallbackServerContext* context, + const inference::CudaSharedMemoryRegisterRequest* request, + inference::CudaSharedMemoryRegisterResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } - uint64_t byte_size; - err = shm_region_json.MemberAsUInt("byte_size", &byte_size); - GOTO_IF_ERR(err, earlyexit); + // Core business logic + TRITONSERVER_Error* err = nullptr; +#ifdef TRITON_ENABLE_GPU + err = shm_manager_->RegisterCUDASharedMemory( + request->name(), + reinterpret_cast( + request->raw_handle().c_str()), + request->byte_size(), request->device_id()); +#else + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "failed to register CUDA shared memory region: '" + + request->name() + "', GPUs not supported") + .c_str()); +#endif // TRITON_ENABLE_GPU + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } - inference::CudaSharedMemoryStatusResponse::RegionStatus region_status; - region_status.set_name(std::string(name, namelen)); - region_status.set_device_id(device_id); - region_status.set_byte_size(byte_size); + ::grpc::ServerUnaryReactor* CudaSharedMemoryStatus( + ::grpc::CallbackServerContext* context, + const inference::CudaSharedMemoryStatusRequest* request, + inference::CudaSharedMemoryStatusResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } - (*response->mutable_regions())[name] = region_status; - } - earlyexit: - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryStatusResponse>, - inference::CudaSharedMemoryStatusRequest, - inference::CudaSharedMemoryStatusResponse>( - "CudaSharedMemoryStatus", 0, OnRegisterCudaSharedMemoryStatus, - OnExecuteCudaSharedMemoryStatus, false /* async */, cq_, restricted_kv, - response_delay_); -} + // Core business logic - kept same as original + triton::common::TritonJson::Value shm_status_json( + triton::common::TritonJson::ValueType::ARRAY); + TRITONSERVER_Error* err = shm_manager_->GetStatus( + request->name(), TRITONSERVER_MEMORY_GPU, &shm_status_json); + GOTO_IF_ERR(err, earlyexit); -void -CommonHandler::RegisterCudaSharedMemoryRegister() -{ - auto OnRegisterCudaSharedMemoryRegister = - [this]( - ::grpc::ServerContext* ctx, - inference::CudaSharedMemoryRegisterRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryRegisterResponse>* responder, - void* tag) { - this->service_->RequestCudaSharedMemoryRegister( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteCudaSharedMemoryRegister = - [this]( - inference::CudaSharedMemoryRegisterRequest& request, - inference::CudaSharedMemoryRegisterResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; -#ifdef TRITON_ENABLE_GPU - err = shm_manager_->RegisterCUDASharedMemory( - request.name(), - reinterpret_cast( - request.raw_handle().c_str()), - request.byte_size(), request.device_id()); -#else - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - std::string( - "failed to register CUDA shared memory region: '" + - request.name() + "', GPUs not supported") - .c_str()); -#endif // TRITON_ENABLE_GPU + for (size_t idx = 0; idx < shm_status_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value shm_region_json; + err = shm_status_json.IndexAsObject(idx, &shm_region_json); + GOTO_IF_ERR(err, earlyexit); - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryRegisterResponse>, - inference::CudaSharedMemoryRegisterRequest, - inference::CudaSharedMemoryRegisterResponse>( - "CudaSharedMemoryRegister", 0, OnRegisterCudaSharedMemoryRegister, - OnExecuteCudaSharedMemoryRegister, false /* async */, cq_, restricted_kv, - response_delay_); -} + const char* name; + size_t namelen; + err = shm_region_json.MemberAsString("name", &name, &namelen); + GOTO_IF_ERR(err, earlyexit); -void -CommonHandler::RegisterCudaSharedMemoryUnregister() -{ - auto OnRegisterCudaSharedMemoryUnregister = - [this]( - ::grpc::ServerContext* ctx, - inference::CudaSharedMemoryUnregisterRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryUnregisterResponse>* responder, - void* tag) { - this->service_->RequestCudaSharedMemoryUnregister( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteCudaSharedMemoryUnregister = - [this]( - inference::CudaSharedMemoryUnregisterRequest& request, - inference::CudaSharedMemoryUnregisterResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.name().empty()) { - err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_GPU); - } else { - err = - shm_manager_->Unregister(request.name(), TRITONSERVER_MEMORY_GPU); - } + uint64_t device_id; + err = shm_region_json.MemberAsUInt("device_id", &device_id); + GOTO_IF_ERR(err, earlyexit); - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryUnregisterResponse>, - inference::CudaSharedMemoryUnregisterRequest, - inference::CudaSharedMemoryUnregisterResponse>( - "CudaSharedMemoryUnregister", 0, OnRegisterCudaSharedMemoryUnregister, - OnExecuteCudaSharedMemoryUnregister, false /* async */, cq_, - restricted_kv, response_delay_); -} + uint64_t byte_size; + err = shm_region_json.MemberAsUInt("byte_size", &byte_size); + GOTO_IF_ERR(err, earlyexit); -void -CommonHandler::RegisterRepositoryIndex() -{ - auto OnRegisterRepositoryIndex = - [this]( - ::grpc::ServerContext* ctx, - inference::RepositoryIndexRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestRepositoryIndex( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteRepositoryIndex = - [this]( - inference::RepositoryIndexRequest& request, - inference::RepositoryIndexResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.repository_name().empty()) { - uint32_t flags = 0; - if (request.ready()) { - flags |= TRITONSERVER_INDEX_FLAG_READY; - } + inference::CudaSharedMemoryStatusResponse::RegionStatus region_status; + region_status.set_name(std::string(name, namelen)); + region_status.set_device_id(device_id); + region_status.set_byte_size(byte_size); - TRITONSERVER_Message* model_index_message = nullptr; - err = TRITONSERVER_ServerModelIndex( - tritonserver_.get(), flags, &model_index_message); - GOTO_IF_ERR(err, earlyexit); + (*response->mutable_regions())[name] = region_status; + } - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - model_index_message, &buffer, &byte_size); - GOTO_IF_ERR(err, earlyexit); + earlyexit: + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } - triton::common::TritonJson::Value model_index_json; - err = model_index_json.Parse(buffer, byte_size); - GOTO_IF_ERR(err, earlyexit); + ::grpc::ServerUnaryReactor* SystemSharedMemoryUnregister( + ::grpc::CallbackServerContext* context, + const inference::SystemSharedMemoryUnregisterRequest* request, + inference::SystemSharedMemoryUnregisterResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } - err = model_index_json.AssertType( - triton::common::TritonJson::ValueType::ARRAY); - GOTO_IF_ERR(err, earlyexit); + // Core business logic - kept same as original + TRITONSERVER_Error* err = nullptr; + if (request->name().empty()) { + err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_CPU); + } else { + err = shm_manager_->Unregister(request->name(), TRITONSERVER_MEMORY_CPU); + } - for (size_t idx = 0; idx < model_index_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value index_json; - err = model_index_json.IndexAsObject(idx, &index_json); - GOTO_IF_ERR(err, earlyexit); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } - auto model_index = response->add_models(); + // Add here + ::grpc::ServerUnaryReactor* CudaSharedMemoryUnregister( + ::grpc::CallbackServerContext* context, + const inference::CudaSharedMemoryUnregisterRequest* request, + inference::CudaSharedMemoryUnregisterResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } - const char* name; - size_t namelen; - err = index_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_name(std::string(name, namelen)); + // Core business logic - kept same as original + TRITONSERVER_Error* err = nullptr; + if (request->name().empty()) { + err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_GPU); + } else { + err = shm_manager_->Unregister(request->name(), TRITONSERVER_MEMORY_GPU); + } - if (index_json.Find("version")) { - const char* version; - size_t versionlen; - err = index_json.MemberAsString("version", &version, &versionlen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_version(std::string(version, versionlen)); - } - if (index_json.Find("state")) { - const char* state; - size_t statelen; - err = index_json.MemberAsString("state", &state, &statelen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_state(std::string(state, statelen)); - } - if (index_json.Find("reason")) { - const char* reason; - size_t reasonlen; - err = index_json.MemberAsString("reason", &reason, &reasonlen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_reason(std::string(reason, reasonlen)); - } - } + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } - TRITONSERVER_MessageDelete(model_index_message); - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_UNSUPPORTED, - "'repository_name' specification is not supported"); - } + ::grpc::ServerUnaryReactor* RepositoryIndex( + ::grpc::CallbackServerContext* context, + const inference::RepositoryIndexRequest* request, + inference::RepositoryIndexResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } - earlyexit: - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::RepositoryIndexRequest, inference::RepositoryIndexResponse>( - "RepositoryIndex", 0, OnRegisterRepositoryIndex, OnExecuteRepositoryIndex, - false /* async */, cq_, restricted_kv, response_delay_); -} + // Core business logic + TRITONSERVER_Error* err = nullptr; + if (request->repository_name().empty()) { + uint32_t flags = 0; + if (request->ready()) { + flags |= TRITONSERVER_INDEX_FLAG_READY; + } -void -CommonHandler::RegisterRepositoryModelLoad() -{ - auto OnRegisterRepositoryModelLoad = - [this]( - ::grpc::ServerContext* ctx, - inference::RepositoryModelLoadRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::RepositoryModelLoadResponse>* responder, - void* tag) { - this->service_->RequestRepositoryModelLoad( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteRepositoryModelLoad = - [this]( - inference::RepositoryModelLoadRequest& request, - inference::RepositoryModelLoadResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.repository_name().empty()) { - std::vector params; - // WAR for the const-ness check - std::vector const_params; - for (const auto& param_proto : request.parameters()) { - if (param_proto.first == "config") { - if (param_proto.second.parameter_choice_case() != - inference::ModelRepositoryParameter::ParameterChoiceCase:: - kStringParam) { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("invalid value type for load parameter '") + - param_proto.first + "', expected string_param.") - .c_str()); - break; - } else { - auto param = TRITONSERVER_ParameterNew( - param_proto.first.c_str(), TRITONSERVER_PARAMETER_STRING, - param_proto.second.string_param().c_str()); - if (param != nullptr) { - params.emplace_back(param); - const_params.emplace_back(param); - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - "unexpected error on creating Triton parameter"); + TRITONSERVER_Message* model_index_message = nullptr; + err = TRITONSERVER_ServerModelIndex( + tritonserver_.get(), flags, &model_index_message); + if (err == nullptr) { + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + model_index_message, &buffer, &byte_size); + if (err == nullptr) { + triton::common::TritonJson::Value model_index_json; + err = model_index_json.Parse(buffer, byte_size); + if (err == nullptr) { + err = model_index_json.AssertType( + triton::common::TritonJson::ValueType::ARRAY); + if (err == nullptr) { + for (size_t idx = 0; idx < model_index_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value index_json; + err = model_index_json.IndexAsObject(idx, &index_json); + if (err != nullptr) { break; } - } - } else if (param_proto.first.rfind("file:", 0) == 0) { - if (param_proto.second.parameter_choice_case() != - inference::ModelRepositoryParameter::ParameterChoiceCase:: - kBytesParam) { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("invalid value type for load parameter '") + - param_proto.first + "', expected bytes_param.") - .c_str()); - break; - } else { - auto param = TRITONSERVER_ParameterBytesNew( - param_proto.first.c_str(), - param_proto.second.bytes_param().data(), - param_proto.second.bytes_param().length()); - if (param != nullptr) { - params.emplace_back(param); - const_params.emplace_back(param); - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - "unexpected error on creating Triton parameter"); + + auto model_index = response->add_models(); + + const char* name; + size_t namelen; + err = index_json.MemberAsString("name", &name, &namelen); + if (err != nullptr) { break; } + model_index->set_name(std::string(name, namelen)); + + if (index_json.Find("version")) { + const char* version; + size_t versionlen; + err = index_json.MemberAsString( + "version", &version, &versionlen); + if (err != nullptr) { + break; + } + model_index->set_version(std::string(version, versionlen)); + } + if (index_json.Find("state")) { + const char* state; + size_t statelen; + err = index_json.MemberAsString("state", &state, &statelen); + if (err != nullptr) { + break; + } + model_index->set_state(std::string(state, statelen)); + } + if (index_json.Find("reason")) { + const char* reason; + size_t reasonlen; + err = + index_json.MemberAsString("reason", &reason, &reasonlen); + if (err != nullptr) { + break; + } + model_index->set_reason(std::string(reason, reasonlen)); + } } - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("unrecognized load parameter '") + - param_proto.first + "'.") - .c_str()); - break; } } - if (err == nullptr) { - err = TRITONSERVER_ServerLoadModelWithParameters( - tritonserver_.get(), request.model_name().c_str(), - const_params.data(), const_params.size()); - } - // Assumes no further 'params' access after load API returns - for (auto& param : params) { - TRITONSERVER_ParameterDelete(param); - } - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_UNSUPPORTED, - "'repository_name' specification is not supported"); } + TRITONSERVER_MessageDelete(model_index_message); + } + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "'repository_name' specification is not supported"); + } - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::RepositoryModelLoadRequest, - inference::RepositoryModelLoadResponse>( - "RepositoryModelLoad", 0, OnRegisterRepositoryModelLoad, - OnExecuteRepositoryModelLoad, true /* async */, cq_, restricted_kv, - response_delay_); -} + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterRepositoryModelUnload() -{ - auto OnRegisterRepositoryModelUnload = - [this]( - ::grpc::ServerContext* ctx, - inference::RepositoryModelUnloadRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::RepositoryModelUnloadResponse>* responder, - void* tag) { - this->service_->RequestRepositoryModelUnload( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteRepositoryModelUnload = - [this]( - inference::RepositoryModelUnloadRequest& request, - inference::RepositoryModelUnloadResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.repository_name().empty()) { - // Check if the dependent models should be removed - bool unload_dependents = false; - for (auto param : request.parameters()) { - if (param.first.compare("unload_dependents") == 0) { - const auto& unload_param = param.second; - if (unload_param.parameter_choice_case() != - inference::ModelRepositoryParameter::ParameterChoiceCase:: - kBoolParam) { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - "invalid value type for 'unload_dependents' parameter, " - "expected " - "bool_param."); - } - unload_dependents = unload_param.bool_param(); + ::grpc::ServerUnaryReactor* RepositoryModelLoad( + ::grpc::CallbackServerContext* context, + const inference::RepositoryModelLoadRequest* request, + inference::RepositoryModelLoadResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic + TRITONSERVER_Error* err = nullptr; + if (request->repository_name().empty()) { + std::vector params; + // WAR for the const-ness check + std::vector const_params; + + for (const auto& param_proto : request->parameters()) { + if (param_proto.first == "config") { + if (param_proto.second.parameter_choice_case() != + inference::ModelRepositoryParameter::ParameterChoiceCase:: + kStringParam) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("invalid value type for load parameter '") + + param_proto.first + "', expected string_param.") + .c_str()); + break; + } else { + auto param = TRITONSERVER_ParameterNew( + param_proto.first.c_str(), TRITONSERVER_PARAMETER_STRING, + param_proto.second.string_param().c_str()); + if (param != nullptr) { + params.emplace_back(param); + const_params.emplace_back(param); + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "unexpected error on creating Triton parameter"); break; } } - if (err == nullptr) { - if (unload_dependents) { - err = TRITONSERVER_ServerUnloadModelAndDependents( - tritonserver_.get(), request.model_name().c_str()); + } else if (param_proto.first.rfind("file:", 0) == 0) { + if (param_proto.second.parameter_choice_case() != + inference::ModelRepositoryParameter::ParameterChoiceCase:: + kBytesParam) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("invalid value type for load parameter '") + + param_proto.first + "', expected bytes_param.") + .c_str()); + break; + } else { + auto param = TRITONSERVER_ParameterBytesNew( + param_proto.first.c_str(), + param_proto.second.bytes_param().data(), + param_proto.second.bytes_param().length()); + if (param != nullptr) { + params.emplace_back(param); + const_params.emplace_back(param); } else { - err = TRITONSERVER_ServerUnloadModel( - tritonserver_.get(), request.model_name().c_str()); + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "unexpected error on creating Triton parameter"); + break; } } } else { err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_UNSUPPORTED, - "'repository_name' specification is not supported"); + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("unrecognized load parameter '") + + param_proto.first + "'.") + .c_str()); + break; + } + } + + if (err == nullptr) { + err = TRITONSERVER_ServerLoadModelWithParameters( + tritonserver_.get(), request->model_name().c_str(), + const_params.data(), const_params.size()); + } + + // Assumes no further 'params' access after load API returns + for (auto& param : params) { + TRITONSERVER_ParameterDelete(param); + } + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "'repository_name' specification is not supported"); + } + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } + + ::grpc::ServerUnaryReactor* RepositoryModelUnload( + ::grpc::CallbackServerContext* context, + const inference::RepositoryModelUnloadRequest* request, + inference::RepositoryModelUnloadResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic + TRITONSERVER_Error* err = nullptr; + if (request->repository_name().empty()) { + // Check if the dependent models should be removed + bool unload_dependents = false; + for (const auto& param : request->parameters()) { + if (param.first.compare("unload_dependents") == 0) { + const auto& unload_param = param.second; + if (unload_param.parameter_choice_case() != + inference::ModelRepositoryParameter::ParameterChoiceCase:: + kBoolParam) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "invalid value type for 'unload_dependents' parameter, " + "expected bool_param."); + } + unload_dependents = unload_param.bool_param(); + break; + } + } + + if (err == nullptr) { + if (unload_dependents) { + err = TRITONSERVER_ServerUnloadModelAndDependents( + tritonserver_.get(), request->model_name().c_str()); + } else { + err = TRITONSERVER_ServerUnloadModel( + tritonserver_.get(), request->model_name().c_str()); } + } + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "'repository_name' specification is not supported"); + } + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } + + private: + std::shared_ptr tritonserver_; + std::shared_ptr shm_manager_; + RestrictedFeatures restricted_keys_; +}; + +// +// CommonHandler +// +// A common handler for all non-inference requests. +// +class CommonHandler : public HandlerBase { + public: + CommonHandler( + const std::string& name, + const std::shared_ptr& tritonserver, + const std::shared_ptr& shm_manager, + TraceManager* trace_manager, + inference::GRPCInferenceService::AsyncService* service, + ::grpc::health::v1::Health::AsyncService* health_service, + inference::GRPCInferenceServiceCallback::CallbackService* + non_inference_callback_service, + const RestrictedFeatures& restricted_keys, const uint64_t response_delay); + + // Implement pure virtual functions + void Start() override {} // No-op for callback implementation + void Stop() override {} // No-op for callback implementation + + // Descriptive name of of the handler. + const std::string& Name() const { return name_; } - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::RepositoryModelUnloadResponse>, - inference::RepositoryModelUnloadRequest, - inference::RepositoryModelUnloadResponse>( - "RepositoryModelUnload", 0, OnRegisterRepositoryModelUnload, - OnExecuteRepositoryModelUnload, true /* async */, cq_, restricted_kv, - response_delay_); + void CreateCallbackServices(); + + // Add methods to return the callback services + inference::GRPCInferenceServiceCallback::CallbackService* + GetUnifiedCallbackService() + { + return non_inference_callback_service_; + } + + ::grpc::health::v1::Health::CallbackService* GetHealthCallbackService() + { + return health_callback_service_; + } + + private: + const std::string name_; + std::shared_ptr tritonserver_; + std::shared_ptr shm_manager_; + TraceManager* trace_manager_; + inference::GRPCInferenceService::AsyncService* service_; + ::grpc::health::v1::Health::AsyncService* health_service_; + inference::GRPCInferenceServiceCallback::CallbackService* + non_inference_callback_service_; + ::grpc::health::v1::Health::CallbackService* health_callback_service_; + std::unique_ptr thread_; + RestrictedFeatures restricted_keys_; + const uint64_t response_delay_; +}; + +CommonHandler::CommonHandler( + const std::string& name, + const std::shared_ptr& tritonserver, + const std::shared_ptr& shm_manager, + TraceManager* trace_manager, + inference::GRPCInferenceService::AsyncService* service, + ::grpc::health::v1::Health::AsyncService* health_service, + inference::GRPCInferenceServiceCallback::CallbackService* + non_inference_callback_service, + const RestrictedFeatures& restricted_keys, const uint64_t response_delay) + : name_(name), tritonserver_(tritonserver), shm_manager_(shm_manager), + trace_manager_(trace_manager), service_(service), + health_service_(health_service), + non_inference_callback_service_(non_inference_callback_service), + health_callback_service_(nullptr), restricted_keys_(restricted_keys), + response_delay_(response_delay) +{ + CreateCallbackServices(); +} + +void +CommonHandler::CreateCallbackServices() +{ + // Create the unified callback service for non-inference operations + non_inference_callback_service_ = + new UnifiedCallbackService(tritonserver_, shm_manager_, restricted_keys_); + + // Create the health callback service + health_callback_service_ = + new HealthCallbackService(tritonserver_, restricted_keys_); } } // namespace @@ -2284,7 +1999,6 @@ Server::Server( builder_.AddListeningPort(server_addr_, credentials, &bound_port_); builder_.SetMaxMessageSize(MAX_GRPC_MESSAGE_SIZE); builder_.RegisterService(&service_); - builder_.RegisterService(&health_service_); builder_.AddChannelArgument( GRPC_ARG_ALLOW_REUSEPORT, options.socket_.reuse_port_); @@ -2370,7 +2084,6 @@ Server::Server( LOG_TABLE_VERBOSE(1, table_printer); } - common_cq_ = builder_.AddCompletionQueue(); model_infer_cq_ = builder_.AddCompletionQueue(); model_stream_infer_cq_ = builder_.AddCompletionQueue(); @@ -2383,8 +2096,15 @@ Server::Server( // A common Handler for other non-inference requests common_handler_.reset(new CommonHandler( "CommonHandler", tritonserver_, shm_manager_, trace_manager_, &service_, - &health_service_, common_cq_.get(), options.restricted_protocols_, - response_delay)); + &health_service_, nullptr /* non_inference_callback_service */, + options.restricted_protocols_, response_delay)); + // Use common_handler_ and register services + auto* handler = dynamic_cast(common_handler_.get()); + if (handler != nullptr) { + // Register both the unified service and health service + builder_.RegisterService(handler->GetUnifiedCallbackService()); + builder_.RegisterService(handler->GetHealthCallbackService()); + } // [FIXME] "register" logic is different for infer // Handler for model inference requests. @@ -2548,7 +2268,7 @@ Server::Start() (std::string("Socket '") + server_addr_ + "' already in use ").c_str()); } - common_handler_->Start(); + // Remove this for (auto& model_infer_handler : model_infer_handlers_) { model_infer_handler->Start(); } @@ -2572,13 +2292,13 @@ Server::Stop() // Always shutdown the completion queue after the server. server_->Shutdown(); - common_cq_->Shutdown(); + // common_cq_->Shutdown(); model_infer_cq_->Shutdown(); model_stream_infer_cq_->Shutdown(); // Must stop all handlers explicitly to wait for all the handler // threads to join since they are referencing completion queue, etc. - common_handler_->Stop(); + // common_handler_->Stop(); for (auto& model_infer_handler : model_infer_handlers_) { model_infer_handler->Stop(); } diff --git a/src/grpc/grpc_server.h b/src/grpc/grpc_server.h index f5ec5f87cd..983a212ffe 100644 --- a/src/grpc/grpc_server.h +++ b/src/grpc/grpc_server.h @@ -36,6 +36,7 @@ #include "grpc_handler.h" #include "grpc_service.grpc.pb.h" #include "grpc_utils.h" +#include "grpccallback_service.grpc.pb.h" #include "health.grpc.pb.h" #include "infer_handler.h" #include "stream_infer_handler.h" @@ -140,10 +141,13 @@ class Server { inference::GRPCInferenceService::AsyncService service_; ::grpc::health::v1::Health::AsyncService health_service_; + inference::GRPCInferenceServiceCallback::CallbackService + non_inference_callback_service_; + ::grpc::health::v1::Health::CallbackService* health_callback_service_; std::unique_ptr<::grpc::Server> server_; - std::unique_ptr<::grpc::ServerCompletionQueue> common_cq_; + // std::unique_ptr<::grpc::ServerCompletionQueue> common_cq_; std::unique_ptr<::grpc::ServerCompletionQueue> model_infer_cq_; std::unique_ptr<::grpc::ServerCompletionQueue> model_stream_infer_cq_; diff --git a/test_grpc_callbacks.sh b/test_grpc_callbacks.sh new file mode 100755 index 0000000000..8058af5cdd --- /dev/null +++ b/test_grpc_callbacks.sh @@ -0,0 +1,148 @@ +#!/bin/bash +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Note: Before running this script, start Triton server in explicit model control mode: +# tritonserver --model-repository=/path/to/model/repository --model-control-mode=explicit + +# Default server URL +SERVER_URL=${1:-"localhost:8001"} +PROTO_PATH="/mnt/builddir/triton-server/_deps/repo-common-src/protobuf" +PROTO_FILE="${PROTO_PATH}/grpccallback_service.proto" +HEALTH_PROTO="${PROTO_PATH}/health.proto" + +# Colors for output +GREEN='\033[0;32m' +RED='\033[0;31m' +NC='\033[0m' # No Color +BOLD='\033[1m' + +# Function to print test results +print_result() { + local test_name=$1 + local result=$2 + if [ $result -eq 0 ]; then + echo -e "${test_name}: ${GREEN}PASSED${NC}" + else + echo -e "${test_name}: ${RED}FAILED${NC}" + fi +} + +echo -e "\n${BOLD}Testing gRPC Callback RPCs against ${SERVER_URL}${NC}\n" + +# Test Health Check +echo -e "\n${BOLD}Testing Health Check:${NC}" +grpcurl -proto ${HEALTH_PROTO} \ + --import-path ${PROTO_PATH} \ + -plaintext ${SERVER_URL} \ + grpc.health.v1.Health/Check +print_result "Health Check" $? + +# Test Repository Index +echo -e "\n${BOLD}Testing Repository Index:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/RepositoryIndex +print_result "Repository Index" $? + +# Test Model Load +echo -e "\n${BOLD}Testing Model Load:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext -d '{"model_name": "simple"}' \ + ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/RepositoryModelLoad +print_result "Model Load" $? + +# Wait for model to load +sleep 2 + +# Test Model Unload +echo -e "\n${BOLD}Testing Model Unload:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext -d '{"model_name": "simple"}' \ + ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/RepositoryModelUnload +print_result "Model Unload" $? + +# Test Server Live +echo -e "\n${BOLD}Testing Server Live:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/ServerLive +print_result "Server Live" $? + +# Test Server Ready +echo -e "\n${BOLD}Testing Server Ready:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/ServerReady +print_result "Server Ready" $? + +# Load model again before testing Model Ready +echo -e "\n${BOLD}Loading model for Model Ready test:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext -d '{"model_name": "simple"}' \ + ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/RepositoryModelLoad +print_result "Model Load" $? + +# Wait for model to load +sleep 2 + +# Test Model Ready +echo -e "\n${BOLD}Testing Model Ready:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext -d '{"name": "simple"}' \ + ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/ModelReady +print_result "Model Ready" $? + +# Test Server Metadata +echo -e "\n${BOLD}Testing Server Metadata:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/ServerMetadata +print_result "Server Metadata" $? + +# Test Model Metadata +echo -e "\n${BOLD}Testing Model Metadata:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext -d '{"name": "simple"}' \ + ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/ModelMetadata +print_result "Model Metadata" $? + +echo -e "\n${BOLD}Test Summary:${NC}" +echo "----------------------------------------" \ No newline at end of file