Skip to content

Commit 0576bcc

Browse files
authored
Merge pull request #80 from intercreate/fix/#77/udp-mss-instead-of-mtu
Fix/#77/udp mss instead of mtu
2 parents 9089081 + 3e4c274 commit 0576bcc

4 files changed

Lines changed: 90 additions & 5 deletions

File tree

.vscode/settings.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"editor.defaultFormatter": "ms-python.black-formatter",
88
"editor.formatOnSave": true,
99
"editor.codeActionsOnSave": {
10-
"source.organizeImports": "explicit"
10+
"source.organizeImports.isort": "explicit"
1111
},
1212
},
1313
"python.testing.pytestArgs": [

envr-default

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ PYTHON_VENV=.venv
77
[ADD_TO_PATH]
88

99
[ALIASES]
10-
lint=black --check . && isort --check-only . && flake8 . && pydoclint smpclient && mypy .
10+
lint=black --check . && isort --check-only --diff . && flake8 . && pydoclint smpclient && mypy .
1111
test=coverage erase && pytest --cov --maxfail=1

smpclient/transport/udp.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import logging
5+
from socket import AF_INET6
56
from typing import Final
67

78
from smp import header as smphdr
@@ -13,22 +14,51 @@
1314

1415
logger = logging.getLogger(__name__)
1516

17+
IPV4_HEADER_SIZE: Final = 20
18+
"""Minimum IPv4 header size in bytes."""
19+
20+
IPV6_HEADER_SIZE: Final = 40
21+
"""IPv6 header size in bytes."""
22+
23+
UDP_HEADER_SIZE: Final = 8
24+
"""UDP header size in bytes."""
25+
26+
IPV4_UDP_OVERHEAD: Final = IPV4_HEADER_SIZE + UDP_HEADER_SIZE
27+
"""Total overhead (28 bytes) to subtract from MTU to get maximum UDP payload (MSS) for IPv4.
28+
29+
Per RFC 8085 section 3.2, applications must subtract IP and UDP header sizes from the
30+
PMTU to avoid fragmentation."""
31+
32+
IPV6_UDP_OVERHEAD: Final = IPV6_HEADER_SIZE + UDP_HEADER_SIZE
33+
"""Total overhead (48 bytes) to subtract from MTU to get maximum UDP payload (MSS) for IPv6.
34+
35+
Per RFC 8085 section 3.2, applications must subtract IP and UDP header sizes from the
36+
PMTU to avoid fragmentation."""
37+
1638

1739
class SMPUDPTransport(SMPTransport):
1840
def __init__(self, mtu: int = 1500) -> None:
1941
"""Initialize the SMP UDP transport.
2042
2143
Args:
22-
mtu: The Maximum Transmission Unit (MTU) in 8-bit bytes.
44+
mtu: The Maximum Transmission Unit (MTU) of the link layer in bytes.
45+
IP and UDP header overhead will be subtracted to calculate the maximum
46+
UDP payload size (MSS) to avoid fragmentation per RFC 8085 section 3.2.
2347
"""
2448
self._mtu = mtu
49+
self._is_ipv6 = False
2550

2651
self._client: Final = UDPClient()
2752

2853
@override
2954
async def connect(self, address: str, timeout_s: float, port: int = 1337) -> None:
3055
logger.debug(f"Connecting to {address=} {port=}")
3156
await asyncio.wait_for(self._client.connect(Addr(host=address, port=port)), timeout_s)
57+
58+
if sock := self._client._transport.get_extra_info('socket'):
59+
self._is_ipv6 = sock.family == AF_INET6
60+
logger.debug(f"Detected {'IPv6' if self._is_ipv6 else 'IPv4'} connection")
61+
3262
logger.info(f"Connected to {address=} {port=}")
3363

3464
@override
@@ -104,4 +134,10 @@ def mtu(self) -> int:
104134
@override
105135
@property
106136
def max_unencoded_size(self) -> int:
107-
return self._mtu
137+
"""Maximum UDP payload size (MSS) to avoid fragmentation.
138+
139+
Subtracts IPv4/IPv6 and UDP header overhead from MTU per RFC 8085 section 3.2.
140+
The IP version is auto-detected after connection.
141+
"""
142+
overhead = IPV6_UDP_OVERHEAD if self._is_ipv6 else IPV4_UDP_OVERHEAD
143+
return self._mtu - overhead

tests/test_smp_udp_transport.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from smpclient.exceptions import SMPClientException
1010
from smpclient.requests.os_management import EchoWrite
1111
from smpclient.transport._udp_client import Addr, UDPClient
12-
from smpclient.transport.udp import SMPUDPTransport
12+
from smpclient.transport.udp import IPV4_UDP_OVERHEAD, IPV6_UDP_OVERHEAD, SMPUDPTransport
1313

1414

1515
def test_init() -> None:
@@ -27,6 +27,10 @@ async def test_connect(_: MagicMock) -> None:
2727
t = SMPUDPTransport()
2828
t._client = cast(MagicMock, t._client) # type: ignore
2929

30+
# Mock _transport for IPv4/IPv6 detection
31+
t._client._transport = MagicMock()
32+
t._client._transport.get_extra_info.return_value = None
33+
3034
await t.connect("192.168.0.1", 0.001)
3135
t._client.connect.assert_awaited_once_with(Addr(host="192.168.0.1", port=1337))
3236

@@ -110,3 +114,48 @@ async def test_send_and_receive() -> None:
110114
await t.send_and_receive(message)
111115
send_mock.assert_awaited_once_with(message)
112116
receive_mock.assert_awaited_once()
117+
118+
119+
def test_max_unencoded_size_ipv4() -> None:
120+
"""Test MSS calculation for IPv4 (default)."""
121+
t = SMPUDPTransport(mtu=1500)
122+
# Before connection, defaults to IPv4
123+
assert t.max_unencoded_size == 1500 - IPV4_UDP_OVERHEAD
124+
assert t.max_unencoded_size == 1472
125+
126+
127+
def test_max_unencoded_size_custom_mtu() -> None:
128+
"""Test MSS calculation with custom MTU."""
129+
t = SMPUDPTransport(mtu=512)
130+
assert t.max_unencoded_size == 512 - IPV4_UDP_OVERHEAD
131+
assert t.max_unencoded_size == 484
132+
133+
134+
@pytest.mark.asyncio
135+
async def test_ipv4_detection_real_socket() -> None:
136+
"""Test IPv4 auto-detection with real socket connection."""
137+
t = SMPUDPTransport(mtu=1500)
138+
139+
# Create a real UDP connection to localhost IPv4
140+
await t.connect("127.0.0.1", 1.0)
141+
142+
assert t._is_ipv6 is False
143+
assert t.max_unencoded_size == 1500 - IPV4_UDP_OVERHEAD
144+
assert t.max_unencoded_size == 1472
145+
146+
await t.disconnect()
147+
148+
149+
@pytest.mark.asyncio
150+
async def test_ipv6_detection_real_socket() -> None:
151+
"""Test IPv6 auto-detection with real socket connection."""
152+
t = SMPUDPTransport(mtu=1500)
153+
154+
# Create a real UDP connection to localhost IPv6
155+
await t.connect("::1", 1.0)
156+
157+
assert t._is_ipv6 is True
158+
assert t.max_unencoded_size == 1500 - IPV6_UDP_OVERHEAD
159+
assert t.max_unencoded_size == 1452
160+
161+
await t.disconnect()

0 commit comments

Comments
 (0)