Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _codeql_detected_source_root
49 changes: 49 additions & 0 deletions src/test/input_byte_size_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,55 @@ TEST_F(InputByteSizeTest, BatchSizeOverflowInt32)
TRITONSERVER_InferenceRequestDelete(irequest_),
"deleting inference request");
}

TEST_F(InputByteSizeTest, GetByteSizeOverflow)
{
const char* model_name = "onnx_zero_1_float32";
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, model_name, -1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
irequest_, InferRequestComplete, nullptr /* request_release_userp */),
"setting request release callback");

// Define a shape whose total byte size overflows int64_t:
// element_count = INT64_MAX/sizeof(float) + 1 is a valid int64_t value,
// but element_count * sizeof(float) exceeds INT64_MAX.
int64_t large_dim =
static_cast<int64_t>(INT64_MAX / sizeof(float)) + 1;
std::vector<int64_t> shape{1, large_dim};

// Provide a minimal data buffer; the overflow is detected before the
// byte size comparison, so no large allocation is needed.
std::vector<float> input_data(1, 0.0f);
const auto input0_byte_size = sizeof(input_data[0]) * input_data.size();

// Set input for the request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestAddInput(
irequest_, "INPUT0", TRITONSERVER_TYPE_FP32, shape.data(),
shape.size()),
"setting input for the request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestAppendInputData(
irequest_, "INPUT0", input_data.data(), input0_byte_size,
TRITONSERVER_MEMORY_CPU, 0),
"assigning INPUT data");

// Run inference and expect the byte size overflow error
FAIL_TEST_IF_SUCCESS(
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */),
"expect error with byte size overflow",
"' causes total byte size to exceed maximum size of ");

// Need to manually delete request, otherwise server will not shut down.
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestDelete(irequest_),
"deleting inference request");
}
} // namespace

int
Expand Down