diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8ee27b5..dd42693 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,7 +23,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pytest pycryptodome + pip install pytest pytest-asyncio pycryptodome - name: Run tests run: PYTHONPATH=. pytest \ No newline at end of file diff --git a/sfs2x/transport/__init__.py b/sfs2x/transport/__init__.py new file mode 100644 index 0000000..286d51e --- /dev/null +++ b/sfs2x/transport/__init__.py @@ -0,0 +1,12 @@ +from sfs2x.transport.base import Acceptor, Transport # noqa: I001 +from sfs2x.transport.tcp import TCPAcceptor, TCPTransport +from sfs2x.transport.factory import client_from_url, server_from_url + +__all__ = [ + "Acceptor", + "TCPAcceptor", + "TCPTransport", + "Transport", + "client_from_url", + "server_from_url", +] diff --git a/sfs2x/transport/base.py b/sfs2x/transport/base.py new file mode 100644 index 0000000..f7e5e41 --- /dev/null +++ b/sfs2x/transport/base.py @@ -0,0 +1,68 @@ +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator +from typing import Protocol + +from sfs2x.core import Buffer +from sfs2x.protocol import Message, decode, encode + + +class Transport(ABC): + """Abstract base class for transports.""" + + _closed: bool + + def __init__(self) -> None: + self._closed = True + + async def open(self) -> "Transport": + await self._open() + self._closed = False + return self + + async def send(self, msg: Message) -> None: + if self._closed: + err_msg = "Connection closed by remote host" + raise ConnectionError(err_msg) + await self._send_raw(encode(msg)) + + async def recv(self) -> Message: + if self._closed: + msg = "Connection closed by remote host" + raise ConnectionError(msg) + raw = await self._recv_raw() + return decode(Buffer(raw)) + + async def close(self) -> None: + if not self._closed: + await self._close_impl() + self._closed = True + + @abstractmethod + async def _open(self) -> None: + ... + + @abstractmethod + async def _send_raw(self, raw: bytes) -> None: + ... + + @abstractmethod + async def _recv_raw(self) -> bytes: + ... + + @abstractmethod + async def _close_impl(self) -> None: + ... + + @abstractmethod + def host(self) -> str: + ... + + @abstractmethod + def port(self) -> int: + ... + + +class Acceptor(Protocol): + """Async listener for server.""" + + async def __aiter__(self) -> AsyncIterator[Transport]: ... # noqa: D105 diff --git a/sfs2x/transport/factory.py b/sfs2x/transport/factory.py new file mode 100644 index 0000000..66493df --- /dev/null +++ b/sfs2x/transport/factory.py @@ -0,0 +1,37 @@ +from urllib.parse import urlparse + +from sfs2x.transport import Acceptor, TCPAcceptor, TCPTransport, Transport + + +def client_from_url(url: str) -> Transport: + """ + Create transport from url. + + * ``tcp://host:port`` + * ``ws://host:port/path`` + * ``http://host:port/path + """ + u = urlparse(url) + scheme = (u.scheme or "tcp").lower() + + if scheme == "tcp": + port = u.port or 9933 + return TCPTransport(u.hostname or "localhost", port) + raise NotImplementedError + + +def server_from_url(url: str) -> TCPAcceptor | Acceptor: + """ + Create acceptor from url. + + * ``tcp://host:port`` + * ``ws://host:port/path`` + * ``http://host:port/path + """ + u = urlparse(url) + scheme = u.scheme.lower() + + if scheme == "tcp": + port = u.port or 9933 + return TCPAcceptor(u.hostname or "localhost", port) + raise NotImplementedError diff --git a/sfs2x/transport/tcp.py b/sfs2x/transport/tcp.py new file mode 100644 index 0000000..1e0986f --- /dev/null +++ b/sfs2x/transport/tcp.py @@ -0,0 +1,113 @@ +import asyncio +import logging +from asyncio import AbstractServer, IncompleteReadError, StreamReader, StreamWriter, get_running_loop, start_server +from collections.abc import AsyncIterator + +from sfs2x.protocol import Flag +from sfs2x.transport import Acceptor, Transport + +logger = logging.getLogger("SFS2X/TCPTransport") + + +class TCPTransport(Transport): + """SmartFox Transport realisation with Async Streams.""" + + def __init__(self, host: str, port: int) -> None: + super().__init__() + self._host = host + self._port = port + self._reader: StreamReader | None = None + self._writer: StreamWriter | None = None + + @property + def host(self) -> str: + return self._host + + @property + def port(self) -> int: + return self._port + + async def _open(self) -> None: + self._reader, self._writer = await asyncio.open_connection(self._host, self._port) + logger.info("Opened connection to %s:%s", self._host, self._port) + + async def _send_raw(self, raw: bytes) -> None: + if not self._writer: + msg = "Connection closed by remote host" + raise ConnectionError(msg) + + self._writer.write(raw) + await self._writer.drain() + logger.info("Sent %s bytes", {len(raw)}) + + async def _recv_raw(self) -> bytes: + if not self._reader: + msg = "Connection closed by remote host" + raise ConnectionError(msg) + + try: + _flags = await self._reader.readexactly(1) + flags = Flag(_flags[0]) + if not flags & Flag.BINARY: + msg = "Invalid packet type" + raise RuntimeWarning(msg) + + len_bytes = await self._reader.readexactly(2) + if flags & Flag.BIG_SIZE: + len_bytes += await self._reader.readexactly(2) + + length = int.from_bytes(len_bytes, byteorder="big", signed=False) + body = await self._reader.readexactly(length) + except IncompleteReadError as e: + msg = "Connection closed by remote host" + raise ConnectionError(msg) from e + + + logger.info("Received %s bytes from %s:%s", length, self._host, self._port) + + return _flags + len_bytes + body + + async def _close_impl(self) -> None: + if self._writer: + self._writer.close() + await self._writer.wait_closed() + logger.info("Closed connection to %s:%s", self._host, self._port) + + +class TCPAcceptor(Acceptor): + """Server-Side implementation of the TCP Acceptor.""" + + def __init__(self, host: str, port: int) -> None: + super().__init__() + self._host = host + self._port = port + self._server: AbstractServer | None = None + + async def __aiter__(self) -> AsyncIterator[Transport]: # type: ignore # noqa: PGH003 + """Iterate all new connections.""" + loop = get_running_loop() + self._server = await start_server(self._on_conn, self._host, self._port) + logger.info("Started server on %s:%s", self._host, self._port) + + self._queue: asyncio.Queue[TCPTransport] = asyncio.Queue() + + async def producer() -> None: + async with self._server: # type: ignore # noqa: PGH003 + await self._server.serve_forever() # type: ignore # noqa: PGH003 + + loop.create_task(producer()) # noqa: RUF006 + + try: + while True: + yield await self._queue.get() + finally: + self._server.close() + + async def _on_conn(self, reader: StreamReader, writer: StreamWriter) -> None: + host, port = writer.get_extra_info("peername") + logger.info("Connection from %s:%s", host, port) + transport = TCPTransport(host, port) + transport._reader = reader # noqa: SLF001 + transport._writer = writer # noqa: SLF001 + transport._closed = False # noqa: SLF001 + await self._queue.put(transport) diff --git a/tests/test_protocol.py b/tests/test_protocol.py new file mode 100644 index 0000000..d791c08 --- /dev/null +++ b/tests/test_protocol.py @@ -0,0 +1,73 @@ +import asyncio +import pytest +import pytest_asyncio + +from sfs2x.core import Float, UtfString, Int, Double +from sfs2x.transport import client_from_url, server_from_url, TCPTransport +from sfs2x.protocol import Message, ControllerID, SysAction +from sfs2x.core.types.containers import SFSObject + +@pytest_asyncio.fixture +async def echo_server(event_loop): + server_task = event_loop.create_task(run_echo_server()) + await asyncio.sleep(0.2) + + yield + + server_task.cancel() + with pytest.raises(asyncio.CancelledError): + await server_task + +async def run_echo_server(): + async for conn in server_from_url("tcp://0.0.0.0:9000"): + asyncio.create_task(some_handler(conn)) + +async def some_handler(conn: TCPTransport): + try: + while True: + msg = await conn.recv() + obj = msg.payload.value.get('input') + obj.value *= 2 + msg.payload['resp'] = obj + await conn.send(msg) + except ConnectionError: + await conn.close() + +@pytest.mark.asyncio +async def test_tcp_echo_roundtrip(echo_server): + conn = await client_from_url("tcp://localhost:9000").open() + for value in [UtfString('Friday, '), Int(8), Double(123.12)]: + test_msg = Message(ControllerID.SYSTEM, SysAction.PING_PONG, SFSObject({'input': value})) + await conn.send(test_msg) + + answer = await conn.recv() + assert answer.controller == test_msg.controller + assert answer.action == test_msg.action + assert answer.payload.get('resp') == value.value * 2 + await conn.close() + +@pytest.mark.asyncio +async def test_msm_server(): + conn = await client_from_url("tcp://107.20.67.227").open() + + session_info = SFSObject() + session_info.put_utf_string("api", "1.0.3") + session_info.put_utf_string("cl", "UnityPlayer::") + session_info.put_bool("bin", True) + + await conn.send(Message(ControllerID.SYSTEM, SysAction.HANDSHAKE, session_info)) + + handshake = await conn.recv() + assert handshake.controller == ControllerID.SYSTEM + assert handshake.action == SysAction.HANDSHAKE + + auth_info = SFSObject() + auth_info.put_utf_string("zn", "MySingingPenis") + auth_info.put_utf_string("un", "") + auth_info.put_utf_string("pw", "") + auth_info.put_sfs_object("p", SFSObject()) + + await conn.send(Message(ControllerID.SYSTEM, SysAction.LOGIN, auth_info)) + + resp = await conn.recv() + assert resp.payload['ec'] == 1 \ No newline at end of file