diff --git a/proxy/bridge.py b/proxy/bridge.py index b038f543..3a862baa 100644 --- a/proxy/bridge.py +++ b/proxy/bridge.py @@ -263,6 +263,23 @@ async def _tcp_fallback(reader, writer, dst, port, relay_init, label, ctx: Crypt return True +async def _ws_keepalive(ws, interval: float): + """Send periodic WS PING frames to keep the upstream flow warm. + + A non-positive interval disables keepalive. The loop exits on send + failure so a dead upstream is detected promptly instead of lingering + until the next client packet (see issue #646). + """ + if interval <= 0: + return + try: + while True: + await asyncio.sleep(interval) + await ws.send_ping() + except (asyncio.CancelledError, ConnectionError, OSError): + return + + async def bridge_ws_reencrypt(reader, writer, ws: RawWebSocket, label, ctx: CryptoCtx, dc=None, is_media=False, @@ -334,12 +351,15 @@ async def ws_to_tcp(): tasks = [asyncio.create_task(tcp_to_ws()), asyncio.create_task(ws_to_tcp())] + keepalive = asyncio.ensure_future( + _ws_keepalive(ws, proxy_config.ws_keepalive_interval)) try: await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) finally: + keepalive.cancel() for t in tasks: t.cancel() - for t in tasks: + for t in (*tasks, keepalive): try: await t except BaseException: diff --git a/proxy/config.py b/proxy/config.py index 28afe924..0da705fb 100644 --- a/proxy/config.py +++ b/proxy/config.py @@ -62,6 +62,7 @@ class ProxyConfig: cfproxy_worker_domains: List[str] = field(default_factory=list) fake_tls_domain: str = '' proxy_protocol: bool = False + ws_keepalive_interval: float = 30.0 proxy_config = ProxyConfig() diff --git a/proxy/raw_websocket.py b/proxy/raw_websocket.py index 30d07e9d..9e7a2483 100644 --- a/proxy/raw_websocket.py +++ b/proxy/raw_websocket.py @@ -154,6 +154,13 @@ async def send_batch(self, parts: List[bytes]): self._build_frame(self.OP_BINARY, part, mask=True)) await self.writer.drain() + async def send_ping(self, payload: bytes = b''): + if self._closed: + raise ConnectionError("WebSocket closed") + frame = self._build_frame(self.OP_PING, payload, mask=True) + self.writer.write(frame) + await self.writer.drain() + async def recv(self) -> Optional[bytes]: while not self._closed: opcode, payload = await self._read_frame() diff --git a/proxy/tg_ws_proxy.py b/proxy/tg_ws_proxy.py index daee01e5..be8e0206 100644 --- a/proxy/tg_ws_proxy.py +++ b/proxy/tg_ws_proxy.py @@ -592,6 +592,10 @@ def main(): ap.add_argument('--proxy-protocol', action='store_true', help='Accept PROXY protocol v1 header ' '(for use behind nginx/haproxy with proxy_protocol on)') + ap.add_argument('--ws-keepalive', type=float, default=30.0, metavar='SEC', + help='Seconds between WebSocket keepalive PINGs to the ' + 'upstream (default 30, 0 to disable). Keeps idle ' + 'sessions alive through NAT/firewall timeouts.') args = ap.parse_args() if not args.dc_ip: @@ -628,6 +632,7 @@ def main(): proxy_config.cfproxy_worker_domains = coerce_domain_list(args.cfproxy_worker_domain) proxy_config.fake_tls_domain = args.fake_tls_domain.strip() proxy_config.proxy_protocol = args.proxy_protocol + proxy_config.ws_keepalive_interval = max(0.0, args.ws_keepalive) log_level = logging.DEBUG if args.verbose else logging.INFO log_fmt = logging.Formatter('%(asctime)s %(levelname)-5s %(message)s',