Skip to content

Commit 74ac9d3

Browse files
authored
Merge pull request #86 from gudnimg/generic-types-gudni
2 parents 625e1f0 + 42aad0b commit 74ac9d3

File tree

10 files changed

+27
-29
lines changed

10 files changed

+27
-29
lines changed

docs/_generate_requests_docstrings.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import importlib.util
1414
import inspect
1515
import os
16-
from typing import Any, List, Optional, Type
16+
from typing import Any, Optional, Type
1717

1818
from pydantic import BaseModel
1919

@@ -71,11 +71,11 @@ def format_type(annotation: Type[Any] | None) -> str:
7171
if annotation is None:
7272
raise ValueError("Annotation cannot be None")
7373
if hasattr(annotation, '__name__'): # Handles regular types like `int`, `str`, etc.
74-
# get the annotations like List[str] for example
74+
# get the annotations like list[str] for example
7575
if hasattr(annotation, '__args__'):
7676
return f"{annotation.__name__}[{format_type(annotation.__args__[0])}]"
7777
return f"{annotation.__name__}"
78-
elif hasattr(annotation, '__origin__'): # Handles generic types like List[str], Optional[int]
78+
elif hasattr(annotation, '__origin__'): # Handles generic types like list[str], Optional[int]
7979
return f"{annotation.__origin__.__module__}.{annotation.__origin__.__name__}"
8080
return str(annotation) # Fallback for other types
8181

@@ -108,7 +108,7 @@ def get_pydantic_fields(cls: Type[BaseModel]) -> str:
108108
return args
109109

110110

111-
def parse_file(file_path: str) -> List[ClassInfo]:
111+
def parse_file(file_path: str) -> list[ClassInfo]:
112112
"""Parse the file and extract class definitions."""
113113
with open(file_path, 'r') as file:
114114
lines = file.readlines()

examples/usb/upgrade.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import subprocess
88
import time
99
from pathlib import Path
10-
from typing import Final, Tuple
10+
from typing import Final
1111

1212
from serial.tools.list_ports import comports
1313
from smp import error as smperr
@@ -215,7 +215,7 @@ async def ensure_request(request: SMPRequest[TRep, TEr1, TEr2]) -> TRep:
215215
raise SystemExit(f"Unknown response: {images}")
216216

217217

218-
def get_runner_command(board: str, hex_path: Path) -> Tuple[str, ...]:
218+
def get_runner_command(board: str, hex_path: Path) -> tuple[str, ...]:
219219
if "nrf" in board:
220220
print("Using the nrfjprog runner")
221221
return ("nrfjprog", "--recover", "--reset", "--verify", "--program", str(hex_path))

smpclient/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import traceback
3131
from hashlib import sha256
3232
from types import TracebackType
33-
from typing import AsyncIterator, Final, Tuple, Type
33+
from typing import AsyncIterator, Final, Type
3434

3535
from pydantic import ValidationError
3636
from smp import header as smpheader
@@ -417,7 +417,7 @@ def _cbor_integer_size(integer: int) -> int:
417417
# https://datatracker.ietf.org/doc/html/rfc8949#name-core-deterministic-encoding
418418
return 0 if integer < 24 else 1 if integer <= 0xFF else 2 if integer <= 0xFFFF else 4
419419

420-
def _get_max_cbor_and_data_size(self, request: smpmsg.WriteRequest) -> Tuple[int, int]:
420+
def _get_max_cbor_and_data_size(self, request: smpmsg.WriteRequest) -> tuple[int, int]:
421421
"""Given an `ImageUploadWrite`, return the maximum CBOR size and data size."""
422422

423423
# given empty data in the request, how many bytes are available for the data?

smpclient/mcuboot.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from enum import IntEnum, IntFlag, unique
1212
from functools import cached_property
1313
from io import BufferedReader, BytesIO
14-
from typing import Annotated, Any, Dict, Final, List, Union
14+
from typing import Annotated, Any, Final, Union
1515

1616
from intelhex import hex2bin # type: ignore
1717
from pydantic import Field, GetCoreSchemaHandler
@@ -295,7 +295,7 @@ class ImageInfo:
295295

296296
header: ImageHeader
297297
tlv_info: ImageTLVInfo
298-
tlvs: List[ImageTLVValue]
298+
tlvs: list[ImageTLVValue]
299299
file: str | None = None
300300

301301
def get_tlv(self, tlv: ImageTLVType) -> ImageTLVValue:
@@ -332,15 +332,15 @@ def load_file(path: str) -> 'ImageInfo':
332332
f.seek(tlv_offset) # move to the start of the TLV area
333333
tlv_info = ImageTLVInfo.load_from(f)
334334

335-
tlvs: List[ImageTLVValue] = []
335+
tlvs: list[ImageTLVValue] = []
336336
while f.tell() < tlv_offset + tlv_info.tlv_tot:
337337
tlv_header = ImageTLV.load_from(f)
338338
tlvs.append(ImageTLVValue(header=tlv_header, value=f.read(tlv_header.len)))
339339

340340
return ImageInfo(file=path, header=image_header, tlv_info=tlv_info, tlvs=tlvs)
341341

342342
@cached_property
343-
def _map_tlv_type_to_value(self) -> Dict[int, ImageTLVValue]:
343+
def _map_tlv_type_to_value(self) -> dict[int, ImageTLVValue]:
344344
return {tlv.header.type: tlv for tlv in self.tlvs}
345345

346346
def __str__(self) -> str:

smpclient/transport/_udp_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import asyncio
66
import logging
7-
from typing import Any, Final, NamedTuple, Tuple
7+
from typing import Any, Final, NamedTuple
88

99
from typing_extensions import override
1010

@@ -79,7 +79,7 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None:
7979
logger.debug(f"Connection made, {transport=}")
8080

8181
@override
82-
def datagram_received(self, data: bytes, addr: Tuple[str | Any, int]) -> None:
82+
def datagram_received(self, data: bytes, addr: tuple[str | Any, int]) -> None:
8383
logger.debug(f"{len(data)} B datagram received from {addr}")
8484
self._receive_queue.put_nowait(data)
8585

smpclient/transport/ble.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
import re
88
import sys
9-
from typing import Final, List, Protocol
9+
from typing import Final, Protocol
1010
from uuid import UUID
1111

1212
from bleak import BleakClient, BleakGATTCharacteristic, BleakScanner
@@ -198,7 +198,7 @@ def mtu(self) -> int:
198198
return self._max_write_without_response_size
199199

200200
@staticmethod
201-
async def scan(timeout: int = 5) -> List[BLEDevice]:
201+
async def scan(timeout: int = 5) -> list[BLEDevice]:
202202
"""Scan for BLE devices."""
203203
logger.debug(f"Scanning for BLE devices for {timeout} seconds")
204204
devices: Final = await BleakScanner(service_uuids=[str(SMP_SERVICE_UUID)]).discover(

tests/extensions/test_intercreate.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Test the Intercreate extensions."""
22

33
from pathlib import Path
4-
from typing import List
54
from unittest.mock import PropertyMock, patch
65

76
import pytest
@@ -29,7 +28,7 @@ async def test_upload_hello_world_bin_encoded(mock_mtu: PropertyMock) -> None:
2928
assert s._transport.mtu == 127
3029
assert s._transport.max_unencoded_size < 127
3130

32-
packets: List[bytes] = []
31+
packets: list[bytes] = []
3332

3433
def mock_write(data: bytes) -> int:
3534
"""Accumulate the raw packets in the global `packets`."""

tests/test_requests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import Tuple, Type
5+
from typing import Type
66

77
import pytest
88
from smp import enumeration_management as smpem
@@ -296,7 +296,7 @@
296296
),
297297
)
298298
def test_requests(
299-
test_tuple: Tuple[
299+
test_tuple: tuple[
300300
smpmsg.Request,
301301
SMPRequest[TRep, TEr1, TEr2],
302302
Type[smpmsg.Response],

tests/test_smp_client.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import sys
66
from hashlib import sha256
77
from pathlib import Path
8-
from typing import List
98
from unittest.mock import AsyncMock, PropertyMock, call, patch
109

1110
import pytest
@@ -341,7 +340,7 @@ async def test_upload_hello_world_bin_encoded(
341340
s = SMPClient(m, "address")
342341
assert s._transport.mtu == max_smp_encoded_frame_size
343342

344-
packets: List[bytes] = []
343+
packets: list[bytes] = []
345344

346345
def mock_write(data: bytes) -> int:
347346
"""Accumulate the raw packets in the global `packets`."""
@@ -572,7 +571,7 @@ async def test_file_upload_test_encoded(max_smp_encoded_frame_size: int, line_bu
572571
s = SMPClient(m, "address")
573572
assert s._transport.mtu == max_smp_encoded_frame_size
574573

575-
packets: List[bytes] = []
574+
packets: list[bytes] = []
576575

577576
def mock_write(data: bytes) -> int:
578577
"""Accumulate the raw packets in the global `packets`."""

tests/test_udp_client.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Test the generic UDP client implementation."""
22

33
import asyncio
4-
from typing import List, Tuple, cast
4+
from typing import cast
55
from unittest.mock import AsyncMock, MagicMock, patch
66

77
import pytest
@@ -82,14 +82,14 @@ class _ServerProtocol(asyncio.DatagramProtocol):
8282
"""A mock SMP server protocol for unit testing."""
8383

8484
def __init__(self) -> None:
85-
self.datagrams_recieved: List[bytes] = []
85+
self.datagrams_recieved: list[bytes] = []
8686

87-
def datagram_received(self, data: bytes, addr: Tuple[str, int]) -> None:
87+
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
8888
self.datagrams_recieved.append(data)
8989

9090

9191
@pytest_asyncio.fixture
92-
async def udp_server() -> AsyncGenerator[Tuple[asyncio.DatagramTransport, _ServerProtocol], None]:
92+
async def udp_server() -> AsyncGenerator[tuple[asyncio.DatagramTransport, _ServerProtocol], None]:
9393
transport, protocol = await asyncio.get_running_loop().create_datagram_endpoint(
9494
lambda: _ServerProtocol(), local_addr=("127.0.0.1", 1337)
9595
)
@@ -100,7 +100,7 @@ async def udp_server() -> AsyncGenerator[Tuple[asyncio.DatagramTransport, _Serve
100100

101101

102102
@pytest.mark.asyncio
103-
async def test_send(udp_server: Tuple[asyncio.DatagramTransport, _ServerProtocol]) -> None:
103+
async def test_send(udp_server: tuple[asyncio.DatagramTransport, _ServerProtocol]) -> None:
104104
_, p = udp_server
105105

106106
c = UDPClient()
@@ -113,7 +113,7 @@ async def test_send(udp_server: Tuple[asyncio.DatagramTransport, _ServerProtocol
113113

114114

115115
@pytest.mark.asyncio
116-
async def test_receive(udp_server: Tuple[asyncio.DatagramTransport, _ServerProtocol]) -> None:
116+
async def test_receive(udp_server: tuple[asyncio.DatagramTransport, _ServerProtocol]) -> None:
117117
t, _ = udp_server
118118

119119
CLIENT_ADDR = Addr("127.0.0.1", 1338)

0 commit comments

Comments
 (0)