diff --git a/_codeql_detected_source_root b/_codeql_detected_source_root new file mode 120000 index 000000000..945c9b46d --- /dev/null +++ b/_codeql_detected_source_root @@ -0,0 +1 @@ +. \ No newline at end of file diff --git a/src/test/input_byte_size_test.cc b/src/test/input_byte_size_test.cc index 80d80bea0..42994e711 100644 --- a/src/test/input_byte_size_test.cc +++ b/src/test/input_byte_size_test.cc @@ -810,6 +810,55 @@ TEST_F(InputByteSizeTest, BatchSizeOverflowInt32) TRITONSERVER_InferenceRequestDelete(irequest_), "deleting inference request"); } + +TEST_F(InputByteSizeTest, GetByteSizeOverflow) +{ + const char* model_name = "onnx_zero_1_float32"; + // Create an inference request + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceRequestNew( + &irequest_, server_, model_name, -1 /* model_version */), + "creating inference request"); + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceRequestSetReleaseCallback( + irequest_, InferRequestComplete, nullptr /* request_release_userp */), + "setting request release callback"); + + // Define a shape whose total byte size overflows int64_t: + // element_count = INT64_MAX/sizeof(float) + 1 is a valid int64_t value, + // but element_count * sizeof(float) exceeds INT64_MAX. + int64_t large_dim = + static_cast(INT64_MAX / sizeof(float)) + 1; + std::vector shape{1, large_dim}; + + // Provide a minimal data buffer; the overflow is detected before the + // byte size comparison, so no large allocation is needed. + std::vector input_data(1, 0.0f); + const auto input0_byte_size = sizeof(input_data[0]) * input_data.size(); + + // Set input for the request + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceRequestAddInput( + irequest_, "INPUT0", TRITONSERVER_TYPE_FP32, shape.data(), + shape.size()), + "setting input for the request"); + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceRequestAppendInputData( + irequest_, "INPUT0", input_data.data(), input0_byte_size, + TRITONSERVER_MEMORY_CPU, 0), + "assigning INPUT data"); + + // Run inference and expect the byte size overflow error + FAIL_TEST_IF_SUCCESS( + TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */), + "expect error with byte size overflow", + "' causes total byte size to exceed maximum size of "); + + // Need to manually delete request, otherwise server will not shut down. + FAIL_TEST_IF_ERR( + TRITONSERVER_InferenceRequestDelete(irequest_), + "deleting inference request"); +} } // namespace int