From fd3371bc9b041853557d9cbdcf160f9e498c9329 Mon Sep 17 00:00:00 2001 From: f4rceful Date: Sun, 31 May 2026 10:26:08 -0700 Subject: [PATCH] fix: add WebSocket keepalive pings to prevent idle disconnects (#646) Idle sessions could be silently torn down by NAT/firewall timeouts on the upstream WS connection: with no traffic and no keepalive, the bridge only noticed the dead connection on the next client packet, by which point Telegram had already dropped the session. Add an optional WS keepalive: RawWebSocket.send_ping() emits a masked PING frame, and _ws_keepalive() runs alongside the bridge pumps, pinging every ws_keepalive_interval seconds (--ws-keepalive, default 30, 0 to disable). The loop exits on send failure so a dead upstream is detected promptly. --- proxy/bridge.py | 22 +++++++++++++++++++++- proxy/config.py | 1 + proxy/raw_websocket.py | 7 +++++++ proxy/tg_ws_proxy.py | 5 +++++ 4 files changed, 34 insertions(+), 1 deletion(-) diff --git a/proxy/bridge.py b/proxy/bridge.py index b038f543..3a862baa 100644 --- a/proxy/bridge.py +++ b/proxy/bridge.py @@ -263,6 +263,23 @@ async def _tcp_fallback(reader, writer, dst, port, relay_init, label, ctx: Crypt return True +async def _ws_keepalive(ws, interval: float): + """Send periodic WS PING frames to keep the upstream flow warm. + + A non-positive interval disables keepalive. The loop exits on send + failure so a dead upstream is detected promptly instead of lingering + until the next client packet (see issue #646). + """ + if interval <= 0: + return + try: + while True: + await asyncio.sleep(interval) + await ws.send_ping() + except (asyncio.CancelledError, ConnectionError, OSError): + return + + async def bridge_ws_reencrypt(reader, writer, ws: RawWebSocket, label, ctx: CryptoCtx, dc=None, is_media=False, @@ -334,12 +351,15 @@ async def ws_to_tcp(): tasks = [asyncio.create_task(tcp_to_ws()), asyncio.create_task(ws_to_tcp())] + keepalive = asyncio.ensure_future( + _ws_keepalive(ws, proxy_config.ws_keepalive_interval)) try: await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) finally: + keepalive.cancel() for t in tasks: t.cancel() - for t in tasks: + for t in (*tasks, keepalive): try: await t except BaseException: diff --git a/proxy/config.py b/proxy/config.py index 28afe924..0da705fb 100644 --- a/proxy/config.py +++ b/proxy/config.py @@ -62,6 +62,7 @@ class ProxyConfig: cfproxy_worker_domains: List[str] = field(default_factory=list) fake_tls_domain: str = '' proxy_protocol: bool = False + ws_keepalive_interval: float = 30.0 proxy_config = ProxyConfig() diff --git a/proxy/raw_websocket.py b/proxy/raw_websocket.py index 30d07e9d..9e7a2483 100644 --- a/proxy/raw_websocket.py +++ b/proxy/raw_websocket.py @@ -154,6 +154,13 @@ async def send_batch(self, parts: List[bytes]): self._build_frame(self.OP_BINARY, part, mask=True)) await self.writer.drain() + async def send_ping(self, payload: bytes = b''): + if self._closed: + raise ConnectionError("WebSocket closed") + frame = self._build_frame(self.OP_PING, payload, mask=True) + self.writer.write(frame) + await self.writer.drain() + async def recv(self) -> Optional[bytes]: while not self._closed: opcode, payload = await self._read_frame() diff --git a/proxy/tg_ws_proxy.py b/proxy/tg_ws_proxy.py index daee01e5..be8e0206 100644 --- a/proxy/tg_ws_proxy.py +++ b/proxy/tg_ws_proxy.py @@ -592,6 +592,10 @@ def main(): ap.add_argument('--proxy-protocol', action='store_true', help='Accept PROXY protocol v1 header ' '(for use behind nginx/haproxy with proxy_protocol on)') + ap.add_argument('--ws-keepalive', type=float, default=30.0, metavar='SEC', + help='Seconds between WebSocket keepalive PINGs to the ' + 'upstream (default 30, 0 to disable). Keeps idle ' + 'sessions alive through NAT/firewall timeouts.') args = ap.parse_args() if not args.dc_ip: @@ -628,6 +632,7 @@ def main(): proxy_config.cfproxy_worker_domains = coerce_domain_list(args.cfproxy_worker_domain) proxy_config.fake_tls_domain = args.fake_tls_domain.strip() proxy_config.proxy_protocol = args.proxy_protocol + proxy_config.ws_keepalive_interval = max(0.0, args.ws_keepalive) log_level = logging.DEBUG if args.verbose else logging.INFO log_fmt = logging.Formatter('%(asctime)s %(levelname)-5s %(message)s',