Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 52 additions & 28 deletions asyncua/ua/ua_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import functools
import logging
import struct
import sys
import typing
import uuid
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Callable, Iterable, Mapping, Sequence
from dataclasses import Field, fields, is_dataclass
from enum import Enum, IntFlag
from io import BytesIO
Expand Down Expand Up @@ -46,16 +47,20 @@ def set_string_encoding(new_encoding: str) -> None:
_string_encoding.set(new_encoding)


def get_safe_type_hints(cls: type, extra_ns: dict[str, Any] | None = None) -> dict[str, Any]:
# Use globalns=None so that get_type_hints automatically resolves the
# module globals of cls (e.g. bare names like Byte).
# Pass extra_ns (e.g. {'ua': ua}) as localns so ua.Xxx annotations resolve too.
# Filter out properties from the class dict to avoid shadowing.
def get_safe_type_hints(cls: type[Any], extra_ns: Mapping[str, Any] | None = None) -> dict[str, Any]:
# Resolve annotations with explicit module globals for stable behavior
# across Python versions (notably 3.10 forward-reference handling).
module = sys.modules.get(cls.__module__)
globalns = vars(module).copy() if module is not None else {}
if extra_ns:
globalns.update(extra_ns)

# Keep class-local names available and avoid property shadowing.
localns = {k: v for k, v in cls.__dict__.items() if not isinstance(v, property)}
if extra_ns:
localns.update(extra_ns)

return typing.get_type_hints(cls, globalns=None, localns=localns)
return typing.get_type_hints(cls, globalns=globalns, localns=localns)


def test_bit(data: int, offset: int) -> int:
Expand Down Expand Up @@ -327,9 +332,35 @@ def resolve_uatype(ftype: Any) -> tuple[Any, bool]:
return ftype, is_optional


def _resolve_type_in_dataclass_context(ftype: Any, dataclazz: type) -> Any:
if not isinstance(ftype, str):
return ftype

module = sys.modules.get(getattr(dataclazz, "__module__", "")) if isinstance(dataclazz, type) else None
namespace = {
"ua": ua,
"typing": typing,
"list": list,
"List": list,
"Union": typing.Union,
"Optional": typing.Optional,
"Dict": dict,
}
if module is not None:
namespace.update(vars(module))
if isinstance(dataclazz, type):
namespace.update({k: v for k, v in dataclazz.__dict__.items() if not isinstance(v, property)})

try:
return eval(ftype, namespace)
except Exception:
return ftype


def field_serializer(uatype: Any, is_optional: bool, dataclazz: type) -> Callable[[Any], bytes]:
if type_is_list(uatype):
ft = type_from_list(uatype)
ft = _resolve_type_in_dataclass_context(ft, dataclazz)
if is_optional:
return lambda val: b"" if val is None else create_list_serializer(ft, ft == dataclazz)(val)
return create_list_serializer(ft, ft == dataclazz)
Expand Down Expand Up @@ -460,21 +491,15 @@ def create_list_serializer(uatype: type, recursive: bool = False) -> Callable[[S
data_type = getattr(Primitives1, uatype.__name__)
return data_type.pack_array
none_val = Primitives.Int32.pack(-1)
if recursive:

def recursive_serialize(val: Sequence[Any] | None) -> bytes:
if val is None:
return none_val
data_size = Primitives.Int32.pack(len(val))
return data_size + b"".join(create_type_serializer(uatype)(el) for el in val)

return recursive_serialize

type_serializer = create_type_serializer(uatype)
type_serializer = None

def serialize(val: Sequence[Any] | None) -> bytes:
nonlocal type_serializer
if val is None:
return none_val
if type_serializer is None:
type_serializer = create_type_serializer(uatype)
data_size = Primitives.Int32.pack(len(val))
return data_size + b"".join(type_serializer(el) for el in val)

Expand Down Expand Up @@ -662,25 +687,23 @@ def extensionobject_to_binary(obj: Any) -> bytes:


@functools.cache
def _create_list_deserializer(uatype: type, recursive: bool = False) -> Callable[[Buffer | IO], list[Any]]:
if recursive:

def _deserialize_recursive(data: Buffer | IO) -> list[Any]:
size = Primitives.Int32.unpack(data)
return [_create_type_deserializer(uatype, type(None))(data) for _ in range(size)]

return _deserialize_recursive
element_deserializer = _create_type_deserializer(uatype, type(None))
def _create_list_deserializer(uatype: Any, recursive: bool = False) -> Callable[[Buffer | IO[Any]], list[Any]]:
# Resolve the element decoder lazily so mutually-recursive dataclass lists
# do not recurse forever during deserializer construction.
element_deserializer = None

def _deserialize(data: Buffer | IO) -> list[Any]:
def _deserialize(data: Buffer | IO[Any]) -> list[Any]:
nonlocal element_deserializer
size = Primitives.Int32.unpack(data)
if element_deserializer is None:
element_deserializer = _create_type_deserializer(uatype, type(None))
return [element_deserializer(data) for _ in range(size)]

return _deserialize


@functools.cache
def _create_type_deserializer(uatype: Any, dataclazz: type) -> Callable[[Buffer | IO], Any]:
def _create_type_deserializer(uatype: Any, dataclazz: type) -> Callable[[Buffer | IO[Any]], Any]:
uatype, is_optional = resolve_uatype(uatype)

if not is_optional and type_is_union(uatype):
Expand All @@ -689,6 +712,7 @@ def _create_type_deserializer(uatype: Any, dataclazz: type) -> Callable[[Buffer
return _create_type_deserializer(uatype, uatype)
if type_is_list(uatype):
utype = type_from_list(uatype)
utype = _resolve_type_in_dataclass_context(utype, dataclazz)
if hasattr(ua.VariantType, utype.__name__):
vtype = getattr(ua.VariantType, utype.__name__)
return _create_uatype_array_deserializer(vtype)
Expand Down
29 changes: 29 additions & 0 deletions tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@
EXAMPLE_BSD_PATH = Path(__file__).parent.absolute() / "example.bsd"


@dataclass
class _MutualRecursiveChild:
Name: ua.String = ""
Parents: list["_MutualRecursiveParent"] = field(default_factory=list)


@dataclass
class _MutualRecursiveParent:
Name: ua.String = ""
Children: list[_MutualRecursiveChild] = field(default_factory=list)


def test_variant_array_none():
v = ua.Variant(None, VariantType=ua.VariantType.Int32, is_array=True)
data = variant_to_binary(v)
Expand Down Expand Up @@ -910,6 +922,23 @@ class MyStruct:
assert m == m2


def test_struct_mutual_recursive_lists_roundtrip() -> None:
root = _MutualRecursiveParent(Name="root")
child = _MutualRecursiveChild(Name="leaf")
branch = _MutualRecursiveParent(Name="branch")
root.Children.append(child)
child.Parents.append(branch)

data = struct_to_binary(root)
decoded = struct_from_binary(_MutualRecursiveParent, ua.utils.Buffer(data))

assert decoded.Name == "root"
assert len(decoded.Children) == 1
assert decoded.Children[0].Name == "leaf"
assert len(decoded.Children[0].Parents) == 1
assert decoded.Children[0].Parents[0].Name == "branch"


def test_session_security_diagnostics_roundtrip():
"""Regression test: SessionSecurityDiagnosticsDataType has a bare
'Encoding: Byte' annotation (not quoted as 'ua.Byte'). With
Expand Down
Loading