-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathserver.py
More file actions
137 lines (110 loc) · 4.28 KB
/
server.py
File metadata and controls
137 lines (110 loc) · 4.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from fastapi import FastAPI, Request, Response, HTTPException, WebSocket
from models import TunnelRequest, TunnelResponse
from auth import hash_secret, verify_hmac
import asyncio
import msgpack
import uuid
import os
import argparse
from rich.console import Console
app = FastAPI()
console = Console()
parser = argparse.ArgumentParser()
parser.add_argument(
"--secret", help="Secret for authentication (can also use TUNNEL_SECRET env var)"
)
parser.add_argument(
"--host", default="0.0.0.0", help="Server host (can also use TUNNEL_HOST env var)"
)
parser.add_argument(
"--port",
type=int,
default=5001,
help="Server port (can also use TUNNEL_PORT env var)",
)
args, _ = parser.parse_known_args()
secret = args.secret or os.environ.get("TUNNEL_SECRET")
host = os.environ.get("TUNNEL_HOST") or args.host
port = int(os.environ.get("TUNNEL_PORT") or args.port)
secret_hash = hash_secret(secret) if secret else None
@app.websocket("/tunnel")
async def websocket_tunnel(websocket: WebSocket):
await websocket.accept()
if secret_hash:
try:
challenge = str(uuid.uuid4())
await websocket.send_bytes(
msgpack.packb({"type": "challenge", "challenge": challenge})
)
response_bytes = await asyncio.wait_for(
websocket.receive_bytes(), timeout=10.0
)
response_data = msgpack.unpackb(response_bytes)
if response_data.get("type") != "auth":
await websocket.close(code=4008, reason="Expected auth message")
return
tag = response_data.get("tag")
if not verify_hmac(secret_hash, challenge, tag):
console.print("[red]Authentication failed: Invalid secret[/red]")
await websocket.close(code=4008, reason="Invalid secret")
return
console.print("[green]Client authenticated successfully[/green]")
except asyncio.TimeoutError:
console.print("[yellow]Authentication timeout[/yellow]")
await websocket.close(code=4008, reason="Auth timeout")
return
except Exception as e:
import traceback
console.print(f"[red]Authentication error: {e}[/red]")
console.print(f"[red]{traceback.format_exc()}[/red]")
await websocket.close(code=4008, reason=f"Auth error: {str(e)}")
return
app.state.websocket = websocket
app.state.response_queue = asyncio.Queue()
app.state.response_futures = {}
try:
async def receive_loop():
while True:
data = msgpack.unpackb(await websocket.receive_bytes())
if data["type"] == "response":
request_id = data["id"]
if request_id in app.state.response_futures:
app.state.response_futures[request_id].set_result(data)
await receive_loop()
except Exception as e:
console.print(f"[red]Connection error: {e}[/red]")
finally:
app.state.websocket = None
@app.api_route(
"/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]
)
async def catch_all(path: str, request: Request):
if not hasattr(app.state, "websocket") or app.state.websocket is None:
raise HTTPException(status_code=503, detail="No tunnel client connected")
request_id = str(id(request))
headers = dict(request.headers)
body = await request.body()
tunnel_request = TunnelRequest(
id=request_id,
method=request.method,
path=f"/{path}",
headers=headers,
body=body if body else None,
)
app.state.response_futures[request_id] = asyncio.Future()
try:
await app.state.websocket.send_bytes(msgpack.packb(tunnel_request.model_dump()))
response_data = await asyncio.wait_for(
app.state.response_futures[request_id], timeout=30
)
tunnel_response = TunnelResponse(**response_data)
return Response(
content=tunnel_response.body,
status_code=tunnel_response.status_code,
headers=tunnel_response.headers,
)
finally:
app.state.response_futures.pop(request_id, None)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=host, port=port)