-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathserver.py
More file actions
63 lines (53 loc) · 2.14 KB
/
server.py
File metadata and controls
63 lines (53 loc) · 2.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import brotli
import fastapi
import requests
from bucket import Bucket
from gate import AIBudgetGate
def create_app(gate: AIBudgetGate) -> fastapi.FastAPI:
app = fastapi.FastAPI()
session = requests.Session()
base_uri_to_bucket = {
bucket.base_uri: bucket
for bucket in gate.buckets
}
@app.api_route("/{base_uri}/{path:path}", methods = ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS", "TRACE"])
async def api_proxy(base_uri: str, path: str, request: fastapi.Request) -> fastapi.Response:
bucket = base_uri_to_bucket.get("/" + base_uri)
if bucket is None:
return fastapi.Response(status_code = 404)
if bucket.exceeds_threshold(gate.budget_threshold):
return fastapi.Response(status_code = 429)
method = request.method.upper()
url = f"{bucket.upstream_url}/{path}"
params = request.url.query
data = await request.body()
headers = {k: v for k, v in request.headers.items() if k.lower() != "host"}
response = session.request(
method = method,
url = url,
params = params,
data = data,
headers = headers
)
if response.headers.get('content-encoding') == 'br':
try:
response._content = brotli.decompress(response.content)
except Exception:
pass
response.headers['content-encoding'] = 'identity'
if response.headers.get('content-type') == 'application/json':
update_bucket_usage(bucket, response)
return fastapi.Response(
content = response.content,
status_code = response.status_code,
headers = response.headers
)
def update_bucket_usage(bucket: Bucket, response: requests.Response) -> None:
data = response.json()
if isinstance(data, dict):
usage = data.get("usage")
if isinstance(usage, dict):
total_tokens = usage.get("total_tokens")
if isinstance(total_tokens, int):
bucket.increment_usage(total_tokens)
return app