Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion proxy/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,23 @@ async def _tcp_fallback(reader, writer, dst, port, relay_init, label, ctx: Crypt
return True


async def _ws_keepalive(ws, interval: float):
"""Send periodic WS PING frames to keep the upstream flow warm.

A non-positive interval disables keepalive. The loop exits on send
failure so a dead upstream is detected promptly instead of lingering
until the next client packet (see issue #646).
"""
if interval <= 0:
return
try:
while True:
await asyncio.sleep(interval)
await ws.send_ping()
except (asyncio.CancelledError, ConnectionError, OSError):
return


async def bridge_ws_reencrypt(reader, writer, ws: RawWebSocket, label,
ctx: CryptoCtx,
dc=None, is_media=False,
Expand Down Expand Up @@ -334,12 +351,15 @@ async def ws_to_tcp():

tasks = [asyncio.create_task(tcp_to_ws()),
asyncio.create_task(ws_to_tcp())]
keepalive = asyncio.ensure_future(
_ws_keepalive(ws, proxy_config.ws_keepalive_interval))
try:
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
finally:
keepalive.cancel()
for t in tasks:
t.cancel()
for t in tasks:
for t in (*tasks, keepalive):
try:
await t
except BaseException:
Expand Down
1 change: 1 addition & 0 deletions proxy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class ProxyConfig:
cfproxy_worker_domains: List[str] = field(default_factory=list)
fake_tls_domain: str = ''
proxy_protocol: bool = False
ws_keepalive_interval: float = 30.0


proxy_config = ProxyConfig()
Expand Down
7 changes: 7 additions & 0 deletions proxy/raw_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ async def send_batch(self, parts: List[bytes]):
self._build_frame(self.OP_BINARY, part, mask=True))
await self.writer.drain()

async def send_ping(self, payload: bytes = b''):
if self._closed:
raise ConnectionError("WebSocket closed")
frame = self._build_frame(self.OP_PING, payload, mask=True)
self.writer.write(frame)
await self.writer.drain()

async def recv(self) -> Optional[bytes]:
while not self._closed:
opcode, payload = await self._read_frame()
Expand Down
5 changes: 5 additions & 0 deletions proxy/tg_ws_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,10 @@ def main():
ap.add_argument('--proxy-protocol', action='store_true',
help='Accept PROXY protocol v1 header '
'(for use behind nginx/haproxy with proxy_protocol on)')
ap.add_argument('--ws-keepalive', type=float, default=30.0, metavar='SEC',
help='Seconds between WebSocket keepalive PINGs to the '
'upstream (default 30, 0 to disable). Keeps idle '
'sessions alive through NAT/firewall timeouts.')
args = ap.parse_args()

if not args.dc_ip:
Expand Down Expand Up @@ -628,6 +632,7 @@ def main():
proxy_config.cfproxy_worker_domains = coerce_domain_list(args.cfproxy_worker_domain)
proxy_config.fake_tls_domain = args.fake_tls_domain.strip()
proxy_config.proxy_protocol = args.proxy_protocol
proxy_config.ws_keepalive_interval = max(0.0, args.ws_keepalive)

log_level = logging.DEBUG if args.verbose else logging.INFO
log_fmt = logging.Formatter('%(asctime)s %(levelname)-5s %(message)s',
Expand Down