Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions contrib/starrocks-python-client/starrocks/engine/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def mv_name(self, value: str) -> None:


@add_cached_str_clause
@dataclasses.dataclass
@dataclasses.dataclass(unsafe_hash=True)
class ReflectedRefreshInfo:
"""Stores structured reflection information about a materialized view's refresh scheme."""
moment: Optional[str] = None
Expand Down Expand Up @@ -217,7 +217,7 @@ class ReflectedCKInfo(TypedDict):


@add_cached_str_clause
@dataclasses.dataclass(**dict(kw_only=True) if 'KW_ONLY' in dataclasses.__all__ else {})
@dataclasses.dataclass(**(dict(kw_only=True) if 'KW_ONLY' in dataclasses.__all__ else {}))
class ReflectedTableKeyInfo:
"""
Stores structed reflection information about a table' key/type.
Expand All @@ -230,20 +230,23 @@ class ReflectedTableKeyInfo:
type: str
columns: Optional[Union[List[str], str]]

def __hash__(self) -> int:
cols = tuple(self.columns) if isinstance(self.columns, list) else self.columns
return hash((self.type, cols))

def __str__(self) -> str:
self.type = self.type.upper() if self.type else self.type
if self.columns:
self.columns = self.columns.strip()
if isinstance(self.columns, list):
return f"{self.type} ({', '.join(self.columns)})"
return f"{self.type} ({self.columns})"
type_str = self.type.upper() if self.type else self.type
columns_str = self.columns.strip() if isinstance(self.columns, str) else self.columns
if isinstance(columns_str, list):
return f"{type_str} ({', '.join(columns_str)})"
return f"{type_str} ({columns_str})"

def __repr__(self) -> str:
return repr(str(self))


@add_cached_str_clause
@dataclasses.dataclass(**dict(kw_only=True) if 'KW_ONLY' in dataclasses.__all__ else {})
@dataclasses.dataclass(unsafe_hash=True, **(dict(kw_only=True) if 'KW_ONLY' in dataclasses.__all__ else {}))
class ReflectedPartitionInfo:
"""
Stores structured reflection information about a table's partitioning scheme.
Expand All @@ -261,18 +264,17 @@ class ReflectedPartitionInfo:
pre_created_partitions: Optional[str] = None

def __str__(self) -> str:
self.type = self.type.upper() if self.type else self.type
self.partition_method = self.partition_method.strip() if self.partition_method else self.partition_method
method_str = self.partition_method.strip() if self.partition_method else self.partition_method
if self.pre_created_partitions:
return f"{self.partition_method} {self.pre_created_partitions}"
return f"{self.partition_method}"
return f"{method_str} {self.pre_created_partitions}"
return f"{method_str}"

def __repr__(self) -> str:
return repr(str(self))


@add_cached_str_clause
@dataclasses.dataclass(**dict(kw_only=True) if 'KW_ONLY' in dataclasses.__all__ else {})
@dataclasses.dataclass(**(dict(kw_only=True) if 'KW_ONLY' in dataclasses.__all__ else {}))
class ReflectedDistributionInfo:
"""Stores reflection information about a view."""
type: Union[str, None]
Expand All @@ -285,14 +287,19 @@ class ReflectedDistributionInfo:
buckets: Union[int, None]
"""The buckets count."""

def __hash__(self) -> int:
cols = tuple(self.columns) if isinstance(self.columns, list) else self.columns
return hash((self.type, cols, self.distribution_method, self.buckets))

def __str__(self) -> str:
"""Convert to string representation of distribution option."""
buckets_str = f' BUCKETS {self.buckets}' if self.buckets and str(self.buckets) != "0" else ""
if not self.distribution_method:
method = self.distribution_method
if not method:
distribution_cols = ', '.join(self.columns) if isinstance(self.columns, list) else self.columns
distribution_cols_str = f'({distribution_cols})' if distribution_cols else ""
self.distribution_method = f'{self.type}{distribution_cols_str}'
return f'{self.distribution_method}{buckets_str}'
method = f'{self.type}{distribution_cols_str}'
return f'{method}{buckets_str}'

def __repr__(self) -> str:
return repr(str(self))
130 changes: 130 additions & 0 deletions contrib/starrocks-python-client/test/unit/test_interfaces_hashable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Tests that reflection dataclasses are hashable and safe for use as dict keys."""
import pytest

from starrocks.engine.interfaces import (
ReflectedDistributionInfo,
ReflectedPartitionInfo,
ReflectedRefreshInfo,
ReflectedTableKeyInfo,
)


class TestReflectedRefreshInfoHashable:
def test_usable_as_dict_key(self):
info = ReflectedRefreshInfo(moment="ASYNC", type="FULL")
d = {info: "value"}
assert d[info] == "value"

def test_hash_stable_after_str(self):
info = ReflectedRefreshInfo(moment="ASYNC", type="FULL")
h_before = hash(info)
str(info)
assert hash(info) == h_before


class TestReflectedTableKeyInfoHashable:
def test_usable_as_dict_key(self):
info = ReflectedTableKeyInfo(type="PRIMARY KEY", columns="id")
d = {info: "value"}
assert d[info] == "value"

def test_usable_as_dict_key_with_list_columns(self):
info = ReflectedTableKeyInfo(type="PRIMARY KEY", columns=["id", "name"])
d = {info: "value"}
assert d[info] == "value"

def test_hash_stable_after_str(self):
info = ReflectedTableKeyInfo(type="primary key", columns=" id ")
h_before = hash(info)
str(info)
assert hash(info) == h_before

def test_hash_stable_with_list_columns(self):
info = ReflectedTableKeyInfo(type="PRIMARY KEY", columns=["id", "name"])
h_before = hash(info)
str(info)
assert hash(info) == h_before

def test_str_does_not_mutate_fields(self):
info = ReflectedTableKeyInfo(type="primary key", columns=" id ")
str(info)
assert info.type == "primary key"
assert info.columns == " id "


class TestReflectedPartitionInfoHashable:
def test_usable_as_dict_key(self):
info = ReflectedPartitionInfo(type="RANGE", partition_method="RANGE(dt)")
d = {info: "value"}
assert d[info] == "value"

def test_hash_stable_after_str(self):
info = ReflectedPartitionInfo(type="range", partition_method=" RANGE(dt) ")
h_before = hash(info)
str(info)
assert hash(info) == h_before

def test_str_does_not_mutate_fields(self):
info = ReflectedPartitionInfo(type="range", partition_method=" RANGE(dt) ")
str(info)
assert info.type == "range"
assert info.partition_method == " RANGE(dt) "


class TestReflectedDistributionInfoHashable:
def test_usable_as_dict_key(self):
info = ReflectedDistributionInfo(
type="HASH", columns="id", distribution_method="HASH(id)", buckets=4
)
d = {info: "value"}
assert d[info] == "value"

def test_usable_as_dict_key_with_list_columns(self):
info = ReflectedDistributionInfo(
type="HASH", columns=["id", "name"], distribution_method=None, buckets=4
)
d = {info: "value"}
assert d[info] == "value"

def test_hash_stable_after_str(self):
info = ReflectedDistributionInfo(
type="HASH", columns="id", distribution_method=None, buckets=4
)
h_before = hash(info)
str(info)
assert hash(info) == h_before

def test_hash_stable_with_list_columns(self):
info = ReflectedDistributionInfo(
type="HASH", columns=["id", "name"], distribution_method=None, buckets=4
)
h_before = hash(info)
str(info)
assert hash(info) == h_before

def test_str_does_not_mutate_fields(self):
info = ReflectedDistributionInfo(
type="HASH", columns="id", distribution_method=None, buckets=4
)
str(info)
assert info.distribution_method is None

def test_usable_in_set(self):
info1 = ReflectedDistributionInfo(
type="HASH", columns="id", distribution_method="HASH(id)", buckets=4
)
info2 = ReflectedDistributionInfo(
type="HASH", columns="id", distribution_method="HASH(id)", buckets=4
)
s = {info1, info2}
assert len(s) == 1

def test_usable_in_set_with_list_columns(self):
info1 = ReflectedDistributionInfo(
type="HASH", columns=["id", "name"], distribution_method=None, buckets=4
)
info2 = ReflectedDistributionInfo(
type="HASH", columns=["id", "name"], distribution_method=None, buckets=4
)
s = {info1, info2}
assert len(s) == 1
Loading