Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/user_guide/decoupled_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ each time with a new response. You can take a look at [grpc_server.cc](https://g

### Using Decoupled Models in Ensembles

When using decoupled models within an [ensemble pipeline](ensemble_models.md), you may encounter unbounded memory growth if the decoupled model produces responses faster than downstream models can consume them.
When using decoupled models within an [ensemble pipeline](ensemble_models.md), you may experience unbounded memory growth if a decoupled model produces responses faster than downstream models can consume them.

To prevent unbounded memory growth in this scenario, consider using the `max_inflight_requests` configuration field. This field limits the maximum number of concurrent inflight requests permitted at each ensemble step for each inference request.
To prevent this, use the `max_inflight_requests` configuration field. This field sets a limit on the maximum number of concurrent requests allowed at each ensemble step. The limit is shared across all active requests for that ensemble model, which helps control memory usage and prevents it from growing without bound.

For more details and examples, see [Managing Memory Usage in Ensemble Models](ensemble_models.md#managing-memory-usage-in-ensemble-models).

Expand Down
12 changes: 5 additions & 7 deletions docs/user_guide/ensemble_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,18 +190,17 @@ When crafting the ensemble steps, it is useful to note the distinction between
connecting ensemble `input`/`output` to those on the composing model and between
composing models.

## Managing Memory Usage in Ensemble Models
## Limit Memory Growth in Ensemble Models (BETA)

An *inflight request* refers to an intermediate request generated by an upstream model that is queued and held in memory until it is processed by a downstream model within an ensemble pipeline. When upstream models process requests significantly faster than downstream models, these in-flight requests can accumulate and potentially lead to unbounded memory growth. This problem occurs when there is a speed mismatch between different steps in the pipeline and is particularly common in *decoupled models* that produce multiple responses per request more quickly than downstream models can consume.

Consider an example ensemble model with two steps where the upstream model is 10× faster:
1. **Preprocessing model**: Produces 100 preprocessed requests/sec
2. **Inference model**: Consumes 10 requests/sec

Without backpressure, requests accumulate in the pipeline faster than they can be processed, eventually leading to out-of-memory errors.
Without backpressure, requests accumulate in the pipeline faster than they can be processed, which eventually leads to out-of-memory errors.

The `max_inflight_requests` field in the ensemble configuration sets a limit on the number of concurrent inflight requests permitted at each ensemble step for a single inference request.
When this limit is reached, faster upstream models are paused (blocked) until downstream models finish processing, effectively preventing unbounded memory growth.
The `max_inflight_requests` field in the ensemble configuration defines a limit on the number of concurrent in-flight requests allowed at each ensemble step. This limit is shared across all active requests for that ensemble model. When the limit is reached, new request scheduling for that step is paused until downstream models free up capacity. This prevents requests from accumulating indefinitely and keeps memory usage under control.

```
ensemble_scheduling {
Expand All @@ -225,9 +224,8 @@ ensemble_scheduling {
```

**Configuration:**
* **`max_inflight_requests: 16`**: For each ensemble request (not globally), at most 16 requests from `dali_preprocess`
can wait for `onnx_inference` to process. Once this per-step limit is reached, `dali_preprocess` is blocked until the downstream step completes a response.
* **Default (`0`)**: No limit - allows unlimited inflight requests (original behavior).
* **`max_inflight_requests: 16`**: Limits the number of concurrent in-flight requests at a given ensemble step to 16 (for example, requests from `dali_preprocess` waiting for `onnx_inference` to complete). This limit is shared across all active requests for that ensemble model. Once the limit is reached, scheduling new work for that step is paused until downstream capacity becomes available.
* **Default (`0`)**: No limit — allows an unlimited number of in-flight requests (original behavior).

### When to Use This Feature

Expand Down
172 changes: 83 additions & 89 deletions qa/L0_simple_ensemble/ensemble_backpressure_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@
SERVER_URL = "localhost:8001"
DEFAULT_RESPONSE_TIMEOUT = 60
EXPECTED_INFER_OUTPUT = 0.5
MODEL_ENSEMBLE_DISABLED = "ensemble_disabled_max_inflight_requests"
MODEL_ENSEMBLE_LIMIT_4 = "ensemble_max_inflight_requests_limit_4"
MODEL_ENSEMBLE_LIMIT_1 = "ensemble_max_inflight_requests_limit_1"

NUM_REQUESTS = 16
NUM_RESPONSES_PER_REQUEST = 8


class UserData:
Expand All @@ -62,11 +62,14 @@ def callback(user_data, result, error):
user_data._response_queue.put(result)


def prepare_infer_args(input_value):
def prepare_infer_args(input_value, enable_batching=False):
"""
Create InferInput/InferRequestedOutput lists
"""
input_data = np.array([input_value], dtype=np.int32)
if enable_batching:
input_data = np.array([[input_value]], dtype=np.int32)
else:
input_data = np.array([input_value], dtype=np.int32)
infer_input = [grpcclient.InferInput("IN", input_data.shape, "INT32")]
infer_input[0].set_data_from_numpy(input_data)
outputs = [grpcclient.InferRequestedOutput("OUT")]
Expand All @@ -87,7 +90,7 @@ def collect_responses(user_data):
f"No response received within {DEFAULT_RESPONSE_TIMEOUT} seconds."
)

if type(result) == InferenceServerException:
if isinstance(result, InferenceServerException):
errors.append(result)
# error responses are final - stream terminates
break
Expand All @@ -110,131 +113,122 @@ class EnsembleBackpressureTest(tu.TestResultCollector):
Tests for ensemble backpressure feature (max_inflight_requests).
"""

def _run_inference(self, model_name, expected_responses_count=32):
def _run_inference(
self, model_name, expected_responses_per_request, num_concurrent_requests=1
):
"""
Helper function to run inference and verify responses.
Send num_concurrent_requests streaming requests to model_name, each expecting
expected_responses_per_request responses. Verify all complete with correct data.
"""
user_data = UserData()
with grpcclient.InferenceServerClient(SERVER_URL) as triton_client:
try:
inputs, outputs = prepare_infer_args(expected_responses_count)
triton_client.start_stream(callback=partial(callback, user_data))
triton_client.async_stream_infer(
model_name=model_name, inputs=inputs, outputs=outputs
)

# Collect and verify responses
errors, responses = collect_responses(user_data)
self.assertEqual(
len(responses),
expected_responses_count,
f"Expected {expected_responses_count} responses, got {len(responses)}",
)
self.assertEqual(
len(errors),
0,
f"Expected no errors during inference, got {len(errors)} errors",
)

# Verify correctness of responses
for idx, resp in enumerate(responses):
output = resp.as_numpy("OUT")
self.assertAlmostEqual(
output[0],
EXPECTED_INFER_OUTPUT,
places=5,
msg=f"Response {idx} has incorrect value - {output[0]}",
)
finally:
triton_client.stop_stream()

def test_max_inflight_requests_limit_4(self):
"""
Test that max_inflight_requests correctly limits concurrent
responses.
"""
self._run_inference(model_name=MODEL_ENSEMBLE_LIMIT_4)

def test_max_inflight_requests_limit_1(self):
"""
Test edge case: max_inflight_requests=1.
"""
self._run_inference(model_name=MODEL_ENSEMBLE_LIMIT_1)

def test_max_inflight_requests_limit_disabled(self):
"""
Test that an ensemble model without max_inflight_requests parameter works correctly.
"""
self._run_inference(model_name=MODEL_ENSEMBLE_DISABLED)

def test_max_inflight_requests_limit_concurrent_requests(self):
"""
Test that backpressure works correctly with multiple concurrent requests.
Each request should have independent backpressure state.
"""
num_concurrent = 8
expected_per_request = 8
user_datas = [UserData() for _ in range(num_concurrent)]
user_datas = [UserData() for _ in range(num_concurrent_requests)]

with ExitStack() as stack:
clients = [
stack.enter_context(grpcclient.InferenceServerClient(SERVER_URL))
for _ in range(num_concurrent)
for _ in range(num_concurrent_requests)
]

inputs, outputs = prepare_infer_args(expected_per_request)
inputs, outputs = prepare_infer_args(expected_responses_per_request, True)

# Start all concurrent requests
for i in range(num_concurrent):
for i in range(num_concurrent_requests):
clients[i].start_stream(callback=partial(callback, user_datas[i]))
clients[i].async_stream_infer(
model_name=MODEL_ENSEMBLE_LIMIT_4, inputs=inputs, outputs=outputs
model_name=model_name, inputs=inputs, outputs=outputs
)

# Collect and verify responses for all requests
for i, ud in enumerate(user_datas):
errors, responses = collect_responses(ud)
self.assertEqual(
len(responses),
expected_per_request,
f"Request {i}: expected {expected_per_request} responses, got {len(responses)}",
expected_responses_per_request,
f"Request {i}: expected {expected_responses_per_request} responses, got {len(responses)}",
)
self.assertEqual(
len(errors),
0,
f"Request {i}: Expected no errors during inference, got {len(errors)} errors",
len(errors), 0, f"Request {i}: unexpected errors: {errors}"
)
# Verify correctness of responses
for idx, resp in enumerate(responses):
output = resp.as_numpy("OUT")
# output shape is [batch_size, 1]; extract scalar for comparison.
value = float(output[0][0])
self.assertAlmostEqual(
output[0],
value,
EXPECTED_INFER_OUTPUT,
places=5,
msg=f"Response {idx} for request {i} has incorrect value - {output[0]}",
msg=f"Request {i} response {idx}: expected "
f"{EXPECTED_INFER_OUTPUT}, got {value}",
)

# Stop all streams
for client in clients:
client.stop_stream()

def test_max_inflight_requests_limit_request_cancellation(self):
def test_single_request_with_different_limits(self):
"""
Single streaming request that produces 16 responses via a three-step ensemble pipeline
(decoupled_producer → consumer_high_delay → consumer_low_delay) under various
max_inflight_requests configurations.
"""
cases = [
("ensemble_limit_4", "max_inflight_requests=4"),
("ensemble_limit_1", "max_inflight_requests=1"),
("ensemble_disabled", "max_inflight_requests is disabled"),
]
for model_name, desc in cases:
with self.subTest(limit=desc):
self._run_inference(
model_name=model_name, expected_responses_per_request=16
)

def test_concurrent_requests_with_different_limits(self):
"""
NUM_REQUESTS concurrent streaming requests (NUM_RESPONSES_PER_REQUEST
responses each) exercise the max_inflight_requests limit.
Subtests cover: limit=4, limit=1, and the limit disabled.
"""
cases = [
("ensemble_limit_4", "max_inflight_requests=4"),
("ensemble_limit_1", "max_inflight_requests=1"),
("ensemble_disabled", "max_inflight_requests is disabled"),
]
for model_name, desc in cases:
with self.subTest(limit=desc):
self._run_inference(
model_name=model_name,
expected_responses_per_request=NUM_RESPONSES_PER_REQUEST,
num_concurrent_requests=NUM_REQUESTS,
)

def test_sequential_requests_limiter_resets_cleanly(self):
"""
Send NUM_REQUESTS requests one after another. If the limiter
leaks a slot on any request, subsequent requests will be stuck or time out.
"""
for seq_idx in range(NUM_REQUESTS):
with self.subTest(request=seq_idx):
self._run_inference(
model_name="ensemble_limit_4",
expected_responses_per_request=NUM_RESPONSES_PER_REQUEST,
)

def test_request_cancellation_under_backpressure(self):
"""
Test that cancellation unblocks producers waiting on backpressure and that
the client receives a cancellation error.
Start a long-running request (32 responses), cancel mid-stream,
and verify the server sends a CANCELLED status and only a partial set of
responses is received.
"""
# Use a large count to ensure the producer gets blocked by backpressure.
# The model is configured with max_inflight_requests = 4.
input_value = 32
user_data = UserData()

with grpcclient.InferenceServerClient(SERVER_URL) as triton_client:
inputs, outputs = prepare_infer_args(input_value)
inputs, outputs = prepare_infer_args(input_value, True)
triton_client.start_stream(callback=partial(callback, user_data))

# Start the request
triton_client.async_stream_infer(
model_name=MODEL_ENSEMBLE_LIMIT_4, inputs=inputs, outputs=outputs
model_name="ensemble_limit_4", inputs=inputs, outputs=outputs
)

responses = []
Expand All @@ -248,7 +242,7 @@ def test_max_inflight_requests_limit_request_cancellation(self):
except queue.Empty:
self.fail("Stream did not produce any response before cancellation.")

# Cancel the stream. This should unblock any waiting producers and result in a CANCELLED error.
# Cancel the stream - this unblocks any waiting producers and triggers a CANCELLED error.
triton_client.stop_stream(cancel_requests=True)

# Allow some time for cancellation
Expand Down Expand Up @@ -331,7 +325,7 @@ def _run_inference(self, model_name, expected_responses_count):
self.assertEqual(
len(errors),
1,
"Expected exactly one error when queue full terminates stream",
"Expected exactly one error when the queue is full and the stream terminates",
)

# Verify correctness of successful responses
Expand Down
Loading
Loading