Skip to content

Commit 93e0bb7

Browse files
committed
native inference billing
1 parent 8d79861 commit 93e0bb7

6 files changed

Lines changed: 192 additions & 12 deletions

File tree

apis/cloudflare/src/billing.ts

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import { type BillingEvent } from "@braintrust/proxy";
2+
3+
const DEFAULT_BILLING_TELEMETRY_URL =
4+
"https://api.braintrust.dev/billing/telemetry/ingest";
5+
6+
function buildPayloadEvent(event: BillingEvent) {
7+
if (!event.org_id) {
8+
return null;
9+
}
10+
const requestId = crypto.randomUUID();
11+
const timestamp = new Date().toISOString();
12+
13+
return {
14+
event_name: "NativeInferenceTokenUsageEvent",
15+
external_customer_id: event.org_id,
16+
timestamp,
17+
idempotency_key: requestId,
18+
properties: {
19+
model: event.model,
20+
resolved_model: event.resolved_model,
21+
org_id: event.org_id,
22+
input_tokens: event.input_tokens,
23+
output_tokens: event.output_tokens,
24+
cached_input_tokens: event.cached_input_tokens,
25+
cache_write_input_tokens: event.cache_write_input_tokens,
26+
},
27+
};
28+
}
29+
30+
export async function sendBillingTelemetryEvent({
31+
telemetryUrl,
32+
event,
33+
}: {
34+
telemetryUrl?: string;
35+
event: BillingEvent;
36+
}): Promise<void> {
37+
try {
38+
const payloadEvent = buildPayloadEvent(event);
39+
if (!payloadEvent) {
40+
console.warn("billing event skipped: missing org_id");
41+
return;
42+
}
43+
44+
const destination = telemetryUrl || DEFAULT_BILLING_TELEMETRY_URL;
45+
const response = await fetch(destination, {
46+
method: "POST",
47+
headers: {
48+
Authorization: `Bearer ${event.auth_token}`,
49+
"Content-Type": "application/json",
50+
},
51+
body: JSON.stringify({
52+
events: [payloadEvent],
53+
}),
54+
});
55+
56+
if (!response.ok) {
57+
const responseBody = await response.text();
58+
console.warn(
59+
`billing event failed: ${response.status} ${response.statusText} ${responseBody}`,
60+
);
61+
}
62+
} catch (error) {
63+
console.warn("billing event threw an error", error);
64+
}
65+
}

apis/cloudflare/src/env.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ declare global {
44
BRAINTRUST_APP_URL: string;
55
WHITELISTED_ORIGINS?: string;
66
METRICS_LICENSE_KEY?: string;
7+
BILLING_TELEMETRY_URL?: string;
78
NATIVE_INFERENCE_SECRET_KEY?: string;
89
}
910
}

apis/cloudflare/src/proxy.ts

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import { BT_PARENT, resolveParentHeader } from "braintrust/util";
1919
import { cachedLogin, makeProxySpanLogger } from "./tracing";
2020
import { MeterProvider } from "@opentelemetry/sdk-metrics";
2121
import { Meter, Attributes, Histogram } from "@opentelemetry/api";
22+
import { sendBillingTelemetryEvent } from "./billing";
2223

2324
export type LogHistogramFn = (args: {
2425
name: string;
@@ -117,6 +118,30 @@ export async function handleProxyV1(
117118
let span: Span | undefined;
118119
let spanId: string | undefined;
119120
let spanExport: string | undefined;
121+
let billingOrgId: string | undefined;
122+
const orgName = request.headers.get(ORG_NAME_HEADER) ?? undefined;
123+
const apiKey =
124+
parseAuthHeader({
125+
authorization: request.headers.get("authorization") ?? undefined,
126+
}) ?? undefined;
127+
128+
const getLoginState = async () =>
129+
cachedLogin({
130+
appUrl: braintrustAppUrl(env).toString(),
131+
apiKey,
132+
orgName,
133+
cache: credentialsCache,
134+
});
135+
136+
if (apiKey) {
137+
try {
138+
const loginState = await getLoginState();
139+
billingOrgId = loginState.orgId ?? undefined;
140+
} catch (error) {
141+
console.warn("Failed to resolve billing org id", error);
142+
}
143+
}
144+
120145
const parentHeader = request.headers.get(BT_PARENT);
121146
if (parentHeader) {
122147
let parent;
@@ -131,19 +156,11 @@ export async function handleProxyV1(
131156
);
132157
}
133158

134-
const orgName = request.headers.get(ORG_NAME_HEADER) ?? undefined;
135-
const apiKey =
136-
parseAuthHeader({
137-
authorization: request.headers.get("authorization") ?? undefined,
138-
}) ?? undefined;
159+
const loginState = await getLoginState();
160+
billingOrgId = loginState.orgId ?? undefined;
139161

140162
span = startSpan({
141-
state: await cachedLogin({
142-
appUrl: braintrustAppUrl(env).toString(),
143-
apiKey,
144-
orgName,
145-
cache: credentialsCache,
146-
}),
163+
state: loginState,
147164
type: "llm",
148165
name: "LLM",
149166
parent: parent.toStr(),
@@ -199,6 +216,17 @@ export async function handleProxyV1(
199216
spanLogger,
200217
spanId,
201218
spanExport,
219+
billingOrgId,
220+
onBillingEvent: (event) => {
221+
ctx.waitUntil(
222+
sendBillingTelemetryEvent({
223+
telemetryUrl: env.BILLING_TELEMETRY_URL,
224+
event,
225+
}).catch((error) => {
226+
console.warn("billing waitUntil task failed", error);
227+
}),
228+
);
229+
},
202230
nativeInferenceSecretKey: env.NATIVE_INFERENCE_SECRET_KEY,
203231
};
204232

apis/cloudflare/wrangler-template.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@ head_sampling_rate = 0.2
2828
# You should not need to edit this
2929
BRAINTRUST_APP_URL = "https://www.braintrust.dev"
3030
METRICS_LICENSE_KEY="<YOUR_METRICS_LICENSE_KEY>"
31+
BILLING_TELEMETRY_URL="https://api.braintrust.dev/billing/telemetry/ingest"
3132

3233
[env.staging.vars]
3334
BRAINTRUST_APP_URL = "https://www.braintrust.dev"
3435
METRICS_LICENSE_KEY="<YOUR_METRICS_LICENSE_KEY>"
36+
BILLING_TELEMETRY_URL="https://api.braintrust.dev/billing/telemetry/ingest"
3537

3638
[env.staging]
3739
kv_namespaces = [

packages/proxy/edge/index.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { DEFAULT_BRAINTRUST_APP_URL } from "@lib/constants";
22
import { flushMetrics } from "@lib/metrics";
3-
import { proxyV1, SpanLogger, LogHistogramFn } from "@lib/proxy";
3+
import { proxyV1, SpanLogger, LogHistogramFn, BillingEvent } from "@lib/proxy";
44
import { isEmpty } from "@lib/util";
55
import { MeterProvider } from "@opentelemetry/sdk-metrics";
66

@@ -36,6 +36,8 @@ export interface ProxyOpts {
3636
logHistogram?: LogHistogramFn;
3737
whitelist?: (string | RegExp)[];
3838
spanLogger?: SpanLogger;
39+
billingOrgId?: string;
40+
onBillingEvent?: (event: BillingEvent) => void;
3941
spanId?: string;
4042
spanExport?: string;
4143
nativeInferenceSecretKey?: string;
@@ -398,6 +400,8 @@ export function EdgeProxyV1(opts: ProxyOpts) {
398400
digest: digestMessage,
399401
logHistogram: opts.logHistogram,
400402
spanLogger: opts.spanLogger,
403+
billingOrgId: opts.billingOrgId,
404+
onBillingEvent: opts.onBillingEvent,
401405
});
402406
} catch (e) {
403407
return new Response(`${e}`, {

packages/proxy/src/proxy.ts

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,19 @@ export interface SpanLogger {
190190
reportProgress: (progress: string) => void;
191191
}
192192

193+
export type BillingEvent = {
194+
event_name: "NativeInferenceTokenUsageEvent";
195+
auth_token: string;
196+
org_id?: string;
197+
model: string;
198+
resolved_model: string;
199+
org_name?: string;
200+
input_tokens?: number;
201+
output_tokens?: number;
202+
cached_input_tokens?: number;
203+
cache_write_input_tokens?: number;
204+
};
205+
193206
// This is an isomorphic implementation of proxyV1, which is used by both edge functions
194207
// in CloudFlare and by the node proxy (locally and in lambda).
195208
export async function proxyV1({
@@ -208,6 +221,8 @@ export async function proxyV1({
208221
cacheKeyOptions = {},
209222
decompressFetch = false,
210223
spanLogger,
224+
billingOrgId,
225+
onBillingEvent,
211226
signal,
212227
fetch = globalThis.fetch,
213228
}: {
@@ -237,6 +252,8 @@ export async function proxyV1({
237252
cacheKeyOptions?: CacheKeyOptions;
238253
decompressFetch?: boolean;
239254
spanLogger?: SpanLogger;
255+
billingOrgId?: string;
256+
onBillingEvent?: (event: BillingEvent) => void;
240257
signal?: AbortSignal;
241258
fetch?: FetchFn;
242259
}): Promise<void> {
@@ -299,6 +316,7 @@ export async function proxyV1({
299316
);
300317

301318
let orgName: string | undefined = proxyHeaders[ORG_NAME_HEADER] ?? undefined;
319+
let resolvedOrgName: string | undefined = orgName;
302320
const projectId: string | undefined =
303321
proxyHeaders[PROJECT_ID_HEADER] ?? undefined;
304322

@@ -649,6 +667,7 @@ export async function proxyV1({
649667

650668
if (secrets.length > 0 && !orgName && secrets[0].org_name) {
651669
baseAttributes.org_name = secrets[0].org_name;
670+
resolvedOrgName = secrets[0].org_name;
652671
}
653672
logRequest();
654673

@@ -759,6 +778,11 @@ export async function proxyV1({
759778
if (stream) {
760779
let first = true;
761780
const allChunks: Uint8Array[] = [];
781+
let resolvedModel: string | undefined = undefined;
782+
let inputTokens: number | undefined = undefined;
783+
let outputTokens: number | undefined = undefined;
784+
let cachedInputTokens: number | undefined = undefined;
785+
let cacheWriteInputTokens: number | undefined = undefined;
762786

763787
// These parameters are for the streaming case
764788
let reasoning: OpenAIReasoning[] | undefined = undefined;
@@ -787,10 +811,20 @@ export async function proxyV1({
787811
| OpenAIChatCompletionChunk
788812
| undefined;
789813
if (result) {
814+
if (typeof result.model === "string" && result.model) {
815+
resolvedModel = result.model;
816+
}
790817
const extendedUsage = completionUsageSchema.safeParse(
791818
result.usage,
792819
);
793820
if (extendedUsage.success) {
821+
inputTokens = extendedUsage.data.prompt_tokens;
822+
outputTokens = extendedUsage.data.completion_tokens;
823+
cachedInputTokens =
824+
extendedUsage.data.prompt_tokens_details?.cached_tokens;
825+
cacheWriteInputTokens =
826+
extendedUsage.data.prompt_tokens_details
827+
?.cache_creation_tokens;
794828
spanLogger?.log({
795829
metrics: {
796830
tokens: extendedUsage.data.total_tokens,
@@ -978,10 +1012,20 @@ export async function proxyV1({
9781012
case "chat":
9791013
case "completion": {
9801014
const data = dataRaw as ChatCompletion;
1015+
if (typeof data.model === "string" && data.model) {
1016+
resolvedModel = data.model;
1017+
}
9811018
const extendedUsage = completionUsageSchema.safeParse(
9821019
data.usage,
9831020
);
9841021
if (extendedUsage.success) {
1022+
inputTokens = extendedUsage.data.prompt_tokens;
1023+
outputTokens = extendedUsage.data.completion_tokens;
1024+
cachedInputTokens =
1025+
extendedUsage.data.prompt_tokens_details?.cached_tokens;
1026+
cacheWriteInputTokens =
1027+
extendedUsage.data.prompt_tokens_details
1028+
?.cache_creation_tokens;
9851029
spanLogger?.log({
9861030
output: data.choices,
9871031
metrics: {
@@ -1041,6 +1085,15 @@ export async function proxyV1({
10411085
}
10421086
case "response": {
10431087
const data = dataRaw as OpenAIResponse;
1088+
if (typeof data.model === "string" && data.model) {
1089+
resolvedModel = data.model;
1090+
}
1091+
if (data.usage) {
1092+
inputTokens = data.usage.input_tokens;
1093+
outputTokens = data.usage.output_tokens;
1094+
cachedInputTokens =
1095+
data.usage.input_tokens_details?.cached_tokens;
1096+
}
10441097
spanLogger?.log({
10451098
output: data.output,
10461099
metrics: {
@@ -1089,6 +1142,33 @@ export async function proxyV1({
10891142
});
10901143

10911144
spanLogger?.end();
1145+
if (
1146+
!responseFailed &&
1147+
model &&
1148+
onBillingEvent &&
1149+
resolvedModel &&
1150+
(inputTokens !== undefined ||
1151+
outputTokens !== undefined ||
1152+
cachedInputTokens !== undefined ||
1153+
cacheWriteInputTokens !== undefined)
1154+
) {
1155+
try {
1156+
onBillingEvent({
1157+
event_name: "NativeInferenceTokenUsageEvent",
1158+
auth_token: authToken,
1159+
org_id: billingOrgId,
1160+
model,
1161+
resolved_model: resolvedModel,
1162+
org_name: resolvedOrgName,
1163+
input_tokens: inputTokens,
1164+
output_tokens: outputTokens,
1165+
cached_input_tokens: cachedInputTokens,
1166+
cache_write_input_tokens: cacheWriteInputTokens,
1167+
});
1168+
} catch (error) {
1169+
console.warn("billing callback failed", error);
1170+
}
1171+
}
10921172
controller.terminate();
10931173
},
10941174
});

0 commit comments

Comments
 (0)