Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 203 additions & 1 deletion docs/proposals/00210_response-based-token-counting.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
---
issue: https://github.com/praxis-proxy/praxis/issues/210
discussion: https://github.com/praxis-proxy/praxis/issues/210
status: proposed
status: accepted
authors:
- mkoushni
graduation_criteria:
Expand Down Expand Up @@ -152,3 +152,205 @@ chunks as well as the final chunk. Should the filter accumulate and
sum these, or only use the final chunk's usage payload? Summing
intermediate chunks could double-count if the final chunk already
contains the total.

## How?

### Requirements

1. A new `token_count` filter in `filter/src/builtins/http/ai/token_count.rs`.
2. The filter accepts a single required YAML key `provider` that selects
the extraction strategy.
3. For non-streaming responses the full body is buffered and parsed once
at end-of-stream.
Comment on lines +163 to +164

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty random, but I'm curious: has anyone considered an optional alternative to full token counting where you estimate tokens based on the average tokens for any given byte total, and then apply that average to the payload size for an "estimate"?

I'm curious if some users would opt for an estimate to avoid the additional buffering and payload processing here (estimate instead of count, and then stream through without any memory buffering).

🤔

4. For SSE streaming responses all chunks are buffered and the assembled
event stream is scanned for token fields once the stream closes.
5. For Bedrock InvokeModel, token counts are read from HTTP response
headers in `on_response`; no body parsing is performed.
6. Counts are written to `FilterContext` via `ctx.set_token_usage(input,
output, total)` so that downstream filters (header injection, logging,
rate limiting) can read them without parsing provider JSON themselves.
7. The filter delegates all provider-specific JSON parsing to the existing
`token_usage` library (`filter/src/builtins/http/ai/token_usage/`).

### Answering the Open Questions

#### Provider identification

Provider identity is supplied explicitly via the `provider:` YAML key.
Auto-detection is not implemented. This keeps the filter stateless and
unambiguous — Azure and OpenAI share the same JSON schema, so
auto-detection would be unreliable for those two.

Supported values: `openai`, `anthropic`, `google`, `bedrock`,
`bedrock_invoke_model`, `azure`.

#### Streaming completion signal

The filter uses `BodyMode::StreamBuffer` — the proxy buffers all response
body bytes and calls `on_response_body` once with `end_of_stream: true`.
No per-chunk inspection is needed. Stream close (connection end) is the
authoritative trigger, which correctly handles both providers that emit
`[DONE]` (OpenAI) and providers that do not (Google Gemini).

#### Streaming token accumulation strategy per provider

| Provider | Strategy |
|---|---|
| OpenAI / Azure / Bedrock Converse | Scan all `data:` lines; return the last one that parses successfully — usage appears once in the terminal chunk |
| Google (Gemini) | Same as above — usage is in `usageMetadata` on the final chunk; no `[DONE]` sentinel |
| Anthropic | Two-pass scan: collect `input_tokens` from the `message_start` event, collect `output_tokens` from the `message_delta` event; combine at end |
| Bedrock InvokeModel | Header-only; no SSE parsing |

#### Bedrock InvokeModel extraction path

Handled directly by this filter in `on_response`. When `provider:
bedrock_invoke_model` is configured, the filter reads
`x-amzn-bedrock-input-token-count` and `x-amzn-bedrock-output-token-count`
from the upstream response headers and calls `ctx.set_token_usage` there.
`response_body_access` returns `BodyAccess::None` for this provider so no
body buffering occurs. No changes are required to the `token_usage` library.

#### Partial usage data

Only the final assembled payload is parsed. Intermediate SSE chunks that
contain partial usage fields are ignored. This avoids double-counting for
providers (e.g., Anthropic) that report counts in multiple events — each
relevant event is read exactly once, not summed.

### File Layout

```
filter/src/builtins/http/ai/
token_count.rs ← new: TokenCountFilter
token_usage/ ← existing: TokenUsage, TokenUsageProvider, extract_token_usage
mod.rs
providers.rs
tests.rs
```

### Filter Struct and Config

```rust
#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
struct TokenCountConfig {
provider: ProviderKind,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "snake_case")]
enum ProviderKind {
Comment thread
mkoushni marked this conversation as resolved.
/// Variant name matches YAML `snake_case` via serde. Note that
/// `TokenUsageProvider` in `token_usage/mod.rs` uses `OpenAi`
/// (capital I). A private `to_library_provider` conversion function
/// maps `ProviderKind → TokenUsageProvider` at the call site.
OpenAi,
Anthropic,
Google,
Bedrock,
BedrockInvokeModel,
Azure,
}

pub struct TokenCountFilter {
provider: ProviderKind,
}
```

### HttpFilter Trait Implementation

| Hook | Behaviour |
|---|---|
| `on_request` | No-op; returns `Continue` |
| `on_response` | Detects `text/event-stream` content type and stores an `is_sse` flag in `FilterContext` metadata. For `bedrock_invoke_model`, reads token headers and calls `ctx.set_token_usage` |
| `response_body_access` | `BodyAccess::None` for `bedrock_invoke_model`; `BodyAccess::ReadOnly` for all others |
| `response_body_mode` | `BodyMode::Stream` for `bedrock_invoke_model` (no buffering needed); `BodyMode::StreamBuffer { max_bytes: Some(8 MiB) }` for all others |
| `on_response_body` | Triggered once at `end_of_stream`. Reads the `is_sse` flag; calls `extract_from_sse` for SSE or `extract_token_usage` for JSON; writes result via `ctx.set_token_usage` |

### SSE Extraction Detail

```
extract_from_sse(provider, data)
├── Anthropic → extract_anthropic_sse (two-event scan)
└── all others → extract_last_usage_from_sse (scan data: lines, keep last valid)
```

`extract_anthropic_sse` scans the full buffered text, accumulates
`input_tokens` from `message_start` and `output_tokens` from
`message_delta`, then constructs a `TokenUsage` only when both are present.

### FilterContext Metadata

Token counts are stored under three keys accessible to downstream filters:

| Key | Value |
|---|---|
| `token.input` | Input/prompt token count as `u64` |
| `token.output` | Output/completion token count as `u64` |
| `token.total` | Sum (or provider-supplied total) as `u64` |

Written via `ctx.set_token_usage(input, output, total)`.

### Module Registration

Add to `filter/src/builtins/http/ai/mod.rs`:

```rust
#[cfg(feature = "ai-inference")]
mod token_count;

#[cfg(feature = "ai-inference")]
pub use token_count::TokenCountFilter;
```

Add to `filter/src/builtins/http/mod.rs`:

```rust
#[cfg(feature = "ai-inference")]
pub use ai::TokenCountFilter;
```

Add to `filter/src/builtins/mod.rs`:

```rust
#[cfg(feature = "ai-inference")]
pub use http::TokenCountFilter;
```

Register in `filter/src/registry.rs`:

```rust
registry.register_http("token_count", TokenCountFilter::from_config);
```

### YAML Configuration Example

```yaml
listeners:
- name: gateway
address: "127.0.0.1:8080"
filter_chains:
- token-counting

filter_chains:
Comment thread
mkoushni marked this conversation as resolved.
- name: token-counting
filters:
- filter: token_count
provider: openai # openai | anthropic | google | bedrock | bedrock_invoke_model | azure

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would link in count for llm-d, or raw vllm here?

- filter: access_log
- filter: router
routes:
- path_prefix: "/"
cluster: ai-provider
- filter: load_balancer
clusters:
- name: ai-provider
endpoints:
- "127.0.0.1:8000"
```

[#20]: https://github.com/praxis-proxy/praxis/issues/20
[#211]: https://github.com/praxis-proxy/praxis/issues/211
[#212]: https://github.com/praxis-proxy/praxis/issues/212
[#214]: https://github.com/praxis-proxy/praxis/issues/214
[#216]: https://github.com/praxis-proxy/praxis/issues/216
Loading