Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions journalpump/senders/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@
except ImportError:
zstd = None

KAFKA_CONN_ERRORS = tuple(errors.RETRY_ERROR_TYPES) + (
errors.UnknownError,
socket.timeout,
TimeoutError,
)

logging.getLogger("kafka").setLevel(logging.CRITICAL) # remove client-internal tracebacks from logging output


Expand Down Expand Up @@ -83,7 +77,11 @@ def _init_kafka(self) -> None:

try:
kafka_producer = KafkaProducer(**producer_config)
except KAFKA_CONN_ERRORS as ex:
except (errors.KafkaError, socket.timeout, TimeoutError) as ex:
if isinstance(ex, errors.KafkaError):
# Reraise exceptions that are fatal
if not ex.retriable:
raise
self.mark_disconnected(ex)
self.log.warning(
"Retriable error during Kafka initialization: %s: %s",
Expand Down Expand Up @@ -131,7 +129,11 @@ def send_messages(self, *, messages, cursor):
result_future.get(timeout=1)
self.mark_sent(messages=messages, cursor=cursor)
return True
except KAFKA_CONN_ERRORS as ex:
except (errors.KafkaError, socket.timeout, TimeoutError) as ex:
if isinstance(ex, errors.KafkaError):
# Reraise exceptions that are fatal
if not ex.retriable:
raise
self.mark_disconnected(ex)
self.log.info(
"Kafka retriable error during send: %s: %s, waiting",
Expand Down
89 changes: 68 additions & 21 deletions journalpump/senders/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from journalpump import __version__
from journalpump.types import StrEnum
from journalpump.util import ExponentialBackoff
from packaging.version import Version
from threading import Thread
from urllib.parse import urlparse

import asyncio
import contextlib
import enum
import logging
import random
import snappy # pylint: disable=import-error
import socket
import ssl
Expand Down Expand Up @@ -163,6 +165,8 @@ async def websocket_connect_coro(self):

sock = None
url_parsed = urlparse(self.websocket_uri)
preferred_host = None

if self.socks5_proxy:
socks_url_parsed = urlparse(self.socks5_proxy_url)
self.log.info(
Expand All @@ -176,18 +180,61 @@ async def websocket_connect_coro(self):
socks_url_parsed.hostname,
socks_url_parsed.port,
)
else:
# Resolve hostname and pick one address at random
# Websockets.connect() and underlying asyncio.loop.create_connection() have some overlapping timeouts
# that lead to little bit difficulties with working through all addresses returned by getaddrinfo.
# We pick one at random here, and rely our own outer loops to handle retries.
present_addrs = [
# getaddrinfo returns 5-tuple (family, type, proto, canonname, sockaddr)
# sockaddr is family dependent, but leads with address for both the IPv6 and IPv4 families
sockaddr[0]
for _, _, _, _, sockaddr in await self.websocket_loop.getaddrinfo(
url_parsed.hostname, 0, type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP
)
]
if present_addrs:
preferred_host = random.choice(present_addrs)
else:
# We couldn't resolve a suitable name, fallback to async.loop.create_connection() name handling
preferred_host = url_parsed.hostname

ws_compr = None if self.websocket_compression == WebsocketCompression.none else str(self.websocket_compression)
return await websockets.connect( # pylint:disable=no-member
self.websocket_uri,
ssl=ssl_context,
compression=ws_compr,
extra_headers=headers,
sock=sock,
server_hostname=url_parsed.hostname if self.ssl_enabled else None,
close_timeout=20,
max_size=MAX_KAFKA_MESSAGE_SIZE * 2,
)
# In order to support version transition in websockects, we generated kwargs dynamically
connect_kwargs = {
"close_timeout": 20,
"max_size": MAX_KAFKA_MESSAGE_SIZE * 2,
"ssl": ssl_context,
}

if self.websocket_compression != WebsocketCompression.none:
connect_kwargs["compression"] = str(self.websocket_compression)

if self.ssl_enabled:
connect_kwargs["server_hostname"] = url_parsed.hostname

if sock:
connect_kwargs["sock"] = sock

# Versions 13.0 and up switched into additional_headers, lder versions expect extra_headers
# Versions 15.0 up introduce a separate user_agent_header
if headers:
websockets_version = Version(websockets.__version__)
if websockets_version >= Version("13.0"):
if websockets_version >= Version("15.0"):
user_agent = headers.pop("User-Agent", None)
if user_agent:
connect_kwargs["user_agent_header"] = user_agent
if headers:
connect_kwargs["additional_headers"] = headers
else:
connect_kwargs["additional_headers"] = headers
else:
connect_kwargs["extra_headers"] = headers

if preferred_host:
connect_kwargs["host"] = preferred_host

return await websockets.connect(self.websocket_uri, **connect_kwargs)

async def websocket_connect(self, *, timeout=30):
connect_task = asyncio.create_task(self.websocket_connect_coro())
Expand Down Expand Up @@ -244,24 +291,24 @@ async def comms_channel_round(self):

for task in pending:
task.cancel()
except ConnectionRefusedError as ex:
self.log.warning("Websocket connection refused: %r. Retrying.", ex)
except (ConnectionTimeoutError, asyncio.TimeoutError, CancelledError) as ex:
self.log.warning("Websocket connection timed out: %r. Retrying.", ex)
except socket.gaierror as ex:
self.log.error(
"DNS lookup for websocket endpoint or SOCKS5 proxy failed: %r. Retrying.",
ex,
)
except websockets.exceptions.InvalidStatusCode as ex:
except ConnectionRefusedError as ex:
self.log.warning("Websocket connection refused: %r. Retrying.", ex)
except ssl.SSLCertVerificationError as ex:
self.log.error("Websocket certificate verification error: %r. Retrying.", ex)
except (ProxyError, ProxyConnectionError, ProxyTimeoutError) as ex:
self.log.warning("SOCKS5 proxy connection error: %r. Retrying.", ex)
except (ConnectionTimeoutError, asyncio.TimeoutError, CancelledError) as ex:
self.log.warning("Websocket connection timed out: %r. Retrying.", ex)
except websockets.exceptions.InvalidHandshake as ex:
self.log.error(
"Websocket server rejected connection with HTTP status code: %r. Retrying.",
"Websocket handshake failed: %r. Retrying.",
ex,
)
except (ProxyError, ProxyConnectionError, ProxyTimeoutError) as ex:
self.log.warning("SOCKS5 proxy connection error: %r. Retrying.", ex)
except ssl.SSLCertVerificationError as ex:
self.log.error("Websocket certificate verification error: %r. Retrying.", ex)
except OSError as ex: # Network unreachable, etc, may happen sporadically
self.log.warning("Websocket connection error: %r. Retrying.", ex)
except Exception as ex: # pylint:disable=broad-except
Expand Down
41 changes: 23 additions & 18 deletions test/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,14 @@


class WebsocketMockServer(threading.Thread):
def __init__(
self,
*,
port,
):
def __init__(self):
super().__init__()
self.daemon = True
self.log = logging.getLogger(self.__class__.__name__)
self.port = port
self.in_queue = deque()
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.start_event = None
self.stop_event = asyncio.Event()
self.running = False
self.websocket_server = None
Expand All @@ -35,8 +31,8 @@ async def handle_incoming_websocket_message(self, *, connection):
self.log.info("WS: Received message: %r", message)
self.in_queue.append(message)

async def process_connection(self, websocket, path):
self.log.info("WS: Client connection accepted on %s", path)
async def process_connection(self, websocket):
self.log.info("WS: Client connection accepted")
pending = set()

try:
Expand Down Expand Up @@ -66,17 +62,20 @@ async def run_websocket_server(self):
# ctx.load_verify_locations(self.ca_certs)
# ctx.verify_mode = ssl.CERT_REQUIRED

self.start_event = asyncio.Event()

# websockets uses lazy_import and pylint doesn't quite get it
async with websockets.serve( # pylint:disable=no-member
self.process_connection,
"127.0.0.1",
self.port,
loop=self.loop,
None,
ssl=ctx,
close_timeout=10,
) as server:
self.websocket_server = server
self.log.info("WS: Started serving websocket connections")
self.start_event.set()
self.start_event = None
await self.stop_event.wait()

self.log.info("WS: Stopped serving websocket connections")
Expand Down Expand Up @@ -106,6 +105,14 @@ def stop(self):
task.cancel()
self.log.info("WS: stopped")

async def get_port_task(self):
if self.start_event:
await self.start_event.wait()
return self.websocket_server.sockets[0].getsockname()[1]

def get_port(self):
return asyncio.run_coroutine_threadsafe(self.get_port_task(), self.loop).result()


def assert_msgs_found(ws_server, *, messages, timeout):
# Check that all of these messages were sent to the websocket server.
Expand Down Expand Up @@ -157,16 +164,15 @@ def setup_pump(tmpdir, sender_config):

def test_producer_nobatch(caplog, tmpdir):
caplog.set_level(logging.INFO)
ws_server = WebsocketMockServer(
port=10111,
)
ws_server = WebsocketMockServer()
ws_server.start()
port = ws_server.get_port()

pump, sender = setup_pump(
tmpdir,
{
"output_type": "websocket",
"websocket_uri": "ws://127.0.0.1:10111/pump-pump",
"websocket_uri": f"ws://127.0.0.1:{port}/pump-pump",
"compression": "none",
"max_batch_size": 0,
},
Expand All @@ -186,16 +192,15 @@ def test_producer_nobatch(caplog, tmpdir):

def test_producer_batch(caplog, tmpdir):
caplog.set_level(logging.INFO)
ws_server = WebsocketMockServer(
port=10111,
)
ws_server = WebsocketMockServer()
ws_server.start()
port = ws_server.get_port()

pump, sender = setup_pump(
tmpdir,
{
"output_type": "websocket",
"websocket_uri": "ws://127.0.0.1:10111/pump-pump",
"websocket_uri": f"ws://127.0.0.1:{port}/pump-pump",
"compression": "snappy",
"max_batch_size": 1024,
},
Expand Down
Loading