Skip to content

Commit 523d358

Browse files
committed
wip: use mcumgr params for serial transports
1 parent 0576bcc commit 523d358

File tree

4 files changed

+175
-42
lines changed

4 files changed

+175
-42
lines changed

examples/usb/upgrade.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,10 @@ async def main() -> None:
108108
print("Connecting to SMP DUT...", end="", flush=True)
109109
async with SMPClient(
110110
SMPSerialTransport(
111-
max_smp_encoded_frame_size=max_smp_encoded_frame_size,
112-
line_length=line_length,
113-
line_buffers=line_buffers,
111+
fragmentation_strategy=SMPSerialTransport.BufferParams(
112+
line_length=line_length,
113+
line_buffers=line_buffers,
114+
)
114115
),
115116
port_a.device,
116117
) as client:
@@ -187,9 +188,10 @@ async def ensure_request(request: SMPRequest[TRep, TEr1, TEr2]) -> TRep:
187188
print("Connecting to B SMP DUT...", end="", flush=True)
188189
async with SMPClient(
189190
SMPSerialTransport(
190-
max_smp_encoded_frame_size=max_smp_encoded_frame_size,
191-
line_length=line_length,
192-
line_buffers=line_buffers,
191+
fragmentation_strategy=SMPSerialTransport.BufferParams(
192+
line_length=line_length,
193+
line_buffers=line_buffers,
194+
)
193195
),
194196
port_b.device,
195197
) as client:

smpclient/transport/serial.py

Lines changed: 84 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
import math
1111
import time
1212
from enum import IntEnum, unique
13-
from functools import cached_property
14-
from typing import Final
13+
from typing import Final, NamedTuple
1514

1615
from serial import Serial, SerialException
1716
from smp import packet as smppacket
@@ -69,11 +68,29 @@ def __init__(self) -> None:
6968
self.state = SMPSerialTransport._ReadBuffer.State.SER
7069
"""The state of the read buffer."""
7170

71+
class Auto(NamedTuple):
72+
"""Automatically determine buffer parameters from the SMP server.
73+
74+
On connect, queries the server's MCUMGR_PARAM for `buf_size`
75+
(CONFIG_MCUMGR_TRANSPORT_NETBUF_SIZE) and calculates:
76+
- line_length: 127 (standard MTU for uart/usb/shell)
77+
- line_buffers: buf_size / line_length
78+
79+
Falls back to BufferParams() if server doesn't support MCUMGR_PARAM.
80+
"""
81+
82+
class BufferParams(NamedTuple):
83+
"""Buffer parameters for the serial transport."""
84+
85+
line_length: int = 127
86+
"""The maximum SMP packet size."""
87+
88+
line_buffers: int = 1
89+
"""The number of line buffers in the serial buffer."""
90+
7291
def __init__( # noqa: DOC301
7392
self,
74-
max_smp_encoded_frame_size: int = 256,
75-
line_length: int = 128,
76-
line_buffers: int = 2,
93+
fragmentation_strategy: Auto | BufferParams = Auto(),
7794
baudrate: int = 115200,
7895
bytesize: int = 8,
7996
parity: str = "N",
@@ -89,11 +106,10 @@ def __init__( # noqa: DOC301
89106
"""Initialize the serial transport.
90107
91108
Args:
92-
max_smp_encoded_frame_size: The maximum size of an encoded SMP
93-
frame. The SMP server needs to have a buffer large enough to
94-
receive the encoded frame packets and to store the decoded frame.
95-
line_length: The maximum SMP packet size.
96-
line_buffers: The number of line buffers in the serial buffer.
109+
fragmentation_strategy: The fragmentation strategy to use. Either
110+
`SMPSerialTransport.Auto()` to automatically determine buffer
111+
parameters from the SMP server, or `SMPSerialTransport.BufferParams`
112+
to manually specify buffer parameters.
97113
baudrate: The baudrate of the serial connection. OK to ignore for
98114
USB CDC ACM.
99115
bytesize: The number of data bits.
@@ -108,18 +124,7 @@ def __init__( # noqa: DOC301
108124
exclusive: The exclusive access timeout.
109125
110126
"""
111-
if max_smp_encoded_frame_size < line_length * line_buffers:
112-
logger.error(
113-
f"{max_smp_encoded_frame_size=} is less than {line_length=} * {line_buffers=}!"
114-
)
115-
elif max_smp_encoded_frame_size != line_length * line_buffers:
116-
logger.warning(
117-
f"{max_smp_encoded_frame_size=} is not equal to {line_length=} * {line_buffers=}!"
118-
)
119-
120-
self._max_smp_encoded_frame_size: Final = max_smp_encoded_frame_size
121-
self._line_length: Final = line_length
122-
self._line_buffers: Final = line_buffers
127+
self._fragmentation_strategy: Final = fragmentation_strategy
123128
self._conn: Final = Serial(
124129
baudrate=baudrate,
125130
bytesize=bytesize,
@@ -136,6 +141,62 @@ def __init__( # noqa: DOC301
136141
self._buffer = SMPSerialTransport._ReadBuffer()
137142
logger.debug(f"Initialized {self.__class__.__name__}")
138143

144+
@property
145+
def _line_length(self) -> int:
146+
"""The maximum SMP packet size."""
147+
if isinstance(self._fragmentation_strategy, SMPSerialTransport.Auto):
148+
return self.BufferParams().line_length
149+
else:
150+
return self._fragmentation_strategy.line_length
151+
152+
@property
153+
def _line_buffers(self) -> int:
154+
"""The number of line buffers."""
155+
if isinstance(self._fragmentation_strategy, SMPSerialTransport.Auto):
156+
if self._smp_server_transport_buffer_size is not None:
157+
return self._smp_server_transport_buffer_size // self.BufferParams().line_length
158+
return self.BufferParams().line_buffers
159+
else:
160+
return self._fragmentation_strategy.line_buffers
161+
162+
@property
163+
def _max_smp_encoded_frame_size(self) -> int:
164+
"""The maximum encoded frame size (line_length * line_buffers)."""
165+
if isinstance(self._fragmentation_strategy, SMPSerialTransport.Auto):
166+
if self._smp_server_transport_buffer_size is not None:
167+
return self._smp_server_transport_buffer_size
168+
return self._line_length * self._line_buffers
169+
else:
170+
return (
171+
self._fragmentation_strategy.line_length * self._fragmentation_strategy.line_buffers
172+
)
173+
174+
@override
175+
def initialize(self, smp_server_transport_buffer_size: int) -> None:
176+
"""Initialize with the server's buffer size from MCUMGR_PARAM.
177+
178+
Args:
179+
smp_server_transport_buffer_size: The server's CONFIG_MCUMGR_TRANSPORT_NETBUF_SIZE
180+
"""
181+
super().initialize(smp_server_transport_buffer_size)
182+
183+
if isinstance(self._fragmentation_strategy, SMPSerialTransport.Auto):
184+
logger.info(
185+
f"Auto-configured from server: {self._line_length=}, "
186+
f"{self._line_buffers=}, mtu={self._max_smp_encoded_frame_size}"
187+
)
188+
else:
189+
# Validate user's BufferParams against server capabilities
190+
calculated_size = (
191+
self._fragmentation_strategy.line_length * self._fragmentation_strategy.line_buffers
192+
)
193+
if calculated_size > smp_server_transport_buffer_size:
194+
logger.warning(
195+
f"BufferParams (line_length={self._fragmentation_strategy.line_length} * "
196+
f"line_buffers={self._fragmentation_strategy.line_buffers} = {calculated_size}) " # noqa: E501
197+
f"exceeds server buffer size ({smp_server_transport_buffer_size})"
198+
)
199+
139200
@override
140201
async def connect(self, address: str, timeout_s: float) -> None:
141202
self._conn.port = address
@@ -309,7 +370,7 @@ def mtu(self) -> int:
309370
return self._max_smp_encoded_frame_size
310371

311372
@override
312-
@cached_property
373+
@property
313374
def max_unencoded_size(self) -> int:
314375
"""The serial transport encodes each packet instead of sending SMP messages as raw bytes."""
315376

tests/test_smp_client.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,14 @@ def aiter(iterable: Any) -> Any:
6363
class SMPMockTransport:
6464
"""Satisfies the `SMPTransport` `Protocol`."""
6565

66+
mtu = PropertyMock()
67+
max_unencoded_size = PropertyMock()
68+
6669
def __init__(self) -> None:
6770
self.connect = AsyncMock()
6871
self.disconnect = AsyncMock()
6972
self.send = AsyncMock()
7073
self.receive = AsyncMock()
71-
self.mtu = PropertyMock()
72-
self.max_unencoded_size = PropertyMock()
7374
self._smp_server_transport_buffer_size: int | None = None
7475
self.initialize = AsyncMock()
7576

@@ -334,12 +335,16 @@ async def test_upload_hello_world_bin_encoded(
334335
pytest.skip("The line buffer size is too small")
335336

336337
m = SMPSerialTransport(
337-
max_smp_encoded_frame_size=max_smp_encoded_frame_size,
338-
line_length=line_length,
339-
line_buffers=line_buffers,
338+
fragmentation_strategy=SMPSerialTransport.BufferParams(
339+
line_length=line_length,
340+
line_buffers=line_buffers,
341+
)
340342
)
341343
s = SMPClient(m, "address")
342-
assert s._transport.mtu == max_smp_encoded_frame_size
344+
# MTU is line_length * line_buffers, which may be <= max_smp_encoded_frame_size
345+
# due to integer division
346+
assert s._transport.mtu == line_length * line_buffers
347+
assert s._transport.mtu <= max_smp_encoded_frame_size
343348

344349
packets: List[bytes] = []
345350

@@ -565,12 +570,16 @@ async def test_file_upload_test_encoded(max_smp_encoded_frame_size: int, line_bu
565570
pytest.skip("The line buffer size is too small")
566571

567572
m = SMPSerialTransport(
568-
max_smp_encoded_frame_size=max_smp_encoded_frame_size,
569-
line_length=line_length,
570-
line_buffers=line_buffers,
573+
fragmentation_strategy=SMPSerialTransport.BufferParams(
574+
line_length=line_length,
575+
line_buffers=line_buffers,
576+
)
571577
)
572578
s = SMPClient(m, "address")
573-
assert s._transport.mtu == max_smp_encoded_frame_size
579+
# MTU is line_length * line_buffers, which may be <= max_smp_encoded_frame_size
580+
# due to integer division
581+
assert s._transport.mtu == line_length * line_buffers
582+
assert s._transport.mtu <= max_smp_encoded_frame_size
574583

575584
packets: List[bytes] = []
576585

tests/test_smp_serial_transport.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,23 @@
1414

1515

1616
def test_constructor() -> None:
17+
# Test with Auto() (default)
1718
t = SMPSerialTransport()
1819
assert isinstance(t._conn, Serial)
19-
20-
t = SMPSerialTransport(max_smp_encoded_frame_size=512, line_length=128, line_buffers=4)
20+
assert t.mtu == 127 # Default for Auto without initialize
21+
assert t._line_length == 127
22+
assert t._line_buffers == 1
23+
assert t._max_smp_encoded_frame_size == 127
24+
25+
# Test with BufferParams
26+
t = SMPSerialTransport(
27+
fragmentation_strategy=SMPSerialTransport.BufferParams(line_length=128, line_buffers=4)
28+
)
2129
assert isinstance(t._conn, Serial)
22-
assert t.mtu == 512
30+
assert t.mtu == 512 # 128 * 4
31+
assert t._line_length == 128
32+
assert t._line_buffers == 4
33+
assert t._max_smp_encoded_frame_size == 512
2334
assert t.max_unencoded_size < 512
2435

2536

@@ -174,3 +185,53 @@ async def test_send_and_receive() -> None:
174185

175186
t.send.assert_awaited_once_with(b"some data")
176187
t.receive.assert_awaited_once_with()
188+
189+
190+
def test_initialize_with_auto() -> None:
191+
"""Test that Auto mode updates parameters based on server's buffer size."""
192+
t = SMPSerialTransport() # Uses Auto() by default
193+
194+
# Before initialize, uses conservative defaults
195+
assert t._line_length == 127
196+
assert t._line_buffers == 1
197+
assert t._max_smp_encoded_frame_size == 127
198+
199+
# After initialize with server buffer size
200+
t.initialize(512)
201+
assert t._line_length == 127
202+
assert t._line_buffers == 512 // 127 # 4
203+
assert t._max_smp_encoded_frame_size == 512
204+
assert t.mtu == 512
205+
206+
207+
def test_initialize_with_buffer_params() -> None:
208+
"""Test that BufferParams mode doesn't change user-specified parameters."""
209+
t = SMPSerialTransport(
210+
fragmentation_strategy=SMPSerialTransport.BufferParams(line_length=128, line_buffers=2)
211+
)
212+
213+
# Before initialize
214+
assert t._line_length == 128
215+
assert t._line_buffers == 2
216+
assert t._max_smp_encoded_frame_size == 256 # 128 * 2
217+
218+
# After initialize - parameters should NOT change
219+
t.initialize(512)
220+
assert t._line_length == 128
221+
assert t._line_buffers == 2
222+
assert t._max_smp_encoded_frame_size == 256
223+
assert t.mtu == 256
224+
225+
226+
def test_initialize_with_buffer_params_warning(caplog: pytest.LogCaptureFixture) -> None:
227+
"""Test that a warning is logged when user's params exceed server buffer size."""
228+
t = SMPSerialTransport(
229+
fragmentation_strategy=SMPSerialTransport.BufferParams(
230+
line_length=128, line_buffers=4 # 128 * 4 = 512
231+
)
232+
)
233+
234+
with caplog.at_level(logging.WARNING):
235+
t.initialize(256) # Server buffer (256) is smaller than calculated size (512)
236+
237+
assert any("exceeds server buffer size" in record.message for record in caplog.records)

0 commit comments

Comments
 (0)