diff --git a/pyproject.toml b/pyproject.toml index b00e64f3..cf05833f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,23 +75,10 @@ module = [ ] ignore_missing_imports = true -[[tool.mypy.overrides]] -module = "fixtures.*" -ignore_missing_imports = true -follow_imports = "skip" - [[tool.mypy.overrides]] module = [ # FIXME(stephenfin): We would like to remove all modules from this list # except tests (we're not sadists) - "testtools.assertions", - "testtools.compat", - "testtools.matchers.*", - "testtools.monkey", - "testtools.run", - "testtools.runtest", - "testtools.testcase", - "testtools.testresult.*", "testtools.twistedsupport.*", "tests.*", ] diff --git a/tests/matchers/helpers.py b/tests/matchers/helpers.py index 930bf70f..927d76c9 100644 --- a/tests/matchers/helpers.py +++ b/tests/matchers/helpers.py @@ -21,6 +21,8 @@ class MatcherTestProtocol(Protocol): class TestMatchersInterface: """Mixin class that provides test methods for matcher interfaces.""" + __test__ = False # Tell pytest not to collect this as a test class + def test_matches_match(self: MatcherTestProtocol) -> None: matcher = self.matches_matcher matches = self.matches_matches diff --git a/tests/test_testresult.py b/tests/test_testresult.py index 65b74919..67c2c0fc 100644 --- a/tests/test_testresult.py +++ b/tests/test_testresult.py @@ -146,6 +146,8 @@ def make_exception_info(exceptionFactory, *args, **kwargs): class TestControlContract: """Stopping test runs.""" + __test__ = False # Tell pytest not to collect this as a test class + # These are provided by the class that uses this mixin makeResult: Any assertFalse: Any @@ -573,6 +575,8 @@ def makeResult(self): class TestStreamResultContract: + __test__ = False # Tell pytest not to collect this as a test class + # These are provided by the class that uses this mixin addCleanup: Any diff --git a/testtools/assertions.py b/testtools/assertions.py index f4d6bab8..e25d3f81 100644 --- a/testtools/assertions.py +++ b/testtools/assertions.py @@ -2,13 +2,20 @@ """Assertion helpers.""" +from typing import TypeVar + from testtools.matchers import ( Annotate, + Matcher, MismatchError, ) +T = TypeVar("T") + -def assert_that(matchee, matcher, message="", verbose=False): +def assert_that( + matchee: T, matcher: Matcher[T], message: str = "", verbose: bool = False +) -> None: """Assert that matchee is matched by matcher. This should only be used when you need to use a function based diff --git a/testtools/compat.py b/testtools/compat.py index 5c85a71d..8964d316 100644 --- a/testtools/compat.py +++ b/testtools/compat.py @@ -14,9 +14,10 @@ import sys import unicodedata from io import BytesIO, StringIO # for backwards-compat +from typing import IO -def _slow_escape(text): +def _slow_escape(text: str) -> str: """Escape unicode ``text`` leaving printable characters unmodified The behaviour emulates the Python 3 implementation of repr, see @@ -26,7 +27,7 @@ def _slow_escape(text): does not handle astral characters correctly on Python builds with 16 bit rather than 32 bit unicode type. """ - output = [] + output: list[str | bytes] = [] for c in text: o = ord(c) if o < 256: @@ -43,14 +44,14 @@ def _slow_escape(text): output.append(c.encode("unicode-escape")) else: output.append(c) - return "".join(output) + return "".join(output) # type: ignore[arg-type] -def text_repr(text, multiline=None): +def text_repr(text: str | bytes, multiline: bool | None = None) -> str: """Rich repr for ``text`` returning unicode, triple quoted if ``multiline``.""" nl = (isinstance(text, bytes) and bytes((0xA,))) or "\n" if multiline is None: - multiline = nl in text + multiline = nl in text # type: ignore[operator] if not multiline: # Use normal repr for single line of unicode return repr(text) @@ -60,7 +61,7 @@ def text_repr(text, multiline=None): # making sure that quotes are not escaped. offset = len(prefix) + 1 lines = [] - for line in text.split(nl): + for line in text.split(nl): # type: ignore[arg-type] r = repr(line) q = r[-1] lines.append(r[offset:-1].replace("\\" + q, q)) @@ -87,7 +88,7 @@ def text_repr(text, multiline=None): return "".join([prefix, quote, escaped_text, quote]) -def unicode_output_stream(stream): +def unicode_output_stream(stream: IO[str]) -> IO[str]: """Get wrapper for given stream that writes any unicode without exception Characters that can't be coerced to the encoding of the stream, or 'ascii' @@ -103,21 +104,21 @@ def unicode_output_stream(stream): # attribute). return stream try: - writer = codecs.getwriter(stream.encoding or "") + writer = codecs.getwriter(stream.encoding or "") # type: ignore[attr-defined] except (AttributeError, LookupError): - return codecs.getwriter("ascii")(stream, "replace") + return codecs.getwriter("ascii")(stream, "replace") # type: ignore[arg-type, return-value] if writer.__module__.rsplit(".", 1)[1].startswith("utf"): # The current stream has a unicode encoding so no error handler is needed return stream # Python 3 doesn't seem to make this easy, handle a common case try: - return stream.__class__( - stream.buffer, - stream.encoding, + return stream.__class__( # type: ignore[call-arg, return-value] + stream.buffer, # type: ignore[attr-defined] + stream.encoding, # type: ignore[attr-defined] "replace", - stream.newlines, - stream.line_buffering, + stream.newlines, # type: ignore[attr-defined] + stream.line_buffering, # type: ignore[attr-defined] ) except AttributeError: pass - return writer(stream, "replace") + return writer(stream, "replace") # type: ignore[arg-type, return-value] diff --git a/testtools/content.py b/testtools/content.py index 002fbfd6..fe31e433 100644 --- a/testtools/content.py +++ b/testtools/content.py @@ -197,7 +197,8 @@ class TracebackContent(Content): def __init__( self, - err: tuple[type[BaseException], BaseException, types.TracebackType | None], + err: tuple[type[BaseException], BaseException, types.TracebackType | None] + | tuple[None, None, None], test: _TestCase | None, capture_locals: bool = False, ) -> None: @@ -211,6 +212,9 @@ def __init__( raise ValueError("err may not be None") exctype, value, tb = err + # Ensure we have a real exception, not the (None, None, None) variant + assert exctype is not None, "exctype must not be None" + assert value is not None, "value must not be None" # Skip test runner traceback levels if StackLinesContent.HIDE_INTERNAL_STACK: while tb and "__unittest" in tb.tb_frame.f_globals: diff --git a/testtools/matchers/_basic.py b/testtools/matchers/_basic.py index bd6d4ea0..d3650ef3 100644 --- a/testtools/matchers/_basic.py +++ b/testtools/matchers/_basic.py @@ -20,7 +20,7 @@ import re from collections.abc import Callable from pprint import pformat -from typing import Any +from typing import Any, Generic, TypeVar from ..compat import ( text_repr, @@ -35,52 +35,64 @@ Mismatch, ) +T = TypeVar("T") +U = TypeVar("U") -def _format(thing): + +def _format(thing: object) -> str: """Blocks of text with newlines are formatted as triple-quote strings. Everything else is pretty-printed. """ if isinstance(thing, (str, bytes)): - return text_repr(thing) - return pformat(thing) + result: str = text_repr(thing) + return result + pformat_result: str = pformat(thing) + return pformat_result -class _BinaryComparison: +class _BinaryComparison(Matcher[T]): """Matcher that compares an object to another object.""" mismatch_string: str # comparator is defined by subclasses - using Any to allow different signatures comparator: Callable[..., Any] - def __init__(self, expected): + def __init__(self, expected: T) -> None: self.expected = expected - def __str__(self): + def __str__(self) -> str: return f"{self.__class__.__name__}({self.expected!r})" - def match(self, other): + def match(self, other: T) -> Mismatch | None: if self.comparator(other, self.expected): return None return _BinaryMismatch(other, self.mismatch_string, self.expected) -class _BinaryMismatch(Mismatch): +class _BinaryMismatch(Mismatch, Generic[T]): """Two things did not match.""" - def __init__(self, actual, mismatch_string, reference, reference_on_right=True): + def __init__( + self, + actual: T, + mismatch_string: str, + reference: T, + reference_on_right: bool = True, + ) -> None: self._actual = actual self._mismatch_string = mismatch_string self._reference = reference self._reference_on_right = reference_on_right - def describe(self): + def describe(self) -> str: # Special handling for set comparisons if ( self._mismatch_string == "!=" and isinstance(self._reference, set) and isinstance(self._actual, set) ): - return self._describe_set_difference() + result: str = self._describe_set_difference() + return result actual = repr(self._actual) reference = repr(self._reference) @@ -97,8 +109,12 @@ def describe(self): left, right = reference, actual return f"{left} {self._mismatch_string} {right}" - def _describe_set_difference(self): + def _describe_set_difference(self) -> str: """Describe the difference between two sets in a readable format.""" + # Type narrowing: we know these are sets from the isinstance check in describe() + assert isinstance(self._reference, set) + assert isinstance(self._actual, set) + reference_only = sorted( self._reference - self._actual, key=lambda x: (type(x).__name__, x) ) @@ -119,14 +135,14 @@ def _describe_set_difference(self): return "\n".join(lines) -class Equals(_BinaryComparison): +class Equals(_BinaryComparison[T]): """Matches if the items are equal.""" comparator = operator.eq mismatch_string = "!=" -class _FlippedEquals: +class _FlippedEquals(Matcher[T]): """Matches if the items are equal. Exactly like ``Equals`` except that the short mismatch message is " @@ -136,17 +152,20 @@ class _FlippedEquals: the assertion. """ - def __init__(self, expected): + def __init__(self, expected: T) -> None: self._expected = expected - def match(self, other): + def __str__(self) -> str: + return f"_FlippedEquals({self._expected!r})" + + def match(self, other: T) -> Mismatch | None: mismatch = Equals(self._expected).match(other) if not mismatch: return None return _BinaryMismatch(other, "!=", self._expected, False) -class NotEquals(_BinaryComparison): +class NotEquals(_BinaryComparison[T]): """Matches if the items are not equal. In most cases, this is equivalent to ``Not(Equals(foo))``. The difference @@ -157,38 +176,38 @@ class NotEquals(_BinaryComparison): mismatch_string = "==" -class Is(_BinaryComparison): +class Is(_BinaryComparison[T]): """Matches if the items are identical.""" comparator = operator.is_ mismatch_string = "is not" -class LessThan(_BinaryComparison): +class LessThan(_BinaryComparison[T]): """Matches if the item is less than the matchers reference object.""" comparator = operator.lt mismatch_string = ">=" -class GreaterThan(_BinaryComparison): +class GreaterThan(_BinaryComparison[T]): """Matches if the item is greater than the matchers reference object.""" comparator = operator.gt mismatch_string = "<=" -class _NotNearlyEqual(Mismatch): +class _NotNearlyEqual(Mismatch, Generic[T]): """Mismatch for Nearly matcher.""" - def __init__(self, actual, expected, delta): + def __init__(self, actual: T, expected: T, delta: Any) -> None: self.actual = actual self.expected = expected self.delta = delta - def describe(self): + def describe(self) -> str: try: - diff = abs(self.actual - self.expected) + diff = abs(self.actual - self.expected) # type: ignore[operator] return ( f"{self.actual!r} is not nearly equal to {self.expected!r}: " f"difference {diff!r} exceeds tolerance {self.delta!r}" @@ -200,7 +219,7 @@ def describe(self): ) -class Nearly(Matcher): +class Nearly(Matcher[T]): """Matches if a value is nearly equal to the expected value. This matcher is useful for comparing floating point values where exact @@ -213,7 +232,7 @@ class Nearly(Matcher): operations (e.g., integers, floats, Decimal, etc.). """ - def __init__(self, expected, delta=0.001): + def __init__(self, expected: T, delta: Any = 0.001) -> None: """Create a Nearly matcher. :param expected: The expected value to compare against. @@ -223,12 +242,12 @@ def __init__(self, expected, delta=0.001): self.expected = expected self.delta = delta - def __str__(self): + def __str__(self) -> str: return f"Nearly({self.expected!r}, delta={self.delta!r})" - def match(self, actual): + def match(self, actual: T) -> Mismatch | None: try: - diff = abs(actual - self.expected) + diff = abs(actual - self.expected) # type: ignore[operator] if diff <= self.delta: return None except (TypeError, AttributeError): @@ -237,25 +256,25 @@ def match(self, actual): return _NotNearlyEqual(actual, self.expected, self.delta) -class SameMembers(Matcher): +class SameMembers(Matcher[list[T]]): """Matches if two iterators have the same members. This is not the same as set equivalence. The two iterators must be of the same length and have the same repetitions. """ - def __init__(self, expected): + def __init__(self, expected: list[T]) -> None: super().__init__() self.expected = expected - def __str__(self): + def __str__(self) -> str: return f"{self.__class__.__name__}({self.expected!r})" - def match(self, observed): + def match(self, observed: list[T]) -> Mismatch | None: expected_only = list_subtract(self.expected, observed) observed_only = list_subtract(observed, self.expected) if expected_only == observed_only == []: - return + return None return PostfixedMismatch( ( f"\nmissing: {_format(expected_only)}\n" @@ -266,7 +285,7 @@ def match(self, observed): class DoesNotStartWith(Mismatch): - def __init__(self, matchee, expected): + def __init__(self, matchee: str | bytes, expected: str | bytes) -> None: """Create a DoesNotStartWith Mismatch. :param matchee: the string that did not match. @@ -275,33 +294,33 @@ def __init__(self, matchee, expected): self.matchee = matchee self.expected = expected - def describe(self): + def describe(self) -> str: return ( f"{text_repr(self.matchee)} does not start with {text_repr(self.expected)}." ) -class StartsWith(Matcher): +class StartsWith(Matcher[str | bytes]): """Checks whether one string starts with another.""" - def __init__(self, expected): + def __init__(self, expected: str | bytes) -> None: """Create a StartsWith Matcher. :param expected: the string that matchees should start with. """ self.expected = expected - def __str__(self): + def __str__(self) -> str: return f"StartsWith({self.expected!r})" - def match(self, matchee): - if not matchee.startswith(self.expected): + def match(self, matchee: str | bytes) -> Mismatch | None: + if not matchee.startswith(self.expected): # type: ignore[arg-type] return DoesNotStartWith(matchee, self.expected) return None class DoesNotEndWith(Mismatch): - def __init__(self, matchee, expected): + def __init__(self, matchee: str | bytes, expected: str | bytes) -> None: """Create a DoesNotEndWith Mismatch. :param matchee: the string that did not match. @@ -310,50 +329,50 @@ def __init__(self, matchee, expected): self.matchee = matchee self.expected = expected - def describe(self): + def describe(self) -> str: return ( f"{text_repr(self.matchee)} does not end with {text_repr(self.expected)}." ) -class EndsWith(Matcher): +class EndsWith(Matcher[str | bytes]): """Checks whether one string ends with another.""" - def __init__(self, expected): + def __init__(self, expected: str | bytes) -> None: """Create a EndsWith Matcher. :param expected: the string that matchees should end with. """ self.expected = expected - def __str__(self): + def __str__(self) -> str: return f"EndsWith({self.expected!r})" - def match(self, matchee): - if not matchee.endswith(self.expected): + def match(self, matchee: str | bytes) -> Mismatch | None: + if not matchee.endswith(self.expected): # type: ignore[arg-type] return DoesNotEndWith(matchee, self.expected) return None -class IsInstance: +class IsInstance(Matcher[T]): """Matcher that wraps isinstance.""" - def __init__(self, *types): + def __init__(self, *types: type[T]) -> None: self.types = tuple(types) - def __str__(self): + def __str__(self) -> str: return "{}({})".format( self.__class__.__name__, ", ".join(type.__name__ for type in self.types) ) - def match(self, other): + def match(self, other: T) -> Mismatch | None: if isinstance(other, self.types): return None return NotAnInstance(other, self.types) -class NotAnInstance(Mismatch): - def __init__(self, matchee, types): +class NotAnInstance(Mismatch, Generic[T]): + def __init__(self, matchee: T, types: tuple[type[T], ...]) -> None: """Create a NotAnInstance Mismatch. :param matchee: the thing which is not an instance of any of types. @@ -362,7 +381,7 @@ def __init__(self, matchee, types): self.matchee = matchee self.types = types - def describe(self): + def describe(self) -> str: if len(self.types) == 1: typestr = self.types[0].__name__ else: @@ -372,8 +391,8 @@ def describe(self): return f"'{self.matchee}' is not an instance of {typestr}" -class DoesNotContain(Mismatch): - def __init__(self, matchee, needle): +class DoesNotContain(Mismatch, Generic[T, U]): + def __init__(self, matchee: T, needle: U) -> None: """Create a DoesNotContain Mismatch. :param matchee: the object that did not contain needle. @@ -382,26 +401,26 @@ def __init__(self, matchee, needle): self.matchee = matchee self.needle = needle - def describe(self): + def describe(self) -> str: return f"{self.needle!r} not in {self.matchee!r}" -class Contains(Matcher): +class Contains(Matcher[T], Generic[T, U]): """Checks whether something is contained in another thing.""" - def __init__(self, needle): + def __init__(self, needle: U) -> None: """Create a Contains Matcher. :param needle: the thing that needs to be contained by matchees. """ self.needle = needle - def __str__(self): + def __str__(self) -> str: return f"Contains({self.needle!r})" - def match(self, matchee): + def match(self, matchee: T) -> Mismatch | None: try: - if self.needle not in matchee: + if self.needle not in matchee: # type: ignore[operator] return DoesNotContain(matchee, self.needle) except TypeError: # e.g. 1 in 2 will raise TypeError @@ -409,14 +428,14 @@ def match(self, matchee): return None -class MatchesRegex: +class MatchesRegex(Matcher[str]): """Matches if the matchee is matched by a regular expression.""" - def __init__(self, pattern, flags=0): + def __init__(self, pattern: str | bytes, flags: int = 0) -> None: self.pattern = pattern self.flags = flags - def __str__(self): + def __str__(self) -> str: args = [f"{self.pattern!r}"] flag_arg = [] # dir() sorts the attributes for us, so we don't need to do it again. @@ -428,8 +447,8 @@ def __str__(self): args.append("|".join(flag_arg)) return "{}({})".format(self.__class__.__name__, ", ".join(args)) - def match(self, value): - if not re.match(self.pattern, value, self.flags): + def match(self, value: str) -> Mismatch | None: + if not re.match(self.pattern, value, self.flags): # type: ignore[arg-type] pattern = self.pattern if not isinstance(pattern, str): pattern = pattern.decode("latin1") @@ -437,9 +456,10 @@ def match(self, value): return Mismatch( "{!r} does not match /{}/".format(value, pattern.replace("\\\\", "\\")) ) + return None -def has_len(x, y): +def has_len(x: Any, y: int) -> bool: return len(x) == y diff --git a/testtools/matchers/_const.py b/testtools/matchers/_const.py index df405347..7ae976eb 100644 --- a/testtools/matchers/_const.py +++ b/testtools/matchers/_const.py @@ -5,20 +5,21 @@ "Never", ] -from ._impl import Mismatch +from ._impl import Matcher, Mismatch -class _Always: + +class _Always(Matcher[object]): """Always matches.""" - def __str__(self): + def __str__(self) -> str: return "Always()" - def match(self, value): + def match(self, value: object) -> None: return None -def Always(): +def Always() -> _Always: """Always match. That is:: @@ -32,17 +33,17 @@ def Always(): return _Always() -class _Never: +class _Never(Matcher[object]): """Never matches.""" - def __str__(self): + def __str__(self) -> str: return "Never()" - def match(self, value): + def match(self, value: object) -> Mismatch: return Mismatch(f"Inevitable mismatch on {value!r}") -def Never(): +def Never() -> _Never: """Never match. That is:: diff --git a/testtools/matchers/_datastructures.py b/testtools/matchers/_datastructures.py index fc94568e..6a0737bd 100644 --- a/testtools/matchers/_datastructures.py +++ b/testtools/matchers/_datastructures.py @@ -2,13 +2,18 @@ """Matchers that operate with knowledge of Python data structures.""" +from collections.abc import Iterable, Sequence +from typing import Any, Generic, TypeVar + from ..helpers import map_values from ._higherorder import ( Annotate, MatchesAll, MismatchesAll, ) -from ._impl import Mismatch +from ._impl import Matcher, Mismatch + +T = TypeVar("T") __all__ = [ "ContainsAll", @@ -18,7 +23,7 @@ ] -def ContainsAll(items): +def ContainsAll(items: Iterable[T]) -> "MatchesAll[Iterable[T]]": """Make a matcher that checks whether a list of things is contained in another thing. @@ -30,7 +35,7 @@ def ContainsAll(items): return MatchesAll(*map(Contains, items), first_only=False) -class MatchesListwise: +class MatchesListwise(Matcher["Sequence[T]"], Generic[T]): """Matches if each matcher matches the corresponding value. More easily explained by example than in words: @@ -48,7 +53,9 @@ class MatchesListwise: 3 != 1 """ - def __init__(self, matchers, first_only=False): + def __init__( + self, matchers: "Sequence[Matcher[T]]", first_only: bool = False + ) -> None: """Construct a MatchesListwise matcher. :param matchers: A list of matcher that the matched values must match. @@ -58,7 +65,7 @@ def __init__(self, matchers, first_only=False): self.matchers = matchers self.first_only = first_only - def match(self, values): + def match(self, values: "Sequence[T]") -> Mismatch | None: from ._basic import HasLength mismatches = [] @@ -75,9 +82,10 @@ def match(self, values): mismatches.append(mismatch) if mismatches: return MismatchesAll(mismatches) + return None -class MatchesStructure: +class MatchesStructure(Matcher[T], Generic[T]): """Matcher that matches an object structurally. 'Structurally' here means that attributes of the object being matched are @@ -94,7 +102,7 @@ class MatchesStructure: the matcher, rather than just using `Equals`. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs: "Matcher[Any]") -> None: """Construct a `MatchesStructure`. :param kwargs: A mapping of attributes to matchers. @@ -102,7 +110,7 @@ def __init__(self, **kwargs): self.kws = kwargs @classmethod - def byEquality(cls, **kwargs): + def byEquality(cls, **kwargs: Any) -> "MatchesStructure[Any]": """Matches an object where the attributes equal the keyword values. Similar to the constructor, except that the matcher is assumed to be @@ -113,7 +121,9 @@ def byEquality(cls, **kwargs): return cls.byMatcher(Equals, **kwargs) @classmethod - def byMatcher(cls, matcher, **kwargs): + def byMatcher( + cls, matcher: type["Matcher[Any]"], **kwargs: Any + ) -> "MatchesStructure[Any]": """Matches an object where the attributes match the keyword values. Similar to the constructor, except that the provided matcher is used @@ -122,15 +132,15 @@ def byMatcher(cls, matcher, **kwargs): return cls(**map_values(matcher, kwargs)) @classmethod - def fromExample(cls, example, *attributes): + def fromExample(cls, example: T, *attributes: str) -> "MatchesStructure[T]": from ._basic import Equals - kwargs = {} + kwargs: dict[str, Matcher[Any]] = {} for attr in attributes: kwargs[attr] = Equals(getattr(example, attr)) return cls(**kwargs) - def update(self, **kws): + def update(self, **kws: "Matcher[Any] | None") -> "MatchesStructure[T]": new_kws = self.kws.copy() for attr, matcher in kws.items(): if matcher is None: @@ -139,22 +149,22 @@ def update(self, **kws): new_kws[attr] = matcher return type(self)(**new_kws) - def __str__(self): + def __str__(self) -> str: kws = [] for attr, matcher in sorted(self.kws.items()): kws.append(f"{attr}={matcher}") return "{}({})".format(self.__class__.__name__, ", ".join(kws)) - def match(self, value): - matchers = [] - values = [] + def match(self, value: T) -> Mismatch | None: + matchers: list[Matcher[Any]] = [] + values: list[Any] = [] for attr, matcher in sorted(self.kws.items()): matchers.append(Annotate(attr, matcher)) values.append(getattr(value, attr)) return MatchesListwise(matchers).match(values) -class MatchesSetwise: +class MatchesSetwise(Matcher[Iterable[T]], Generic[T]): """Matches if all the matchers match elements of the value being matched. That is, each element in the 'observed' set must match exactly one matcher @@ -164,10 +174,15 @@ class MatchesSetwise: matchings does not matter. """ - def __init__(self, *matchers): + def __init__(self, *matchers: "Matcher[T]") -> None: self.matchers = matchers - def match(self, observed): + def __str__(self) -> str: + return "{}({})".format( + self.__class__.__name__, ", ".join(map(str, self.matchers)) + ) + + def match(self, observed: Iterable[T]) -> Mismatch | None: remaining_matchers = set(self.matchers) not_matched = [] for value in observed: @@ -229,3 +244,4 @@ def match(self, observed): return Annotate( msg, MatchesListwise(remaining_matchers_list[:common_length]) ).match(not_matched[:common_length]) + return None diff --git a/testtools/matchers/_dict.py b/testtools/matchers/_dict.py index cbd0b79e..eb89144d 100644 --- a/testtools/matchers/_dict.py +++ b/testtools/matchers/_dict.py @@ -4,7 +4,8 @@ "KeysEqual", ] -from typing import ClassVar +from collections.abc import Callable +from typing import Any, ClassVar, Generic, TypeVar from ..helpers import ( dict_subtract, @@ -18,30 +19,35 @@ ) from ._impl import Matcher, Mismatch +K = TypeVar("K") +V = TypeVar("V") -def LabelledMismatches(mismatches, details=None): + +def LabelledMismatches( + mismatches: dict[Any, Mismatch], details: Any = None +) -> MismatchesAll: """A collection of mismatches, each labelled.""" return MismatchesAll( (PrefixedMismatch(k, v) for (k, v) in sorted(mismatches.items())), wrap=False ) -class MatchesAllDict(Matcher): +class MatchesAllDict(Matcher[Any]): """Matches if all of the matchers it is created with match. A lot like ``MatchesAll``, but takes a dict of Matchers and labels any mismatches with the key of the dictionary. """ - def __init__(self, matchers): + def __init__(self, matchers: dict[Any, "Matcher[Any]"]) -> None: super().__init__() self.matchers = matchers - def __str__(self): + def __str__(self) -> str: return f"MatchesAllDict({_format_matcher_dict(self.matchers)})" - def match(self, observed): - mismatches = {} + def match(self, observed: Any) -> Mismatch | None: + mismatches: dict[Any, Mismatch | None] = {} for label in self.matchers: mismatches[label] = self.matchers[label].match(observed) return _dict_to_mismatch(mismatches, result_mismatch=LabelledMismatches) @@ -50,11 +56,11 @@ def match(self, observed): class DictMismatches(Mismatch): """A mismatch with a dict of child mismatches.""" - def __init__(self, mismatches, details=None): + def __init__(self, mismatches: dict[Any, Mismatch], details: Any = None) -> None: super().__init__(None, details=details) self.mismatches = mismatches - def describe(self): + def describe(self) -> str: lines = ["{"] lines.extend( [ @@ -66,15 +72,20 @@ def describe(self): return "\n".join(lines) -def _dict_to_mismatch(data, to_mismatch=None, result_mismatch=DictMismatches): +def _dict_to_mismatch( + data: dict[K, V | None], + to_mismatch: "Callable[[V], Mismatch] | None" = None, + result_mismatch: "Callable[[dict[K, Mismatch]], Mismatch]" = DictMismatches, +) -> Mismatch | None: if to_mismatch: - data = map_values(to_mismatch, data) - mismatches = filter_values(bool, data) + data = map_values(to_mismatch, data) # type: ignore[arg-type,assignment] + mismatches = filter_values(bool, data) # type: ignore[arg-type] if mismatches: - return result_mismatch(mismatches) + return result_mismatch(mismatches) # type: ignore[arg-type] + return None -class _MatchCommonKeys(Matcher): +class _MatchCommonKeys(Matcher[dict[Any, Any]]): """Match on keys in a dictionary. Given a dictionary where the values are matchers, this will look for @@ -88,57 +99,64 @@ class _MatchCommonKeys(Matcher): None """ - def __init__(self, dict_of_matchers): + def __init__(self, dict_of_matchers: dict[Any, "Matcher[Any]"]) -> None: super().__init__() self._matchers = dict_of_matchers - def _compare_dicts(self, expected, observed): + def _compare_dicts( + self, expected: dict[Any, "Matcher[Any]"], observed: dict[Any, Any] + ) -> dict[Any, Mismatch]: common_keys = set(expected.keys()) & set(observed.keys()) - mismatches = {} + mismatches: dict[Any, Mismatch] = {} for key in common_keys: mismatch = expected[key].match(observed[key]) if mismatch: mismatches[key] = mismatch return mismatches - def match(self, observed): + def match(self, observed: dict[Any, Any]) -> Mismatch | None: mismatches = self._compare_dicts(self._matchers, observed) if mismatches: return DictMismatches(mismatches) + return None -class _SubDictOf(Matcher): +class _SubDictOf(Matcher[dict[Any, Any]]): """Matches if the matched dict only has keys that are in given dict.""" - def __init__(self, super_dict, format_value=repr): + def __init__( + self, super_dict: dict[Any, Any], format_value: Callable[[Any], str] = repr + ) -> None: super().__init__() self.super_dict = super_dict self.format_value = format_value - def match(self, observed): + def match(self, observed: dict[Any, Any]) -> Mismatch | None: excess = dict_subtract(observed, self.super_dict) return _dict_to_mismatch(excess, lambda v: Mismatch(self.format_value(v))) -class _SuperDictOf(Matcher): +class _SuperDictOf(Matcher[dict[Any, Any]]): """Matches if all of the keys in the given dict are in the matched dict.""" - def __init__(self, sub_dict, format_value=repr): + def __init__( + self, sub_dict: dict[Any, Any], format_value: Callable[[Any], str] = repr + ) -> None: super().__init__() self.sub_dict = sub_dict self.format_value = format_value - def match(self, super_dict): + def match(self, super_dict: dict[Any, Any]) -> Mismatch | None: return _SubDictOf(super_dict, self.format_value).match(self.sub_dict) -def _format_matcher_dict(matchers: dict[str, Matcher]) -> str: +def _format_matcher_dict(matchers: dict[str, "Matcher[Any]"]) -> str: return "{{{}}}".format( ", ".join(sorted(f"{k!r}: {v}" for k, v in matchers.items())) ) -class _CombinedMatcher(Matcher): +class _CombinedMatcher(Matcher[dict[Any, Any]]): """Many matchers labelled and combined into one uber-matcher. Subclass this and then specify a dict of matcher factories that take a @@ -148,19 +166,19 @@ class _CombinedMatcher(Matcher): Not **entirely** dissimilar from ``MatchesAll``. """ - matcher_factories: ClassVar[dict] = {} + matcher_factories: ClassVar[dict[str, Any]] = {} - def __init__(self, expected): + def __init__(self, expected: dict[str, "Matcher[Any]"]) -> None: super().__init__() self._expected = expected - def format_expected(self, expected): + def format_expected(self, expected: dict[str, "Matcher[Any]"]) -> str: return repr(expected) - def __str__(self): + def __str__(self) -> str: return f"{self.__class__.__name__}({self.format_expected(self._expected)})" - def match(self, observed): + def match(self, observed: dict[Any, Any]) -> Mismatch | None: matchers = {k: v(self._expected) for k, v in self.matcher_factories.items()} return MatchesAllDict(matchers).match(observed) @@ -174,13 +192,13 @@ class MatchesDict(_CombinedMatcher): expected dict. """ - matcher_factories: ClassVar[dict] = { + matcher_factories: ClassVar[dict[str, Any]] = { "Extra": _SubDictOf, "Missing": lambda m: _SuperDictOf(m, format_value=str), "Differences": _MatchCommonKeys, } - def format_expected(self, expected: dict[str, Matcher]) -> str: + def format_expected(self, expected: dict[str, "Matcher[Any]"]) -> str: return _format_matcher_dict(expected) @@ -199,12 +217,12 @@ class ContainsDict(_CombinedMatcher): match. """ - matcher_factories: ClassVar[dict] = { + matcher_factories: ClassVar[dict[str, Any]] = { "Missing": lambda m: _SuperDictOf(m, format_value=str), "Differences": _MatchCommonKeys, } - def format_expected(self, expected: dict[str, Matcher]) -> str: + def format_expected(self, expected: dict[str, "Matcher[Any]"]) -> str: return _format_matcher_dict(expected) @@ -223,19 +241,19 @@ class ContainedByDict(_CombinedMatcher): match. """ - matcher_factories: ClassVar[dict] = { + matcher_factories: ClassVar[dict[str, Any]] = { "Extra": _SubDictOf, "Differences": _MatchCommonKeys, } - def format_expected(self, expected: dict[str, Matcher]) -> str: + def format_expected(self, expected: dict[str, "Matcher[Any]"]) -> str: return _format_matcher_dict(expected) -class KeysEqual(Matcher): +class KeysEqual(Matcher[dict[K, Any]], Generic[K]): """Checks whether a dict has particular keys.""" - def __init__(self, *expected): + def __init__(self, *expected: K) -> None: """Create a `KeysEqual` Matcher. :param expected: The keys the matchee is expected to have. As a @@ -243,21 +261,22 @@ def __init__(self, *expected): mapping, then we use its keys as the expected set. """ super().__init__() + expected_keys: tuple[K, ...] | Any = expected if len(expected) == 1: try: - expected = expected[0].keys() + expected_keys = expected[0].keys() # type: ignore[attr-defined] except AttributeError: pass - self.expected = list(expected) + self.expected: list[K] = list(expected_keys) - def __str__(self): + def __str__(self) -> str: return "KeysEqual({})".format(", ".join(map(repr, self.expected))) - def match(self, matchee): + def match(self, matchee: dict[K, Any]) -> Mismatch | None: from ._basic import Equals, _BinaryMismatch - expected = sorted(self.expected) - matched = Equals(expected).match(sorted(matchee.keys())) + expected = sorted(self.expected) # type: ignore[type-var] + matched = Equals(expected).match(sorted(matchee.keys())) # type: ignore[type-var] if matched: return AnnotatedMismatch( "Keys not equal", _BinaryMismatch(expected, "does not match", matchee) diff --git a/testtools/matchers/_doctest.py b/testtools/matchers/_doctest.py index 6998332b..ac621d0b 100644 --- a/testtools/matchers/_doctest.py +++ b/testtools/matchers/_doctest.py @@ -6,10 +6,20 @@ import doctest import re +from typing import Any from ._impl import Mismatch +def _indent( + s: str, + indent: int = 4, + _pattern: re.Pattern[str] = re.compile("^(?!$)", re.MULTILINE), +) -> str: + """Prepend non-empty lines in ``s`` with ``indent`` number of spaces""" + return _pattern.sub(indent * " ", s) + + class _NonManglingOutputChecker(doctest.OutputChecker): """Doctest checker that works with unicode rather than mangling strings @@ -31,7 +41,7 @@ class _NonManglingOutputChecker(doctest.OutputChecker): is sufficient to revert this. """ - def _toAscii(self, s): + def _toAscii(self, s: str) -> str: """Return ``s`` unchanged rather than mangling it to ascii""" return s @@ -41,20 +51,15 @@ def _toAscii(self, s): __f = doctest.OutputChecker.output_difference.__func__ # type: ignore[attr-defined] __g = dict(__f.__globals__) - - def _indent(s, indent=4, _pattern=re.compile("^(?!$)", re.MULTILINE)): - """Prepend non-empty lines in ``s`` with ``indent`` number of spaces""" - return _pattern.sub(indent * " ", s) - __g["_indent"] = _indent - output_difference = __F(__f.func_code, __g, "output_difference") - del __F, __f, __g, _indent + output_difference: Any = __F(__f.func_code, __g, "output_difference") + del __F, __f, __g class DocTestMatches: """See if a string matches a doctest example.""" - def __init__(self, example, flags=0): + def __init__(self, example: str, flags: int = 0) -> None: """Create a DocTestMatches to match example. :param example: The example to match e.g. 'foo bar baz' @@ -67,35 +72,36 @@ def __init__(self, example, flags=0): self.flags = flags self._checker = _NonManglingOutputChecker() - def __str__(self): + def __str__(self) -> str: if self.flags: flagstr = f", flags={self.flags}" else: flagstr = "" return f"DocTestMatches({self.want!r}{flagstr})" - def _with_nl(self, actual): + def _with_nl(self, actual: str) -> str: result = self.want.__class__(actual) if not result.endswith("\n"): result += "\n" return result - def match(self, actual): + def match(self, actual: str) -> "DocTestMismatch | None": with_nl = self._with_nl(actual) if self._checker.check_output(self.want, with_nl, self.flags): return None return DocTestMismatch(self, with_nl) - def _describe_difference(self, with_nl): - return self._checker.output_difference(self, with_nl, self.flags) + def _describe_difference(self, with_nl: str) -> str: + result: str = self._checker.output_difference(self, with_nl, self.flags) + return result class DocTestMismatch(Mismatch): """Mismatch object for DocTestMatches.""" - def __init__(self, matcher, with_nl): + def __init__(self, matcher: DocTestMatches, with_nl: str) -> None: self.matcher = matcher self.with_nl = with_nl - def describe(self): + def describe(self) -> str: return self.matcher._describe_difference(self.with_nl) diff --git a/testtools/matchers/_exception.py b/testtools/matchers/_exception.py index e1e822de..5e20d6e2 100644 --- a/testtools/matchers/_exception.py +++ b/testtools/matchers/_exception.py @@ -7,29 +7,42 @@ ] import sys +import types +from collections.abc import Callable +from typing import TypeAlias, TypeVar from ._basic import MatchesRegex from ._higherorder import AfterPreprocessing -from ._impl import ( - Matcher, - Mismatch, -) +from ._impl import Matcher, Mismatch + +# Type for exc_info tuples +ExcInfo: TypeAlias = tuple[ + type[BaseException], BaseException, types.TracebackType | None +] + +T = TypeVar("T", bound=BaseException) _error_repr = BaseException.__repr__ -def _is_exception(exc): +def _is_exception(exc: object) -> bool: return isinstance(exc, BaseException) -def _is_user_exception(exc): +def _is_user_exception(exc: object) -> bool: return isinstance(exc, Exception) -class MatchesException(Matcher): +class MatchesException(Matcher[ExcInfo]): """Match an exc_info tuple against an exception instance or type.""" - def __init__(self, exception, value_re=None): + def __init__( + self, + exception: BaseException + | type[BaseException] + | tuple[type[BaseException], ...], + value_re: "str | Matcher[BaseException] | None" = None, + ) -> None: """Create a MatchesException that will match exc_info's for exception. :param exception: Either an exception instance or type. @@ -45,23 +58,43 @@ def __init__(self, exception, value_re=None): """ Matcher.__init__(self) self.expected = exception + value_matcher: Matcher[BaseException] | None if isinstance(value_re, str): - value_re = AfterPreprocessing(str, MatchesRegex(value_re), False) - self.value_re = value_re + value_matcher = AfterPreprocessing(str, MatchesRegex(value_re), False) + else: + value_matcher = value_re + self.value_re = value_matcher expected_type = type(self.expected) self._is_instance = not any( issubclass(expected_type, class_type) for class_type in (type, tuple) ) - def match(self, other): + def match(self, other: ExcInfo) -> Mismatch | None: if not isinstance(other, tuple): return Mismatch(f"{other!r} is not an exc_info tuple") - expected_class = self.expected + expected_class: type[BaseException] if self._is_instance: - expected_class = expected_class.__class__ - if not issubclass(other[0], expected_class): - return Mismatch(f"{other[0]!r} is not a {expected_class!r}") + assert isinstance(self.expected, BaseException) + expected_class = self.expected.__class__ + else: + if isinstance(self.expected, tuple): + # For tuple of types, just use the first one for error message + expected_class = self.expected[0] + else: + assert isinstance(self.expected, type) + expected_class = self.expected + + # Check if other[0] is a subclass of expected_class + exc_type = other[0] + if isinstance(self.expected, tuple): + if not any(issubclass(exc_type, cls) for cls in self.expected): + return Mismatch(f"{exc_type!r} is not a {self.expected!r}") + else: + if not issubclass(exc_type, expected_class): + return Mismatch(f"{exc_type!r} is not a {expected_class!r}") + if self._is_instance: + assert isinstance(self.expected, BaseException) if other[1].args != self.expected.args: return Mismatch( f"{_error_repr(other[1])} has different arguments to " @@ -69,21 +102,23 @@ def match(self, other): ) elif self.value_re is not None: return self.value_re.match(other[1]) + return None - def __str__(self): + def __str__(self) -> str: if self._is_instance: + assert isinstance(self.expected, BaseException) return f"MatchesException({_error_repr(self.expected)})" return f"MatchesException({self.expected!r})" -class Raises(Matcher): +class Raises(Matcher[Callable[[], object]]): """Match if the matchee raises an exception when called. Exceptions which are not subclasses of Exception propagate out of the Raises.match call unless they are explicitly matched. """ - def __init__(self, exception_matcher=None): + def __init__(self, exception_matcher: "Matcher[ExcInfo] | None" = None) -> None: """Create a Raises matcher. :param exception_matcher: Optional validator for the exception raised @@ -94,38 +129,48 @@ def __init__(self, exception_matcher=None): """ self.exception_matcher = exception_matcher - def match(self, matchee): + def match(self, matchee: "Callable[[], object]") -> Mismatch | None: try: # Handle staticmethod objects by extracting the underlying function + actual_callable: Callable[[], object] if isinstance(matchee, staticmethod): - matchee = matchee.__func__ - result = matchee() + actual_callable = matchee.__func__ # type: ignore[assignment] + else: + actual_callable = matchee + result = actual_callable() return Mismatch(f"{matchee!r} returned {result!r}") # Catch all exceptions: Raises() should be able to match a # KeyboardInterrupt or SystemExit. except BaseException: exc_info = sys.exc_info() + # Type narrow to actual ExcInfo + assert exc_info[0] is not None + assert exc_info[1] is not None + typed_exc_info: ExcInfo = (exc_info[0], exc_info[1], exc_info[2]) # type: ignore[assignment] + if self.exception_matcher: - mismatch = self.exception_matcher.match(exc_info) + mismatch = self.exception_matcher.match(typed_exc_info) if not mismatch: del exc_info - return + return None else: mismatch = None # The exception did not match, or no explicit matching logic was # performed. If the exception is a non-user exception then # propagate it. - exception = exc_info[1] + exception = typed_exc_info[1] if _is_exception(exception) and not _is_user_exception(exception): del exc_info raise return mismatch - def __str__(self): + def __str__(self) -> str: return "Raises()" -def raises(exception): +def raises( + exception: BaseException | type[BaseException] | tuple[type[BaseException], ...], +) -> Raises: """Make a matcher that checks that a callable raises an exception. This is a convenience function, exactly equivalent to:: diff --git a/testtools/matchers/_filesystem.py b/testtools/matchers/_filesystem.py index 4c101dfc..efc349e7 100644 --- a/testtools/matchers/_filesystem.py +++ b/testtools/matchers/_filesystem.py @@ -14,6 +14,7 @@ import os import tarfile +from typing import Any from ._basic import Equals from ._higherorder import ( @@ -22,10 +23,11 @@ ) from ._impl import ( Matcher, + Mismatch, ) -def PathExists(): +def PathExists() -> "MatchesPredicate[str]": """Matches if the given path exists. Use like this:: @@ -35,7 +37,7 @@ def PathExists(): return MatchesPredicate(os.path.exists, "%s does not exist.") -def DirExists(): +def DirExists() -> "MatchesAll[str]": """Matches if the path exists and is a directory.""" return MatchesAll( PathExists(), @@ -44,7 +46,7 @@ def DirExists(): ) -def FileExists(): +def FileExists() -> "MatchesAll[str]": """Matches if the given path exists and is a file.""" return MatchesAll( PathExists(), @@ -53,13 +55,17 @@ def FileExists(): ) -class DirContains(Matcher): +class DirContains(Matcher[str]): """Matches if the given directory contains files with the given names. That is, is the directory listing exactly equal to the given files? """ - def __init__(self, filenames=None, matcher=None): + def __init__( + self, + filenames: "list[str] | None" = None, + matcher: "Matcher[list[str]] | None" = None, + ) -> None: """Construct a ``DirContains`` matcher. Can be used in a basic mode where the whole directory listing is @@ -73,28 +79,34 @@ def __init__(self, filenames=None, matcher=None): :param matcher: If specified, match the sorted directory listing against this matcher. """ - if filenames == matcher is None: + if filenames is None and matcher is None: raise AssertionError("Must provide one of `filenames` or `matcher`.") - if None not in (filenames, matcher): + if filenames is not None and matcher is not None: raise AssertionError( "Must provide either `filenames` or `matcher`, not both." ) if filenames is None: - self.matcher = matcher + assert matcher is not None + self.matcher: Matcher[list[str]] = matcher else: self.matcher = Equals(sorted(filenames)) - def match(self, path): + def match(self, path: str) -> Mismatch | None: mismatch = DirExists().match(path) if mismatch is not None: return mismatch return self.matcher.match(sorted(os.listdir(path))) -class FileContains(Matcher): +class FileContains(Matcher[str]): """Matches if the given file has the specified contents.""" - def __init__(self, contents=None, matcher=None, encoding=None): + def __init__( + self, + contents: "bytes | str | None" = None, + matcher: "Matcher[Any] | None" = None, + encoding: str | None = None, + ) -> None: """Construct a ``FileContains`` matcher. Can be used in a basic mode where the file contents are compared for @@ -112,21 +124,21 @@ def __init__(self, contents=None, matcher=None, encoding=None): in text mode. Only used when contents is a str (or when using a matcher for text content). Defaults to the system default encoding. """ - if contents == matcher is None: + if contents is None and matcher is None: raise AssertionError("Must provide one of `contents` or `matcher`.") - if None not in (contents, matcher): + if contents is not None and matcher is not None: raise AssertionError( "Must provide either `contents` or `matcher`, not both." ) if matcher is None: - self.matcher = Equals(contents) - self._binary_mode = isinstance(contents, bytes) + self.matcher: Matcher[Any] = Equals(contents) + self._binary_mode: bool = isinstance(contents, bytes) else: self.matcher = matcher self._binary_mode = False self.encoding = encoding - def match(self, path): + def match(self, path: str) -> Mismatch | None: mismatch = PathExists().match(path) if mismatch is not None: return mismatch @@ -138,17 +150,17 @@ def match(self, path): actual_contents = f.read() return self.matcher.match(actual_contents) - def __str__(self): + def __str__(self) -> str: return f"File at path exists and contains {self.matcher}" -class HasPermissions(Matcher): +class HasPermissions(Matcher[str]): """Matches if a file has the given permissions. Permissions are specified and matched as a four-digit octal string. """ - def __init__(self, octal_permissions): + def __init__(self, octal_permissions: str) -> None: """Construct a HasPermissions matcher. :param octal_permissions: A four digit octal string, representing the @@ -157,41 +169,41 @@ def __init__(self, octal_permissions): super().__init__() self.octal_permissions = octal_permissions - def match(self, filename): + def match(self, filename: str) -> Mismatch | None: permissions = oct(os.stat(filename).st_mode)[-4:] return Equals(self.octal_permissions).match(permissions) -class SamePath(Matcher): +class SamePath(Matcher[str]): """Matches if two paths are the same. That is, the paths are equal, or they point to the same file but in different ways. The paths do not have to exist. """ - def __init__(self, path): + def __init__(self, path: str) -> None: super().__init__() self.path = path - def match(self, other_path): - def f(x): + def match(self, other_path: str) -> Mismatch | None: + def f(x: str) -> str: return os.path.abspath(os.path.realpath(x)) return Equals(f(self.path)).match(f(other_path)) -class TarballContains(Matcher): +class TarballContains(Matcher[str]): """Matches if the given tarball contains the given paths. Uses TarFile.getnames() to get the paths out of the tarball. """ - def __init__(self, paths): + def __init__(self, paths: list[str]) -> None: super().__init__() self.paths = paths - self.path_matcher = Equals(sorted(self.paths)) + self.path_matcher: Matcher[list[str]] = Equals(sorted(self.paths)) - def match(self, tarball_path): + def match(self, tarball_path: str) -> Mismatch | None: # Open underlying file first to ensure it's always closed: # f = open(tarball_path, "rb") diff --git a/testtools/matchers/_higherorder.py b/testtools/matchers/_higherorder.py index 55cf191c..113b40d9 100644 --- a/testtools/matchers/_higherorder.py +++ b/testtools/matchers/_higherorder.py @@ -10,7 +10,15 @@ "Not", ] +import sys import types +from collections.abc import Callable, Iterable +from typing import Any, Generic, TypedDict, TypeVar + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack from ._impl import ( Matcher, @@ -18,14 +26,20 @@ MismatchDecorator, ) +T = TypeVar("T") + + +class MatchesAllOptions(TypedDict, total=False): + first_only: bool -class MatchesAny: + +class MatchesAny(Matcher[T], Generic[T]): """Matches if any of the matchers it is created with match.""" - def __init__(self, *matchers): + def __init__(self, *matchers: Matcher[T]) -> None: self.matchers = matchers - def match(self, matchee): + def match(self, matchee: T) -> Mismatch | None: results = [] for matcher in self.matchers: mismatch = matcher.match(matchee) @@ -34,16 +48,18 @@ def match(self, matchee): results.append(mismatch) return MismatchesAll(results) - def __str__(self): + def __str__(self) -> str: return "MatchesAny({})".format( ", ".join([str(matcher) for matcher in self.matchers]) ) -class MatchesAll: +class MatchesAll(Matcher[T], Generic[T]): """Matches if all of the matchers it is created with match.""" - def __init__(self, *matchers, **options): + def __init__( + self, *matchers: Matcher[T], **options: "Unpack[MatchesAllOptions]" + ) -> None: """Construct a MatchesAll matcher. Just list the component matchers as arguments in the ``*args`` @@ -54,10 +70,10 @@ def __init__(self, *matchers, **options): self.matchers = matchers self.first_only = options.get("first_only", False) - def __str__(self): + def __str__(self) -> str: return "MatchesAll({})".format(", ".join(map(str, self.matchers))) - def match(self, matchee): + def match(self, matchee: T) -> Mismatch | None: results = [] for matcher in self.matchers: mismatch = matcher.match(matchee) @@ -74,11 +90,11 @@ def match(self, matchee): class MismatchesAll(Mismatch): """A mismatch with many child mismatches.""" - def __init__(self, mismatches, wrap=True): - self.mismatches = mismatches + def __init__(self, mismatches: "Iterable[Mismatch]", wrap: bool = True) -> None: + self.mismatches = list(mismatches) self._wrap = wrap - def describe(self): + def describe(self) -> str: descriptions = [] if self._wrap: descriptions = ["Differences: ["] @@ -89,16 +105,16 @@ def describe(self): return "\n".join(descriptions) -class Not: +class Not(Matcher[T], Generic[T]): """Inverts a matcher.""" - def __init__(self, matcher): + def __init__(self, matcher: Matcher[T]) -> None: self.matcher = matcher - def __str__(self): + def __str__(self) -> str: return f"Not({self.matcher})" - def match(self, other): + def match(self, other: T) -> Mismatch | None: mismatch = self.matcher.match(other) if mismatch is None: return MatchedUnexpectedly(self.matcher, other) @@ -106,52 +122,53 @@ def match(self, other): return None -class MatchedUnexpectedly(Mismatch): +class MatchedUnexpectedly(Mismatch, Generic[T]): """A thing matched when it wasn't supposed to.""" - def __init__(self, matcher, other): + def __init__(self, matcher: Matcher[T], other: T) -> None: self.matcher = matcher self.other = other - def describe(self): + def describe(self) -> str: return f"{self.other!r} matches {self.matcher}" -class Annotate: +class Annotate(Matcher[T], Generic[T]): """Annotates a matcher with a descriptive string. Mismatches are then described as ': '. """ - def __init__(self, annotation, matcher): + def __init__(self, annotation: str, matcher: Matcher[T]) -> None: self.annotation = annotation self.matcher = matcher @classmethod - def if_message(cls, annotation, matcher): + def if_message(cls, annotation: str, matcher: Matcher[T]) -> Matcher[T]: """Annotate ``matcher`` only if ``annotation`` is non-empty.""" if not annotation: return matcher return cls(annotation, matcher) - def __str__(self): + def __str__(self) -> str: return f"Annotate({self.annotation!r}, {self.matcher})" - def match(self, other): + def match(self, other: T) -> Mismatch | None: mismatch = self.matcher.match(other) if mismatch is not None: return AnnotatedMismatch(self.annotation, mismatch) + return None class PostfixedMismatch(MismatchDecorator): """A mismatch annotated with a descriptive string.""" - def __init__(self, annotation, mismatch): + def __init__(self, annotation: str, mismatch: Mismatch) -> None: super().__init__(mismatch) self.annotation = annotation self.mismatch = mismatch - def describe(self): + def describe(self) -> str: return f"{self.original.describe()}: {self.annotation}" @@ -159,15 +176,18 @@ def describe(self): class PrefixedMismatch(MismatchDecorator): - def __init__(self, prefix, mismatch): + def __init__(self, prefix: str, mismatch: Mismatch) -> None: super().__init__(mismatch) self.prefix = prefix - def describe(self): + def describe(self) -> str: return f"{self.prefix}: {self.original.describe()}" -class AfterPreprocessing: +U = TypeVar("U") + + +class AfterPreprocessing(Matcher[T], Generic[T, U]): """Matches if the value matches after passing through a function. This can be used to aid in creating trivial matchers as functions, for @@ -179,7 +199,12 @@ def _read(path): return AfterPreprocessing(_read, Equals(content)) """ - def __init__(self, preprocessor, matcher, annotate=True): + def __init__( + self, + preprocessor: "Callable[[T], U]", + matcher: Matcher[U], + annotate: bool = True, + ) -> None: """Create an AfterPreprocessing matcher. :param preprocessor: A function called with the matchee before @@ -193,18 +218,18 @@ def __init__(self, preprocessor, matcher, annotate=True): self.matcher = matcher self.annotate = annotate - def _str_preprocessor(self): + def _str_preprocessor(self) -> str: if isinstance(self.preprocessor, types.FunctionType): return f"" return str(self.preprocessor) - def __str__(self): + def __str__(self) -> str: return f"AfterPreprocessing({self._str_preprocessor()}, {self.matcher})" - def match(self, value): + def match(self, value: T) -> Mismatch | None: after = self.preprocessor(value) if self.annotate: - matcher = Annotate( + matcher: Matcher[U] = Annotate( f"after {self._str_preprocessor()} on {value!r}", self.matcher ) else: @@ -212,16 +237,19 @@ def match(self, value): return matcher.match(after) -class AllMatch: +V = TypeVar("V") + + +class AllMatch(Matcher["Iterable[V]"], Generic[V]): """Matches if all provided values match the given matcher.""" - def __init__(self, matcher): + def __init__(self, matcher: Matcher[V]) -> None: self.matcher = matcher - def __str__(self): + def __str__(self) -> str: return f"AllMatch({self.matcher})" - def match(self, values): + def match(self, values: "Iterable[V]") -> Mismatch | None: mismatches = [] for value in values: mismatch = self.matcher.match(value) @@ -229,18 +257,19 @@ def match(self, values): mismatches.append(mismatch) if mismatches: return MismatchesAll(mismatches) + return None -class AnyMatch: +class AnyMatch(Matcher["Iterable[V]"], Generic[V]): """Matches if any of the provided values match the given matcher.""" - def __init__(self, matcher): + def __init__(self, matcher: Matcher[V]) -> None: self.matcher = matcher - def __str__(self): + def __str__(self) -> str: return f"AnyMatch({self.matcher})" - def match(self, values): + def match(self, values: "Iterable[V]") -> Mismatch | None: mismatches = [] for value in values: mismatch = self.matcher.match(value) @@ -251,7 +280,7 @@ def match(self, values): return MismatchesAll(mismatches) -class MatchesPredicate(Matcher): +class MatchesPredicate(Matcher[T], Generic[T]): """Match if a given function returns True. It is reasonably common to want to make a very simple matcher based on a @@ -263,7 +292,7 @@ class MatchesPredicate(Matcher): self.assertThat(4, IsEven) """ - def __init__(self, predicate, message): + def __init__(self, predicate: "Callable[[T], bool]", message: str) -> None: """Create a ``MatchesPredicate`` matcher. :param predicate: A function that takes a single argument and returns @@ -275,15 +304,18 @@ def __init__(self, predicate, message): self.predicate = predicate self.message = message - def __str__(self): + def __str__(self) -> str: return f"{self.__class__.__name__}({self.predicate!r}, {self.message!r})" - def match(self, x): + def match(self, x: T) -> Mismatch | None: if not self.predicate(x): return Mismatch(self.message % x) + return None -def MatchesPredicateWithParams(predicate, message, name=None): +def MatchesPredicateWithParams( + predicate: "Callable[..., bool]", message: str, name: str | None = None +) -> "Callable[..., _MatchesPredicateWithParams[object]]": """Match if a given parameterised function returns True. It is reasonably common to want to make a very simple matcher based on a @@ -309,14 +341,26 @@ def MatchesPredicateWithParams(predicate, message, name=None): :param name: Optional replacement name for the matcher. """ - def construct_matcher(*args, **kwargs): + def construct_matcher( + *args: Any, **kwargs: Any + ) -> _MatchesPredicateWithParams[object]: return _MatchesPredicateWithParams(predicate, message, name, *args, **kwargs) return construct_matcher -class _MatchesPredicateWithParams(Matcher): - def __init__(self, predicate, message, name, *args, **kwargs): +class _MatchesPredicateWithParams(Matcher[T], Generic[T]): + args: "tuple[Any, ...]" + kwargs: "dict[str, Any]" + + def __init__( + self, + predicate: "Callable[..., bool]", + message: str, + name: str | None, + *args: Any, + **kwargs: Any, + ) -> None: """Create a ``MatchesPredicateWithParams`` matcher. :param predicate: A function that takes an object to match and @@ -344,16 +388,17 @@ def __init__(self, predicate, message, name, *args, **kwargs): self.args = args self.kwargs = kwargs - def __str__(self): + def __str__(self) -> str: args_list = [str(arg) for arg in self.args] - kwargs = ["{}={}".format(*item) for item in self.kwargs.items()] - args = ", ".join(args_list + kwargs) + kwargs_list = ["{}={}".format(*item) for item in self.kwargs.items()] + args_str = ", ".join(args_list + kwargs_list) if self.name is None: name = f"MatchesPredicateWithParams({self.predicate!r}, {self.message!r})" else: name = self.name - return f"{name}({args})" + return f"{name}({args_str})" - def match(self, x): + def match(self, x: T) -> Mismatch | None: if not self.predicate(x, *self.args, **self.kwargs): return Mismatch(self.message.format(*((x, *self.args)), **self.kwargs)) + return None diff --git a/testtools/matchers/_impl.py b/testtools/matchers/_impl.py index 6253eaab..e28bee22 100644 --- a/testtools/matchers/_impl.py +++ b/testtools/matchers/_impl.py @@ -18,9 +18,15 @@ ] import unicodedata +from typing import TYPE_CHECKING, Generic, TypeVar +if TYPE_CHECKING: + from testtools.testresult import DetailsDict -def _slow_escape(text): +T = TypeVar("T") + + +def _slow_escape(text: str) -> str: """Escape unicode ``text`` leaving printable characters unmodified The behaviour emulates the Python 3 implementation of repr, see @@ -30,7 +36,7 @@ def _slow_escape(text): does not handle astral characters correctly on Python builds with 16 bit rather than 32 bit unicode type. """ - output = [] + output: list[str | bytes] = [] for c in text: o = ord(c) if o < 256: @@ -47,14 +53,14 @@ def _slow_escape(text): output.append(c.encode("unicode-escape")) else: output.append(c) - return "".join(output) + return "".join(output) # type: ignore[arg-type] -def text_repr(text, multiline=None): +def text_repr(text: str | bytes, multiline: bool | None = None) -> str: """Rich repr for ``text`` returning unicode, triple quoted if ``multiline``.""" nl = (isinstance(text, bytes) and bytes((0xA,))) or "\n" if multiline is None: - multiline = nl in text + multiline = nl in text # type: ignore[operator] if not multiline: # Use normal repr for single line of unicode return repr(text) @@ -64,7 +70,7 @@ def text_repr(text, multiline=None): # making sure that quotes are not escaped. offset = len(prefix) + 1 lines = [] - for line in text.split(nl): + for line in text.split(nl): # type: ignore[arg-type] r = repr(line) q = r[-1] lines.append(r[offset:-1].replace("\\" + q, q)) @@ -91,7 +97,7 @@ def text_repr(text, multiline=None): return "".join([prefix, quote, escaped_text, quote]) -class Matcher: +class Matcher(Generic[T]): """A pattern matcher. A Matcher must implement match and __str__ to be used by @@ -105,11 +111,11 @@ class Matcher: a Java transcription. """ - def match(self, something): + def match(self, something: T) -> "Mismatch | None": """Return None if this matcher matches something, a Mismatch otherwise.""" raise NotImplementedError(self.match) - def __str__(self): + def __str__(self) -> str: """Get a sensible human representation of the matcher. This should include the parameters given to the matcher and any @@ -121,7 +127,9 @@ def __str__(self): class Mismatch: """An object describing a mismatch detected by a Matcher.""" - def __init__(self, description=None, details=None): + def __init__( + self, description: str | None = None, details: "DetailsDict | None" = None + ) -> None: """Construct a `Mismatch`. :param description: A description to use. If not provided, @@ -135,7 +143,7 @@ def __init__(self, description=None, details=None): details = {} self._details = details - def describe(self): + def describe(self) -> str: """Describe the mismatch. This should be either a human-readable string or castable to a string. @@ -147,7 +155,7 @@ def describe(self): except AttributeError: raise NotImplementedError(self.describe) - def get_details(self): + def get_details(self) -> "DetailsDict": """Get extra details about the mismatch. This allows the mismatch to provide extra information beyond the basic @@ -165,14 +173,14 @@ def get_details(self): """ return getattr(self, "_details", {}) - def __repr__(self): + def __repr__(self) -> str: return ( f"" ) -class MismatchError(AssertionError): +class MismatchError(AssertionError, Generic[T]): """Raised when a mismatch occurs.""" # This class exists to work around @@ -180,14 +188,16 @@ class MismatchError(AssertionError): # guaranteed way of getting a readable exception, no matter what crazy # characters are in the matchee, matcher or mismatch. - def __init__(self, matchee, matcher, mismatch, verbose=False): + def __init__( + self, matchee: T, matcher: Matcher[T], mismatch: Mismatch, verbose: bool = False + ) -> None: super().__init__() self.matchee = matchee self.matcher = matcher self.mismatch = mismatch self.verbose = verbose - def __str__(self): + def __str__(self) -> str: difference = self.mismatch.describe() if self.verbose: # GZ 2011-08-24: Smelly API? Better to take any object and special @@ -204,7 +214,7 @@ def __str__(self): return difference -class MismatchDecorator: +class MismatchDecorator(Mismatch): """Decorate a ``Mismatch``. Forwards all messages to the original mismatch object. Probably the best @@ -212,20 +222,20 @@ class MismatchDecorator: custom decoration logic. """ - def __init__(self, original): + def __init__(self, original: Mismatch) -> None: """Construct a `MismatchDecorator`. :param original: A `Mismatch` object to decorate. """ self.original = original - def __repr__(self): + def __repr__(self) -> str: return f"" - def describe(self): + def describe(self) -> str: return self.original.describe() - def get_details(self): + def get_details(self) -> "DetailsDict": return self.original.get_details() diff --git a/testtools/matchers/_warnings.py b/testtools/matchers/_warnings.py index 02526417..5b72dc95 100644 --- a/testtools/matchers/_warnings.py +++ b/testtools/matchers/_warnings.py @@ -3,6 +3,8 @@ __all__ = ["IsDeprecated", "WarningMessage", "Warnings"] import warnings +from collections.abc import Callable +from typing import Any from ._basic import Is from ._const import Always @@ -11,10 +13,16 @@ AfterPreprocessing, Annotate, ) -from ._impl import Mismatch +from ._impl import Matcher, Mismatch -def WarningMessage(category_type, message=None, filename=None, lineno=None, line=None): +def WarningMessage( + category_type: type[Warning], + message: "Matcher[Any] | None" = None, + filename: "Matcher[Any] | None" = None, + lineno: "Matcher[Any] | None" = None, + line: "Matcher[Any] | None" = None, +) -> "MatchesStructure[warnings.WarningMessage]": r"""Create a matcher that will match `warnings.WarningMessage`\s. For example, to match captured `DeprecationWarning`s with a message about @@ -38,10 +46,10 @@ def WarningMessage(category_type, message=None, filename=None, lineno=None, line warning's line of source code. """ category_matcher = Is(category_type) - message_matcher = message or Always() - filename_matcher = filename or Always() - lineno_matcher = lineno or Always() - line_matcher = line or Always() + message_matcher: Matcher[Any] = message or Always() + filename_matcher: Matcher[Any] = filename or Always() + lineno_matcher: Matcher[Any] = lineno or Always() + line_matcher: Matcher[Any] = line or Always() return MatchesStructure( category=Annotate("Warning's category type does not match", category_matcher), message=Annotate( @@ -56,7 +64,7 @@ def WarningMessage(category_type, message=None, filename=None, lineno=None, line class Warnings: """Match if the matchee produces warnings.""" - def __init__(self, warnings_matcher=None): + def __init__(self, warnings_matcher: "Matcher[Any] | None" = None) -> None: """Create a Warnings matcher. :param warnings_matcher: Optional validator for the warnings emitted by @@ -65,7 +73,7 @@ def __init__(self, warnings_matcher=None): """ self.warnings_matcher = warnings_matcher - def match(self, matchee): + def match(self, matchee: Callable[[], Any]) -> Mismatch | None: with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") # Handle staticmethod objects by extracting the underlying function @@ -76,12 +84,13 @@ def match(self, matchee): return self.warnings_matcher.match(w) elif not w: return Mismatch("Expected at least one warning, got none") + return None - def __str__(self): + def __str__(self) -> str: return f"Warnings({self.warnings_matcher!s})" -def IsDeprecated(message): +def IsDeprecated(message: "Matcher[Any]") -> Warnings: """Make a matcher that checks that a callable produces exactly one `DeprecationWarning`. diff --git a/testtools/monkey.py b/testtools/monkey.py index 56768e5e..51669e58 100644 --- a/testtools/monkey.py +++ b/testtools/monkey.py @@ -2,6 +2,9 @@ """Helpers for monkey-patching Python code.""" +from collections.abc import Callable +from typing import Any + __all__ = [ "MonkeyPatcher", "patch", @@ -19,7 +22,7 @@ class MonkeyPatcher: # object before we patched it. _NO_SUCH_ATTRIBUTE = object() - def __init__(self, *patches): + def __init__(self, *patches: tuple[object, str, object]) -> None: """Construct a `MonkeyPatcher`. :param patches: The patches to apply, each should be (obj, name, @@ -27,14 +30,14 @@ def __init__(self, *patches): `add_patch`. """ # List of patches to apply in (obj, name, value). - self._patches_to_apply = [] + self._patches_to_apply: list[tuple[object, str, object]] = [] # List of the original values for things that have been patched. # (obj, name, value) format. - self._originals = [] + self._originals: list[tuple[object, str, object]] = [] for patch in patches: self.add_patch(*patch) - def add_patch(self, obj, name, value): + def add_patch(self, obj: object, name: str, value: object) -> None: """Add a patch to overwrite 'name' on 'obj' with 'value'. The attribute C{name} on C{obj} will be assigned to C{value} when @@ -44,7 +47,7 @@ def add_patch(self, obj, name, value): """ self._patches_to_apply.append((obj, name, value)) - def patch(self): + def patch(self) -> None: """Apply all of the patches that have been specified with `add_patch`. Reverse this operation using L{restore}. @@ -54,7 +57,7 @@ def patch(self): self._originals.append((obj, name, original_value)) setattr(obj, name, value) - def restore(self): + def restore(self) -> None: """Restore all original values to any patched objects. If the patched attribute did not exist on an object before it was @@ -68,7 +71,7 @@ def restore(self): else: setattr(obj, name, value) - def run_with_patches(self, f, *args, **kw): + def run_with_patches(self, f: Callable[..., Any], *args: Any, **kw: Any) -> Any: """Run 'f' with the given args and kwargs with all patches applied. Restores all objects to their original state when finished. @@ -80,7 +83,7 @@ def run_with_patches(self, f, *args, **kw): self.restore() -def patch(obj, attribute, value): +def patch(obj: object, attribute: str, value: object) -> Callable[[], None]: """Set 'obj.attribute' to 'value' and return a callable to restore 'obj'. If 'attribute' is not set on 'obj' already, then the returned callable diff --git a/testtools/run.py b/testtools/run.py index 97ae6fd8..2cef8ca3 100755 --- a/testtools/run.py +++ b/testtools/run.py @@ -11,8 +11,11 @@ import os.path import sys import unittest +from argparse import ArgumentParser +from collections.abc import Callable from functools import partial -from typing import Any +from types import ModuleType +from typing import IO, Any, TextIO from testtools import TextTestResult from testtools.compat import unicode_output_stream @@ -33,7 +36,9 @@ USAGE_AS_MAIN = "" -def list_test(test): +def list_test( + test: unittest.TestSuite | unittest.TestCase, +) -> tuple[list[str], list[str]]: """Return the test ids that would be run if test() was run. When things fail to import they can be represented as well, though @@ -50,31 +55,36 @@ def list_test(test): "unittest.loader.ModuleImportFailure.", "discover.ModuleImportFailure.", } - test_ids = [] - errors = [] - for test in iterate_tests(test): + test_ids: list[str] = [] + errors: list[str] = [] + for single_test in iterate_tests(test): # Much ugly. for prefix in unittest_import_strs: - if test.id().startswith(prefix): - errors.append(test.id()[len(prefix) :]) + if single_test.id().startswith(prefix): + errors.append(single_test.id()[len(prefix) :]) break else: - test_ids.append(test.id()) + test_ids.append(single_test.id()) return test_ids, errors class TestToolsTestRunner: """A thunk object to support unittest.TestProgram.""" + verbosity: int + failfast: bool | None + stdout: IO[str] + tb_locals: bool + def __init__( self, - verbosity=None, - failfast=None, - buffer=None, - stdout=None, - tb_locals=False, - **kwargs, - ): + verbosity: int | None = None, + failfast: bool | None = None, + buffer: bool | None = None, + stdout: IO[str] | None = None, + tb_locals: bool = False, + **kwargs: Any, + ) -> None: """Create a TestToolsTestRunner. :param verbosity: Verbosity level. 0 for quiet, 1 for normal (dots, default), @@ -91,22 +101,30 @@ def __init__( self.stdout = stdout self.tb_locals = tb_locals - def list(self, test, loader): + def list( + self, + test: unittest.TestSuite | unittest.TestCase, + loader: unittest.TestLoader | None = None, + ) -> None: """List the tests that would be run if test() was run.""" test_ids, _ = list_test(test) for test_id in test_ids: self.stdout.write(f"{test_id}\n") - errors = loader.errors - if errors: - for test_id in errors: - self.stdout.write(f"{test_id}\n") - sys.exit(2) - - def run(self, test): + if loader is not None: + errors = loader.errors + if errors: + for error in errors: + self.stdout.write(f"{error}\n") + sys.exit(2) + + def run( + self, test: unittest.TestSuite | unittest.TestCase + ) -> unittest.TestResult | None: """Run the given test case or test suite.""" + stream: TextIO = unicode_output_stream(self.stdout) # type: ignore[assignment] result = TextTestResult( - unicode_output_stream(self.stdout), - failfast=self.failfast, + stream, + failfast=self.failfast or False, tb_locals=self.tb_locals, verbosity=self.verbosity, ) @@ -134,27 +152,41 @@ class TestProgram(unittest.TestProgram): """ # defaults for testing - module = None - verbosity = 1 - failfast = catchbreak = buffer = progName = None - _discovery_parser = None + module: ModuleType | None = None + verbosity: int = 1 + failfast: bool | None = None + catchbreak: bool | None = None + buffer: bool | None = None + progName: str | None = None + _discovery_parser: ArgumentParser | None = None test: Any # Set by parent class + stdout: IO[str] + exit: bool + tb_locals: bool + defaultTest: str | None + listtests: bool + load_list: str | None + testRunner: Callable[..., TestToolsTestRunner] | TestToolsTestRunner | None + testLoader: unittest.TestLoader + result: unittest.TestResult def __init__( self, - module=__name__, - defaultTest=None, - argv=None, - testRunner=None, - testLoader=defaultTestLoader, - exit=True, - verbosity=1, - failfast=None, - catchbreak=None, - buffer=None, - stdout=None, - tb_locals=False, - ): + module: str | ModuleType | None = __name__, + defaultTest: str | None = None, + argv: list[str] | None = None, + testRunner: Callable[..., TestToolsTestRunner] + | TestToolsTestRunner + | None = None, + testLoader: unittest.TestLoader = defaultTestLoader, + exit: bool = True, + verbosity: int = 1, + failfast: bool | None = None, + catchbreak: bool | None = None, + buffer: bool | None = None, + stdout: IO[str] | None = None, + tb_locals: bool = False, + ) -> None: if module == __name__: self.module = None elif isinstance(module, str): @@ -208,17 +240,14 @@ def __init__( else: runner = self._get_runner() if hasattr(runner, "list"): - try: - runner.list(self.test, loader=self.testLoader) - except TypeError: - runner.list(self.test) + runner.list(self.test, loader=self.testLoader) else: for test in iterate_tests(self.test): self.stdout.write(f"{test.id()}\n") del self.testLoader.errors[:] - def _getParentArgParser(self): - parser = super()._getParentArgParser() # type: ignore[misc] + def _getParentArgParser(self) -> ArgumentParser: + parser: ArgumentParser = super()._getParentArgParser() # type: ignore[misc] # XXX: Local edit (see http://bugs.python.org/issue22860) parser.add_argument( "-l", @@ -237,27 +266,39 @@ def _getParentArgParser(self): ) return parser - def _do_discovery(self, argv, Loader=None): + def _do_discovery( + self, argv: list[str], Loader: type[unittest.TestLoader] | None = None + ) -> None: super()._do_discovery(argv, Loader=Loader) # type: ignore[misc] # XXX: Local edit (see http://bugs.python.org/issue22860) self.test = sorted_tests(self.test) - def runTests(self): + def runTests(self) -> None: # XXX: Local edit (see http://bugs.python.org/issue22860) if self.catchbreak and getattr(unittest, "installHandler", None) is not None: unittest.installHandler() testRunner = self._get_runner() - self.result = testRunner.run(self.test) + result = testRunner.run(self.test) + if result is not None: + self.result = result if self.exit: sys.exit(not self.result.wasSuccessful()) - def _get_runner(self): + def _get_runner(self) -> TestToolsTestRunner: # XXX: Local edit (see http://bugs.python.org/issue22860) - if self.testRunner is None: - self.testRunner = TestToolsTestRunner + runner_or_factory = self.testRunner + if runner_or_factory is None: + runner_or_factory = TestToolsTestRunner + + # If it's already an instance, return it directly + if isinstance(runner_or_factory, TestToolsTestRunner): + return runner_or_factory + + # It's a callable (class or factory function) + runner_factory: Callable[..., TestToolsTestRunner] = runner_or_factory try: try: - testRunner = self.testRunner( + testRunner = runner_factory( verbosity=self.verbosity, failfast=self.failfast, buffer=self.buffer, @@ -266,7 +307,7 @@ def _get_runner(self): ) except TypeError: # didn't accept the tb_locals parameter - testRunner = self.testRunner( + testRunner = runner_factory( verbosity=self.verbosity, failfast=self.failfast, buffer=self.buffer, @@ -276,23 +317,19 @@ def _get_runner(self): # didn't accept the verbosity, buffer, failfast or stdout arguments # Try with the prior contract try: - testRunner = self.testRunner( + testRunner = runner_factory( verbosity=self.verbosity, failfast=self.failfast, buffer=self.buffer ) except TypeError: # Now try calling it with defaults - try: - testRunner = self.testRunner() - except TypeError: - # it is assumed to be a TestRunner instance - testRunner = self.testRunner + testRunner = runner_factory() return testRunner ################ -def main(argv, stdout): +def main(argv: list[str], stdout: IO[str]) -> None: TestProgram( argv=argv, testRunner=partial(TestToolsTestRunner, stdout=stdout), stdout=stdout ) diff --git a/testtools/runtest.py b/testtools/runtest.py index f3ecaf9e..06743ccc 100644 --- a/testtools/runtest.py +++ b/testtools/runtest.py @@ -8,8 +8,17 @@ ] import sys +from collections.abc import Callable +from typing import TYPE_CHECKING, Any -from testtools.testresult import ExtendedToOriginalDecorator +from testtools.testresult import ( + ExcInfo, + ExtendedToOriginalDecorator, + TestResult, +) + +if TYPE_CHECKING: + from testtools.testcase import TestCase class MultipleExceptions(Exception): @@ -47,7 +56,17 @@ class RunTest: reporting of error/failure/skip etc. """ - def __init__(self, case, handlers=None, last_resort=None): + def __init__( + self, + case: "TestCase", + handlers: ( + "list[tuple[type[BaseException], " + "Callable[[TestCase, TestResult, BaseException], None]]] | None" + ) = None, + last_resort: ( + "Callable[[TestCase, TestResult, BaseException], None] | None" + ) = None, + ) -> None: """Create a RunTest to run a case. :param case: A testtools.TestCase test case object. @@ -61,11 +80,11 @@ def __init__(self, case, handlers=None, last_resort=None): """ self.case = case self.handlers = handlers or [] - self.exception_caught = object() - self._exceptions = [] + self.exception_caught: object = object() + self._exceptions: list[BaseException] = [] self.last_resort = last_resort or (lambda case, result, exc: None) - def run(self, result=None): + def run(self, result: "TestResult | None" = None) -> TestResult: """Run self.case reporting activity to result. :param result: Optional testtools.TestResult to report activity to. @@ -82,7 +101,7 @@ def run(self, result=None): if result is None: actual_result.stopTestRun() - def _run_one(self, result): + def _run_one(self, result: "TestResult") -> "TestResult": """Run one test reporting to result. :param result: A testtools.TestResult to report activity to. @@ -91,9 +110,9 @@ def _run_one(self, result): confidence by client code. :return: The result object the test was run against. """ - return self._run_prepared_result(ExtendedToOriginalDecorator(result)) + return self._run_prepared_result(ExtendedToOriginalDecorator(result)) # type: ignore[arg-type] - def _run_prepared_result(self, result): + def _run_prepared_result(self, result: "TestResult") -> "TestResult": """Run one test reporting to result. :param result: A testtools.TestResult to report activity to. @@ -103,7 +122,9 @@ def _run_prepared_result(self, result): self.result = result try: self._exceptions = [] - self.case.__testtools_tb_locals__ = getattr(result, "tb_locals", False) + self.case.__testtools_tb_locals__ = getattr( # type: ignore[attr-defined] + result, "tb_locals", False + ) self._run_core() if self._exceptions: # One or more caught exceptions, now trigger the test's @@ -120,7 +141,7 @@ def _run_prepared_result(self, result): result.stopTest(self.case) return result - def _run_core(self): + def _run_core(self) -> None: """Run the user supplied test code.""" test_method = self.case._get_test_method() skip_case = getattr(self.case, "__unittest_skip__", False) @@ -168,7 +189,7 @@ def _run_core(self): self.case, details=self.case.getDetails() ) - def _run_cleanups(self, result): + def _run_cleanups(self, result: "TestResult") -> object | None: """Run the cleanups that have been added with addCleanup. See the docstring for addCleanup for more information. @@ -184,8 +205,9 @@ def _run_cleanups(self, result): failing = True if failing: return self.exception_caught + return None - def _run_user(self, fn, *args, **kwargs): + def _run_user(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: """Run a user supplied function. Exceptions are processed by `_got_user_exception`. @@ -196,9 +218,13 @@ def _run_user(self, fn, *args, **kwargs): try: return fn(*args, **kwargs) except BaseException: - return self._got_user_exception(sys.exc_info()) + # Inside except block, exc_info() is guaranteed to have non-None values + exc_info = sys.exc_info() + return self._got_user_exception(exc_info) # type: ignore[arg-type] - def _got_user_exception(self, exc_info, tb_label="traceback"): + def _got_user_exception( + self, exc_info: ExcInfo, tb_label: str = "traceback" + ) -> object: """Called when user code raises an exception. If 'exc_info' is a `MultipleExceptions`, then we recurse into it @@ -226,7 +252,7 @@ def _got_user_exception(self, exc_info, tb_label="traceback"): return self.exception_caught -def _raise_force_fail_error(): +def _raise_force_fail_error() -> None: raise AssertionError("Forced Test Failure") diff --git a/testtools/tags.py b/testtools/tags.py index 341b3f8e..c9b2df0c 100644 --- a/testtools/tags.py +++ b/testtools/tags.py @@ -2,6 +2,8 @@ """Tag support.""" +from collections.abc import Iterable + class TagContext: """A tag context.""" @@ -22,7 +24,9 @@ def get_current_tags(self) -> set[str]: """Return any current tags.""" return set(self._tags) - def change_tags(self, new_tags: set[str], gone_tags: set[str]) -> set[str]: + def change_tags( + self, new_tags: Iterable[str], gone_tags: Iterable[str] + ) -> set[str]: """Change the tags on this context. :param new_tags: A set of tags to add to this context. diff --git a/testtools/testcase.py b/testtools/testcase.py index a5b2729e..f5910a2c 100644 --- a/testtools/testcase.py +++ b/testtools/testcase.py @@ -16,13 +16,20 @@ ] import copy +import datetime import functools import itertools import sys +import types import unittest +from collections.abc import Callable, Iterator from typing import Any, Protocol, TypeVar, cast from unittest.case import SkipTest +T = TypeVar("T") +U = TypeVar("U") + +# ruff: noqa: E402 - TypeVars must be defined before importing testtools modules from testtools import content from testtools.matchers import ( Annotate, @@ -43,6 +50,8 @@ RunTest, ) from testtools.testresult import ( + DetailsDict, + ExcInfo, ExtendedToOriginalDecorator, TestResult, ) @@ -64,22 +73,26 @@ class _ExpectedFailure(Exception): """ +# TypeVar for decorators +_F = TypeVar("_F", bound=Callable[..., object]) + + # Copied from unittest before python 3.4 release. Used to maintain # compatibility with unittest sub-test feature. Users should not use this # directly. -def _expectedFailure(func): +def _expectedFailure(func: _F) -> _F: @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: object, **kwargs: object) -> None: try: func(*args, **kwargs) except Exception: raise _ExpectedFailure(sys.exc_info()) raise _UnexpectedSuccess - return wrapper + return cast(_F, wrapper) -def run_test_with(test_runner, **kwargs): +def run_test_with(test_runner: type[RunTest], **kwargs: object) -> Callable[[_F], _F]: """Decorate a test as using a specific ``RunTest``. e.g.:: @@ -105,10 +118,19 @@ def test_foo(self): runner. """ - def decorator(function): + def decorator(function: _F) -> _F: # Set an attribute on 'function' which will inform TestCase how to # make the runner. - def _run_test_with(case, handlers=None, last_resort=None): + def _run_test_with( + case: "TestCase", + handlers: ( + "list[tuple[type[BaseException], " + "Callable[[TestCase, TestResult, BaseException], None]]] | None" + ) = None, + last_resort: ( + "Callable[[TestCase, TestResult, BaseException], None] | None" + ) = None, + ) -> RunTest: try: return test_runner( case, handlers=handlers, last_resort=last_resort, **kwargs @@ -116,15 +138,15 @@ def _run_test_with(case, handlers=None, last_resort=None): except TypeError: # Backwards compat: if we can't call the constructor # with last_resort, try without that. - return test_runner(case, handlers=handlers, **kwargs) + return test_runner(case, handlers=handlers, **kwargs) # type: ignore[arg-type] - function._run_test_with = _run_test_with + function._run_test_with = _run_test_with # type: ignore[attr-defined] return function return decorator -def _copy_content(content_object): +def _copy_content(content_object: content.Content) -> content.Content: """Make a copy of the given content object. The content within ``content_object`` is iterated and saved. This is @@ -137,13 +159,13 @@ def _copy_content(content_object): """ content_bytes = list(content_object.iter_bytes()) - def content_callback(): + def content_callback() -> list[bytes]: return content_bytes return content.Content(content_object.content_type, content_callback) -def gather_details(source_dict, target_dict): +def gather_details(source_dict: DetailsDict, target_dict: DetailsDict) -> None: """Merge the details from ``source_dict`` into ``target_dict``. ``gather_details`` evaluates all details in ``source_dict``. Do not use it @@ -171,15 +193,15 @@ def gather_details(source_dict, target_dict): class UseFixtureProtocol(Protocol): - def setUp(self) -> Any: ... - def cleanUp(self) -> Any: ... - def getDetails(self) -> dict: ... + def setUp(self) -> None: ... + def cleanUp(self) -> None: ... + def getDetails(self) -> DetailsDict: ... UseFixtureT = TypeVar("UseFixtureT", bound=UseFixtureProtocol) -def _mods(i, mod): +def _mods(i: int, mod: int) -> Iterator[int]: (q, r) = divmod(i, mod) while True: yield r @@ -188,14 +210,14 @@ def _mods(i, mod): (q, r) = divmod(q, mod) -def _unique_text(base_cp, cp_range, index): +def _unique_text(base_cp: int, cp_range: int, index: int) -> str: s = "" for m in _mods(index, cp_range): s += chr(base_cp + m) return s -def unique_text_generator(prefix): +def unique_text_generator(prefix: str) -> Iterator[str]: """Generates unique text values. Generates text values that are unique. Use this when you need arbitrary @@ -234,7 +256,7 @@ class TestCase(unittest.TestCase): run_tests_with = RunTest - def __init__(self, *args, **kwargs): + def __init__(self, *args: object, **kwargs: object) -> None: """Construct a TestCase. :param testMethod: The name of the method to run. @@ -244,19 +266,26 @@ def __init__(self, *args, **kwargs): ``TestCase.run_tests_with`` if given. """ runTest = kwargs.pop("runTest", None) - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs) # type: ignore[arg-type] self._reset() test_method = self._get_test_method() if runTest is None: runTest = getattr(test_method, "_run_test_with", self.run_tests_with) - self.__RunTest = runTest + self.__RunTest: type[RunTest] | Callable[..., RunTest] = cast( + "type[RunTest] | Callable[..., RunTest]", runTest + ) if getattr(test_method, "__unittest_expecting_failure__", False): setattr(self, self._testMethodName, _expectedFailure(test_method)) # Used internally for onException processing - used to gather extra # data from exceptions. - self.__exception_handlers = [] + self.__exception_handlers: list[Callable[[ExcInfo], None]] = [] # Passed to RunTest to map exceptions to result actions - self.exception_handlers = [ + self.exception_handlers: list[ + tuple[ + type[BaseException], + Callable[[TestCase, TestResult, BaseException], None], + ] + ] = [ (self.skipException, self._report_skip), (self.failureException, self._report_failure), (_ExpectedFailure, self._report_expected_failure), @@ -264,20 +293,22 @@ def __init__(self, *args, **kwargs): (Exception, self._report_error), ] - def _reset(self): + def _reset(self) -> None: """Reset the test case as if it had never been run.""" - self._cleanups = [] - self._unique_id_gen = itertools.count(1) + self._cleanups: list[ + tuple[Callable[..., object], tuple[object, ...], dict[str, object]] + ] = [] + self._unique_id_gen: Iterator[int] = itertools.count(1) # Generators to ensure unique traceback ids. Maps traceback label to # iterators. - self._traceback_id_gens = {} - self.__setup_called = False - self.__teardown_called = False + self._traceback_id_gens: dict[str, Iterator[int]] = {} + self.__setup_called: bool = False + self.__teardown_called: bool = False # __details is lazy-initialized so that a constructed-but-not-run # TestCase is safe to use with clone_test_with_new_id. - self.__details = None + self.__details: DetailsDict | None = None - def __eq__(self, other): + def __eq__(self, other: object) -> bool: eq = getattr(unittest.TestCase, "__eq__", None) if eq is not None: eq_ = unittest.TestCase.__eq__(self, other) @@ -289,11 +320,11 @@ def __eq__(self, other): # https://docs.python.org/3/reference/datamodel.html#object.__hash__ __hash__ = unittest.TestCase.__hash__ - def __repr__(self): + def __repr__(self) -> str: # We add id to the repr because it makes testing testtools easier. return f"<{self.id()} id=0x{id(self):0x}>" - def addDetail(self, name, content_object): + def addDetail(self, name: str, content_object: content.Content) -> None: """Add a detail to be reported with this test's outcome. For more details see pydoc testtools.TestResult. @@ -306,7 +337,7 @@ def addDetail(self, name, content_object): self.__details = {} self.__details[name] = content_object - def getDetails(self): + def getDetails(self) -> DetailsDict: """Get the details dict that will be reported with this test's outcome. For more details see pydoc testtools.TestResult. @@ -315,7 +346,7 @@ def getDetails(self): self.__details = {} return self.__details - def patch(self, obj, attribute, value): + def patch(self, obj: object, attribute: str, value: object) -> None: """Monkey-patch 'obj.attribute' to 'value' while the test is running. If 'obj' has no attribute, then the monkey-patch will still go ahead, @@ -328,10 +359,10 @@ def patch(self, obj, attribute, value): """ self.addCleanup(patch(obj, attribute, value)) - def shortDescription(self): + def shortDescription(self) -> str: return self.id() - def skipTest(self, reason): + def skipTest(self, reason: str) -> None: # type: ignore[override] """Cause this test to be skipped. This raises self.skipException(reason). skipException is raised @@ -344,14 +375,24 @@ def skipTest(self, reason): """ raise self.skipException(reason) - def _formatTypes(self, classOrIterable): + def _formatTypes( + self, + classOrIterable: type[object] | tuple[type[object], ...] | list[type[object]], + ) -> str: """Format a class or a bunch of classes for display in an error.""" - className = getattr(classOrIterable, "__name__", None) - if className is None: + if isinstance(classOrIterable, (tuple, list)): className = ", ".join(klass.__name__ for klass in classOrIterable) + else: + className = classOrIterable.__name__ return className - def addCleanup(self, function, /, *arguments, **keywordArguments): + def addCleanup( + self, + function: Callable[..., object], + /, + *arguments: object, + **keywordArguments: object, + ) -> None: """Add a cleanup function to be called after tearDown. Functions added with addCleanup will be called in reverse order of @@ -366,7 +407,7 @@ def addCleanup(self, function, /, *arguments, **keywordArguments): """ self._cleanups.append((function, arguments, keywordArguments)) - def addOnException(self, handler): + def addOnException(self, handler: "Callable[[ExcInfo], None]") -> None: """Add a handler to be called when an exception occurs in test code. This handler cannot affect what result methods are called, and is @@ -385,10 +426,12 @@ def addOnException(self, handler): """ self.__exception_handlers.append(handler) - def _add_reason(self, reason): + def _add_reason(self, reason: str) -> None: self.addDetail("reason", content.text_content(reason)) - def assertEqual(self, expected, observed, message=""): + def assertEqual( + self, expected: object, observed: object, message: str = "" + ) -> None: """Assert that 'expected' is equal to 'observed'. :param expected: The expected value. @@ -398,29 +441,29 @@ def assertEqual(self, expected, observed, message=""): matcher = _FlippedEquals(expected) self.assertThat(observed, matcher, message) - def assertIn(self, needle, haystack, message=""): + def assertIn(self, needle: object, haystack: object, message: str = "") -> None: """Assert that needle is in haystack.""" self.assertThat(haystack, Contains(needle), message) - def assertIsNone(self, observed, message=""): + def assertIsNone(self, observed: object, message: str = "") -> None: """Assert that 'observed' is equal to None. :param observed: The observed value. :param message: An optional message describing the error. """ - matcher = Is(None) + matcher: Matcher[object] = Is(None) self.assertThat(observed, matcher, message) - def assertIsNotNone(self, observed, message=""): + def assertIsNotNone(self, observed: object, message: str = "") -> None: """Assert that 'observed' is not equal to None. :param observed: The observed value. :param message: An optional message describing the error. """ - matcher = Not(Is(None)) + matcher: Matcher[object] = Not(Is(None)) self.assertThat(observed, matcher, message) - def assertIs(self, expected, observed, message=""): + def assertIs(self, expected: object, observed: object, message: str = "") -> None: """Assert that 'expected' is 'observed'. :param expected: The expected value. @@ -430,24 +473,37 @@ def assertIs(self, expected, observed, message=""): matcher = Is(expected) self.assertThat(observed, matcher, message) - def assertIsNot(self, expected, observed, message=""): + def assertIsNot( + self, expected: object, observed: object, message: str = "" + ) -> None: """Assert that 'expected' is not 'observed'.""" matcher = Not(Is(expected)) self.assertThat(observed, matcher, message) - def assertNotIn(self, needle, haystack, message=""): + def assertNotIn(self, needle: U, haystack: T, message: str = "") -> None: """Assert that needle is not in haystack.""" - matcher = Not(Contains(needle)) + matcher: Not[T] = Not(Contains(needle)) self.assertThat(haystack, matcher, message) - def assertIsInstance(self, obj, klass, msg=None): + def assertIsInstance( # type: ignore[override] + self, + obj: object, + klass: type[object] | tuple[type[object], ...], + msg: str | None = None, + ) -> None: if isinstance(klass, tuple): matcher = IsInstance(*klass) else: matcher = IsInstance(klass) - self.assertThat(obj, matcher, msg) + self.assertThat(obj, matcher, msg or "") - def assertRaises(self, expected_exception, callable=None, *args, **kwargs): + def assertRaises( # type: ignore[override] + self, + expected_exception: type[BaseException], + callable: Callable[..., object] | None = None, + *args: object, + **kwargs: object, + ) -> "_AssertRaisesContext | BaseException": """Fail unless an exception of class expected_exception is thrown by callable when invoked with arguments args and keyword arguments kwargs. If a different type of exception is @@ -472,16 +528,40 @@ def assertRaises(self, expected_exception, callable=None, *args, **kwargs): """ # If callable is None, we're being used as a context manager if callable is None: - return _AssertRaisesContext(expected_exception, self, msg=kwargs.get("msg")) - - class ReRaiseOtherTypes: - def match(self, matchee): + msg_value = kwargs.get("msg") + msg_str: str | None = msg_value if isinstance(msg_value, str) else None + return _AssertRaisesContext(expected_exception, self, msg=msg_str) + + class ReRaiseOtherTypes( + Matcher[ + tuple[type[BaseException], BaseException, types.TracebackType | None] + ] + ): + def match( + self, + matchee: tuple[ + type[BaseException], BaseException, types.TracebackType | None + ], + ) -> None: if not issubclass(matchee[0], expected_exception): raise matchee[1].with_traceback(matchee[2]) + return None - class CaptureMatchee: - def match(self, matchee): + class CaptureMatchee( + Matcher[ + tuple[type[BaseException], BaseException, types.TracebackType | None] + ] + ): + matchee: BaseException + + def match( + self, + matchee: tuple[ + type[BaseException], BaseException, types.TracebackType | None + ], + ) -> None: self.matchee = matchee[1] + return None capture = CaptureMatchee() matcher = Raises( @@ -489,11 +569,17 @@ def match(self, matchee): ReRaiseOtherTypes(), MatchesException(expected_exception), capture ) ) - our_callable = Nullary(callable, *args, **kwargs) + our_callable: Callable[[], object] = Nullary(callable, *args, **kwargs) self.assertThat(our_callable, matcher) return capture.matchee - def assertThat(self, matchee, matcher, message="", verbose=False): + def assertThat( + self, + matchee: T, + matcher: "Matcher[T]", + message: str = "", + verbose: bool = False, + ) -> None: """Assert that matchee is matched by matcher. :param matchee: An object to match with matcher. @@ -504,7 +590,7 @@ def assertThat(self, matchee, matcher, message="", verbose=False): if mismatch_error is not None: raise mismatch_error - def addDetailUniqueName(self, name, content_object): + def addDetailUniqueName(self, name: str, content_object: content.Content) -> None: """Add a detail to the test, but ensure it's name is unique. This method checks whether ``name`` conflicts with a detail that has @@ -525,7 +611,13 @@ def addDetailUniqueName(self, name, content_object): suffix += 1 self.addDetail(full_name, content_object) - def expectThat(self, matchee, matcher, message="", verbose=False): + def expectThat( + self, + matchee: object, + matcher: "Matcher[object]", + message: str = "", + verbose: bool = False, + ) -> None: """Check that matchee is matched by matcher, but delay the assertion failure. This method behaves similarly to ``assertThat``, except that a failed @@ -549,19 +641,27 @@ def expectThat(self, matchee, matcher, message="", verbose=False): ) self.force_failure = True - def _matchHelper(self, matchee, matcher, message, verbose): + def _matchHelper( + self, matchee: T, matcher: "Matcher[T]", message: str, verbose: bool + ) -> "MismatchError[T] | None": matcher = Annotate.if_message(message, matcher) mismatch = matcher.match(matchee) if not mismatch: - return + return None for name, value in mismatch.get_details().items(): self.addDetailUniqueName(name, value) return MismatchError(matchee, matcher, mismatch, verbose) - def defaultTestResult(self): + def defaultTestResult(self) -> TestResult: return TestResult() - def expectFailure(self, reason, predicate, *args, **kwargs): + def expectFailure( + self, + reason: str, + predicate: Callable[..., object], + *args: object, + **kwargs: object, + ) -> None: """Check that a test fails in a particular way. If the test fails in the expected way, a KnownFailure is caused. If it @@ -593,7 +693,7 @@ def expectFailure(self, reason, predicate, *args, **kwargs): else: raise _UnexpectedSuccess(reason) - def getUniqueInteger(self): + def getUniqueInteger(self) -> int: """Get an integer unique to this test. Returns an integer that is guaranteed to be unique to this instance. @@ -602,7 +702,7 @@ def getUniqueInteger(self): """ return next(self._unique_id_gen) - def getUniqueString(self, prefix=None): + def getUniqueString(self, prefix: str | None = None) -> str: """Get a string unique to this test. Returns a string that is guaranteed to be unique to this instance. Use @@ -617,7 +717,7 @@ def getUniqueString(self, prefix=None): prefix = self.id() return f"{prefix}-{self.getUniqueInteger()}" - def onException(self, exc_info, tb_label="traceback"): + def onException(self, exc_info: ExcInfo, tb_label: str = "traceback") -> None: """Called when an exception propagates from test code. :seealso addOnException: @@ -632,19 +732,23 @@ def onException(self, exc_info, tb_label="traceback"): handler(exc_info) @staticmethod - def _report_error(self, result, err): + def _report_error(self: "TestCase", result: TestResult, err: BaseException) -> None: result.addError(self, details=self.getDetails()) @staticmethod - def _report_expected_failure(self, result, err): + def _report_expected_failure( + self: "TestCase", result: TestResult, err: BaseException + ) -> None: result.addExpectedFailure(self, details=self.getDetails()) @staticmethod - def _report_failure(self, result, err): + def _report_failure( + self: "TestCase", result: TestResult, err: BaseException + ) -> None: result.addFailure(self, details=self.getDetails()) @staticmethod - def _report_skip(self, result, err): + def _report_skip(self: "TestCase", result: TestResult, err: BaseException) -> None: if err.args: reason = err.args[0] else: @@ -652,7 +756,9 @@ def _report_skip(self, result, err): self._add_reason(reason) result.addSkip(self, details=self.getDetails()) - def _report_traceback(self, exc_info, tb_label="traceback"): + def _report_traceback( + self, exc_info: "ExcInfo | tuple[None, None, None]", tb_label: str = "traceback" + ) -> None: id_gen = self._traceback_id_gens.setdefault(tb_label, itertools.count(0)) while True: tb_id = next(id_gen) @@ -670,10 +776,12 @@ def _report_traceback(self, exc_info, tb_label="traceback"): ) @staticmethod - def _report_unexpected_success(self, result, err): + def _report_unexpected_success( + self: "TestCase", result: TestResult, err: BaseException + ) -> None: result.addUnexpectedSuccess(self, details=self.getDetails()) - def run(self, result=None): + def run(self, result: TestResult | None = None) -> TestResult: # type: ignore[override] self._reset() try: run_test = self.__RunTest( @@ -685,19 +793,25 @@ def run(self, result=None): run_test = self.__RunTest(self, self.exception_handlers) return run_test.run(result) - def _run_setup(self, result): + def _run_setup(self, result: TestResult) -> object: """Run the setUp function for this test. :param result: A testtools.TestResult to report activity to. :raises ValueError: If the base class setUp is not called, a ValueError is raised. """ - ret = self.setUp() + # setUp() normally returns None, but async test frameworks may + # return Deferred-like objects + setup_result: object = self.setUp() # type: ignore[func-returns-value] # Check if the return value is a Deferred (duck-typing to avoid hard dependency) - if hasattr(ret, "addBoth") and callable(getattr(ret, "addBoth")): + if ( + setup_result is not None + and hasattr(setup_result, "addBoth") + and callable(getattr(setup_result, "addBoth")) + ): # Deferred-like object: validate asynchronously after it resolves - def _validate_setup_called(result): + def _validate_setup_called(result: object) -> object: if not self.__setup_called: raise ValueError( f"In File: {sys.modules[self.__class__.__module__].__file__}\n" @@ -708,7 +822,7 @@ def _validate_setup_called(result): ) return result - ret.addBoth(_validate_setup_called) + setup_result = setup_result.addBoth(_validate_setup_called) else: # Synchronous: validate immediately if not self.__setup_called: @@ -719,21 +833,27 @@ def _validate_setup_called(result): f"super({self.__class__.__name__}, self).setUp() " "from your setUp()." ) - return ret + return setup_result - def _run_teardown(self, result): + def _run_teardown(self, result: TestResult) -> object: """Run the tearDown function for this test. :param result: A testtools.TestResult to report activity to. :raises ValueError: If the base class tearDown is not called, a ValueError is raised. """ - ret = self.tearDown() + # tearDown() normally returns None, but async test frameworks + # may return Deferred-like objects + teardown_result: object = self.tearDown() # type: ignore[func-returns-value] # Check if the return value is a Deferred (duck-typing to avoid hard dependency) - if hasattr(ret, "addBoth") and callable(getattr(ret, "addBoth")): + if ( + teardown_result is not None + and hasattr(teardown_result, "addBoth") + and callable(getattr(teardown_result, "addBoth")) + ): # Deferred-like object: validate asynchronously after it resolves - def _validate_teardown_called(result): + def _validate_teardown_called(result: object) -> object: if not self.__teardown_called: raise ValueError( f"In File: {sys.modules[self.__class__.__module__].__file__}\n" @@ -744,7 +864,7 @@ def _validate_teardown_called(result): ) return result - ret.addBoth(_validate_teardown_called) + teardown_result = teardown_result.addBoth(_validate_teardown_called) else: # Synchronous: validate immediately if not self.__teardown_called: @@ -755,12 +875,12 @@ def _validate_teardown_called(result): f"super({self.__class__.__name__}, self).tearDown() " "from your tearDown()." ) - return ret + return teardown_result - def _get_test_method(self): + def _get_test_method(self) -> Callable[[], object]: method_name = getattr(self, "_testMethodName") try: - m = getattr(self, method_name) + m: Callable[[], object] = getattr(self, method_name) except AttributeError: if method_name != "runTest": # We allow instantiation with no explicit method name @@ -768,10 +888,12 @@ def _get_test_method(self): raise ValueError( f"no such test method in {self.__class__}: {method_name}" ) + # If runTest doesn't exist, return a no-op callable + return lambda: None else: return m - def _run_test_method(self, result): + def _run_test_method(self, result: TestResult) -> object: """Run the test method for this test. :param result: A testtools.TestResult to report activity to. @@ -820,7 +942,7 @@ def useFixture(self, fixture: UseFixtureT) -> UseFixtureT: self.addCleanup(gather_details, fixture.getDetails(), self.getDetails()) return fixture - def setUp(self): + def setUp(self) -> None: super().setUp() if self.__setup_called: raise ValueError( @@ -831,7 +953,7 @@ def setUp(self): ) self.__setup_called = True - def tearDown(self): + def tearDown(self) -> None: super().tearDown() if self.__teardown_called: raise ValueError( @@ -843,25 +965,28 @@ def tearDown(self): self.__teardown_called = True -class PlaceHolder: +class PlaceHolder(unittest.TestCase): """A placeholder test. `PlaceHolder` implements much of the same interface as TestCase and is particularly suitable for being added to TestResults. """ - failureException = None + failureException = None # type: ignore[assignment] def __init__( self, - test_id, - short_description=None, - details=None, - outcome="addSuccess", - error=None, - tags=None, - timestamps=(None, None), - ): + test_id: str, + short_description: str | None = None, + details: DetailsDict | None = None, + outcome: str = "addSuccess", + error: "ExcInfo | tuple[None, None, None] | None" = None, + tags: frozenset[str] | None = None, + timestamps: "tuple[datetime.datetime | None, datetime.datetime | None]" = ( + None, + None, + ), + ) -> None: """Construct a `PlaceHolder`. :param test_id: The id of the placeholder test. @@ -882,11 +1007,13 @@ def __init__( tags = tags or frozenset() self._tags = frozenset(tags) self._timestamps = timestamps + # Required for unittest.TestCase compatibility + self._testMethodName = "run" - def __call__(self, result=None): + def __call__(self, result: unittest.TestResult | None = None) -> None: return self.run(result=result) - def __repr__(self): + def __repr__(self) -> str: internal = [self._outcome, self._test_id, self._details] if self._short_description is not None: internal.append(self._short_description) @@ -896,16 +1023,16 @@ def __repr__(self): ", ".join(map(repr, internal)), ) - def __str__(self): + def __str__(self) -> str: return self.id() - def countTestCases(self): + def countTestCases(self) -> int: return 1 - def debug(self): + def debug(self) -> None: pass - def id(self): + def id(self) -> str: return self._test_id def _result( @@ -929,7 +1056,7 @@ def run(self, result: unittest.TestResult | None = None) -> None: result_obj.stopTest(self) result_obj.tags(set(), self._tags) - def shortDescription(self): + def shortDescription(self) -> str: if self._short_description is None: return self.id() else: @@ -938,9 +1065,9 @@ def shortDescription(self): def ErrorHolder( test_id: str, - error: tuple, + error: "ExcInfo | tuple[None, None, None]", short_description: str | None = None, - details: dict | None = None, + details: "DetailsDict | None" = None, ) -> PlaceHolder: """Construct an `ErrorHolder`. @@ -960,7 +1087,9 @@ def ErrorHolder( ) -def _clone_test_id_callback(test, callback): +def _clone_test_id_callback( + test: unittest.TestCase, callback: Callable[[], str] +) -> unittest.TestCase: """Copy a `TestCase`, and make it call callback for its id(). This is only expected to be used on tests that have been constructed but @@ -971,11 +1100,11 @@ def _clone_test_id_callback(test, callback): :return: A copy.copy of the test with id=callback. """ newTest = copy.copy(test) - newTest.id = callback + newTest.id = callback # type: ignore[method-assign] return newTest -def clone_test_with_new_id(test, new_id): +def clone_test_with_new_id(test: unittest.TestCase, new_id: str) -> unittest.TestCase: """Copy a `TestCase`, and give the copied test a new id. This is only expected to be used on tests that have been constructed but @@ -984,7 +1113,7 @@ def clone_test_with_new_id(test, new_id): return _clone_test_id_callback(test, lambda: new_id) -def attr(*args): +def attr(*args: str) -> Callable[[_F], _F]: """Decorator for adding attributes to WithAttributes. :param args: The name of attributes to add. @@ -992,10 +1121,10 @@ def attr(*args): alter its id to enumerate the added attributes. """ - def decorate(fn): + def decorate(fn: _F) -> _F: if not hasattr(fn, "__testtools_attrs"): - fn.__testtools_attrs = set() - fn.__testtools_attrs.update(args) + fn.__testtools_attrs = set() # type: ignore[attr-defined] + fn.__testtools_attrs.update(args) # type: ignore[attr-defined] return fn return decorate @@ -1029,7 +1158,7 @@ def id(self) -> str: class_types = (type,) -def skip(reason): +def skip(reason: str) -> Callable[[_F], _F]: """A decorator to skip unit tests. This is just syntactic sugar so users don't have to change any of their @@ -1037,43 +1166,43 @@ def skip(reason): @unittest.skip decorator. """ - def decorator(test_item): + def decorator(test_item: _F) -> _F: if not isinstance(test_item, class_types): @functools.wraps(test_item) - def skip_wrapper(*args, **kwargs): + def skip_wrapper(*args: object, **kwargs: object) -> None: raise TestCase.skipException(reason) - test_item = skip_wrapper + test_item = cast(_F, skip_wrapper) # This attribute signals to RunTest._run_core that the entire test # must be skipped - including setUp and tearDown. This makes us # compatible with testtools.skip* functions, which set the same # attributes. - test_item.__unittest_skip__ = True - test_item.__unittest_skip_why__ = reason + test_item.__unittest_skip__ = True # type: ignore[attr-defined] + test_item.__unittest_skip_why__ = reason # type: ignore[attr-defined] return test_item return decorator -def skipIf(condition, reason): +def skipIf(condition: bool, reason: str) -> Callable[[_F], _F]: """A decorator to skip a test if the condition is true.""" if condition: return skip(reason) - def _id(obj): + def _id(obj: _F) -> _F: return obj return _id -def skipUnless(condition, reason): +def skipUnless(condition: bool, reason: str) -> Callable[[_F], _F]: """A decorator to skip a test unless the condition is true.""" if not condition: return skip(reason) - def _id(obj): + def _id(obj: _F) -> _F: return obj return _id @@ -1085,7 +1214,9 @@ class _AssertRaisesContext: This provides compatibility with unittest's assertRaises context manager. """ - def __init__(self, expected, test_case, msg=None): + def __init__( + self, expected: type[BaseException], test_case: TestCase, msg: str | None = None + ) -> None: """Construct an `_AssertRaisesContext`. :param expected: The type of exception to expect. @@ -1095,12 +1226,17 @@ def __init__(self, expected, test_case, msg=None): self.expected = expected self.test_case = test_case self.msg = msg - self.exception = None + self.exception: BaseException | None = None - def __enter__(self): + def __enter__(self) -> "_AssertRaisesContext": return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> bool: if exc_type is None: try: if isinstance(self.expected, tuple): @@ -1137,7 +1273,12 @@ def test_foo(self): exception is raised, an AssertionError will be raised. """ - def __init__(self, exc_type, value_re=None, msg=None): + def __init__( + self, + exc_type: type[BaseException], + value_re: str | None = None, + msg: str | None = None, + ) -> None: """Construct an `ExpectedException`. :param exc_type: The type of exception to expect. @@ -1149,10 +1290,15 @@ def __init__(self, exc_type, value_re=None, msg=None): self.value_re = value_re self.msg = msg - def __enter__(self): + def __enter__(self) -> None: pass - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> bool: if exc_type is None: error_msg = f"{self.exc_type.__name__} not raised." if self.msg: @@ -1162,12 +1308,16 @@ def __exit__(self, exc_type, exc_value, traceback): return False if self.value_re: exception_matcher = MatchesException(self.exc_type, self.value_re) - matcher: Matcher | Annotate + matcher: Matcher[ExcInfo] if self.msg: matcher = Annotate(self.msg, exception_matcher) else: matcher = exception_matcher - mismatch = matcher.match((exc_type, exc_value, traceback)) + # Type narrow: we know exc_type is not None from check above, + # and exc_value must not be None for real exceptions + assert exc_value is not None, "Exception value should not be None" + exc_info_tuple: ExcInfo = (exc_type, exc_value, traceback) + mismatch = matcher.match(exc_info_tuple) if mismatch: raise AssertionError(mismatch.describe()) return True @@ -1180,22 +1330,30 @@ class Nullary: preserves the ``repr()`` of ``f``. """ - def __init__(self, callable_object, *args, **kwargs): + def __init__( + self, callable_object: Callable[..., object], *args: object, **kwargs: object + ) -> None: self._callable_object = callable_object self._args = args self._kwargs = kwargs - def __call__(self): + def __call__(self) -> object: return self._callable_object(*self._args, **self._kwargs) - def __repr__(self): + def __repr__(self) -> str: return repr(self._callable_object) class DecorateTestCaseResult: """Decorate a TestCase and permit customisation of the result for runs.""" - def __init__(self, case, callout, before_run=None, after_run=None): + def __init__( + self, + case: unittest.TestCase, + callout: Callable[[TestResult | None], TestResult], + before_run: Callable[[TestResult], None] | None = None, + after_run: Callable[[TestResult], None] | None = None, + ) -> None: """Construct a DecorateTestCaseResult. :param case: The case to decorate. @@ -1212,7 +1370,9 @@ def __init__(self, case, callout, before_run=None, after_run=None): self.before_run = before_run self.after_run = after_run - def _run(self, result, run_method): + def _run( + self, result: TestResult | None, run_method: Callable[[TestResult], object] + ) -> object: result = self.callout(result) if self.before_run: self.before_run(result) @@ -1222,19 +1382,19 @@ def _run(self, result, run_method): if self.after_run: self.after_run(result) - def run(self, result=None): - self._run(result, self.decorated.run) + def run(self, result: TestResult | None = None) -> object: + return self._run(result, self.decorated.run) - def __call__(self, result=None): - self._run(result, self.decorated) + def __call__(self, result: TestResult | None = None) -> object: + return self._run(result, self.decorated) - def __getattr__(self, name): + def __getattr__(self, name: str) -> object: return getattr(self.decorated, name) - def __delattr__(self, name): + def __delattr__(self, name: str) -> None: delattr(self.decorated, name) - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: object) -> None: if name in ("decorated", "callout", "before_run", "after_run"): self.__dict__[name] = value return diff --git a/testtools/testresult/__init__.py b/testtools/testresult/__init__.py index 104be36c..e8b0b7d5 100644 --- a/testtools/testresult/__init__.py +++ b/testtools/testresult/__init__.py @@ -4,6 +4,8 @@ __all__ = [ "CopyStreamResult", + "DetailsDict", + "ExcInfo", "ExtendedToOriginalDecorator", "ExtendedToStreamDecorator", "MultiTestResult", @@ -28,6 +30,8 @@ from testtools.testresult.real import ( CopyStreamResult, + DetailsDict, + ExcInfo, ExtendedToOriginalDecorator, ExtendedToStreamDecorator, MultiTestResult, diff --git a/testtools/testresult/doubles.py b/testtools/testresult/doubles.py index 762ab257..6e997dce 100644 --- a/testtools/testresult/doubles.py +++ b/testtools/testresult/doubles.py @@ -2,7 +2,11 @@ """Doubles of test result objects, useful for testing unittest code.""" +import datetime +import unittest from collections import namedtuple +from collections.abc import Iterable +from typing import Literal, TypeAlias from testtools.tags import TagContext @@ -13,11 +17,67 @@ "TwistedTestResult", ] +# Convenience namedtuple for status events - defined early for use in LogEvent +_StatusEvent = namedtuple( + "_StatusEvent", + [ + "name", + "test_id", + "test_status", + "test_tags", + "runnable", + "file_name", + "file_bytes", + "eof", + "mime_type", + "route_code", + "timestamp", + ], +) + +# Event type aliases using plain tuples with Literal for event names +# This provides type safety while working with the existing plain tuple code +LogEvent: TypeAlias = ( + tuple[Literal["startTestRun"]] + | tuple[Literal["stopTestRun"]] + | tuple[Literal["startTest"], unittest.TestCase] + | tuple[Literal["stopTest"], unittest.TestCase] + | tuple[Literal["addSuccess"], unittest.TestCase] + | tuple[Literal["addSuccess"], unittest.TestCase, dict[str, object]] + | tuple[ + Literal["addError"], + unittest.TestCase, + tuple[type, Exception, object] | dict[str, object] | None | object, + ] + | tuple[ + Literal["addFailure"], + unittest.TestCase, + tuple[type, Exception, object] | dict[str, object] | None | object, + ] + | tuple[ + Literal["addExpectedFailure"], + unittest.TestCase, + tuple[type, Exception, object] | dict[str, object] | None | object, + ] + | tuple[ + Literal["addSkip"], + unittest.TestCase, + str | dict[str, object] | None, + ] + | tuple[Literal["addUnexpectedSuccess"], unittest.TestCase] + | tuple[Literal["addUnexpectedSuccess"], unittest.TestCase, dict[str, object]] + | tuple[Literal["addDuration"], unittest.TestCase, float] + | tuple[Literal["progress"], int, int] + | tuple[Literal["tags"], Iterable[str], Iterable[str]] + | tuple[Literal["time"], datetime.datetime] + | _StatusEvent +) + class LoggingBase: """Basic support for logging of results.""" - def __init__(self, event_log=None): + def __init__(self, event_log: list[LogEvent] | None = None) -> None: if event_log is None: event_log = [] self._events = event_log @@ -26,131 +86,161 @@ def __init__(self, event_log=None): class Python3TestResult(LoggingBase): """A precisely python 3 like test result, that logs.""" - def __init__(self, event_log=None): + def __init__(self, event_log: list[LogEvent] | None = None) -> None: super().__init__(event_log=event_log) self.shouldStop = False self._was_successful = True self.testsRun = 0 self.failfast = False - self.collectedDurations = [] + self.collectedDurations: list[tuple[unittest.TestCase, float]] = [] - def addError(self, test, err): + def addError( + self, test: unittest.TestCase, err: tuple[type, Exception, object] + ) -> None: self._was_successful = False self._events.append(("addError", test, err)) if self.failfast: self.stop() - def addFailure(self, test, err): + def addFailure( + self, test: unittest.TestCase, err: tuple[type, Exception, object] + ) -> None: self._was_successful = False self._events.append(("addFailure", test, err)) if self.failfast: self.stop() - def addSuccess(self, test): + def addSuccess(self, test: unittest.TestCase) -> None: self._events.append(("addSuccess", test)) - def addExpectedFailure(self, test, err): + def addExpectedFailure( + self, test: unittest.TestCase, err: tuple[type, Exception, object] + ) -> None: self._events.append(("addExpectedFailure", test, err)) - def addSkip(self, test, reason): + def addSkip(self, test: unittest.TestCase, reason: str) -> None: self._events.append(("addSkip", test, reason)) - def addUnexpectedSuccess(self, test): + def addUnexpectedSuccess(self, test: unittest.TestCase) -> None: self._events.append(("addUnexpectedSuccess", test)) if self.failfast: self.stop() - def addDuration(self, test, duration): + def addDuration(self, test: unittest.TestCase, duration: float) -> None: self._events.append(("addDuration", test, duration)) self.collectedDurations.append((test, duration)) - def startTest(self, test): + def startTest(self, test: unittest.TestCase) -> None: self._events.append(("startTest", test)) self.testsRun += 1 - def startTestRun(self): + def startTestRun(self) -> None: self._events.append(("startTestRun",)) - def stop(self): + def stop(self) -> None: self.shouldStop = True - def stopTest(self, test): + def stopTest(self, test: unittest.TestCase) -> None: self._events.append(("stopTest", test)) - def stopTestRun(self): + def stopTestRun(self) -> None: self._events.append(("stopTestRun",)) - def wasSuccessful(self): + def wasSuccessful(self) -> bool: return self._was_successful class ExtendedTestResult(Python3TestResult): """A test result like the proposed extended unittest result API.""" - def __init__(self, event_log=None): + def __init__(self, event_log: list[LogEvent] | None = None) -> None: super().__init__(event_log) self._tags = TagContext() - def addError(self, test, err=None, details=None): + def addError( + self, + test: unittest.TestCase, + err: tuple[type, Exception, object] | None = None, + details: dict[str, object] | None = None, + ) -> None: self._was_successful = False self._events.append(("addError", test, err or details)) - def addFailure(self, test, err=None, details=None): + def addFailure( + self, + test: unittest.TestCase, + err: tuple[type, Exception, object] | None = None, + details: dict[str, object] | None = None, + ) -> None: self._was_successful = False self._events.append(("addFailure", test, err or details)) - def addExpectedFailure(self, test, err=None, details=None): + def addExpectedFailure( + self, + test: unittest.TestCase, + err: tuple[type, Exception, object] | None = None, + details: dict[str, object] | None = None, + ) -> None: self._events.append(("addExpectedFailure", test, err or details)) - def addSkip(self, test, reason=None, details=None): + def addSkip( + self, + test: unittest.TestCase, + reason: str | None = None, + details: dict[str, object] | None = None, + ) -> None: self._events.append(("addSkip", test, reason or details)) - def addSuccess(self, test, details=None): + def addSuccess( + self, test: unittest.TestCase, details: dict[str, object] | None = None + ) -> None: if details: self._events.append(("addSuccess", test, details)) else: self._events.append(("addSuccess", test)) - def addUnexpectedSuccess(self, test, details=None): + def addUnexpectedSuccess( + self, test: unittest.TestCase, details: dict[str, object] | None = None + ) -> None: self._was_successful = False if details is not None: self._events.append(("addUnexpectedSuccess", test, details)) else: self._events.append(("addUnexpectedSuccess", test)) - def addDuration(self, test, duration): + def addDuration(self, test: unittest.TestCase, duration: float) -> None: self._events.append(("addDuration", test, duration)) - def progress(self, offset, whence): + def progress(self, offset: int, whence: int) -> None: self._events.append(("progress", offset, whence)) - def startTestRun(self): + def startTestRun(self) -> None: super().startTestRun() self._was_successful = True self._tags = TagContext() - def startTest(self, test): + def startTest(self, test: unittest.TestCase) -> None: super().startTest(test) self._tags = TagContext(self._tags) - def stopTest(self, test): + def stopTest(self, test: unittest.TestCase) -> None: # NOTE: In Python 3.12.1 skipped tests may not call startTest() if self._tags is not None and self._tags.parent is not None: self._tags = self._tags.parent super().stopTest(test) @property - def current_tags(self): + def current_tags(self) -> set[str]: return self._tags.get_current_tags() - def tags(self, new_tags, gone_tags): + def tags(self, new_tags: Iterable[str], gone_tags: Iterable[str]) -> None: self._tags.change_tags(new_tags, gone_tags) self._events.append(("tags", new_tags, gone_tags)) - def time(self, time): + def time(self, time: datetime.datetime) -> None: self._events.append(("time", time)) - def wasSuccessful(self): + def wasSuccessful(self) -> bool: return self._was_successful @@ -160,42 +250,46 @@ class TwistedTestResult(LoggingBase): Used to ensure that we can use ``trial`` as a test runner. """ - def __init__(self, event_log=None): + def __init__(self, event_log: list[LogEvent] | None = None) -> None: super().__init__(event_log=event_log) self._was_successful = True self.testsRun = 0 - def startTest(self, test): + def startTest(self, test: unittest.TestCase) -> None: self.testsRun += 1 self._events.append(("startTest", test)) - def stopTest(self, test): + def stopTest(self, test: unittest.TestCase) -> None: self._events.append(("stopTest", test)) - def addSuccess(self, test): + def addSuccess(self, test: unittest.TestCase) -> None: self._events.append(("addSuccess", test)) - def addError(self, test, error): + def addError(self, test: unittest.TestCase, error: object) -> None: self._was_successful = False self._events.append(("addError", test, error)) - def addFailure(self, test, error): + def addFailure(self, test: unittest.TestCase, error: object) -> None: self._was_successful = False self._events.append(("addFailure", test, error)) - def addExpectedFailure(self, test, failure, todo=None): + def addExpectedFailure( + self, test: unittest.TestCase, failure: object, todo: object | None = None + ) -> None: self._events.append(("addExpectedFailure", test, failure)) - def addUnexpectedSuccess(self, test, todo=None): + def addUnexpectedSuccess( + self, test: unittest.TestCase, todo: object | None = None + ) -> None: self._events.append(("addUnexpectedSuccess", test)) - def addSkip(self, test, reason): + def addSkip(self, test: unittest.TestCase, reason: str) -> None: self._events.append(("addSkip", test, reason)) - def wasSuccessful(self): + def wasSuccessful(self) -> bool: return self._was_successful - def done(self): + def done(self) -> None: pass @@ -205,25 +299,25 @@ class StreamResult(LoggingBase): All events are logged to _events. """ - def startTestRun(self): + def startTestRun(self) -> None: self._events.append(("startTestRun",)) - def stopTestRun(self): + def stopTestRun(self) -> None: self._events.append(("stopTestRun",)) def status( self, - test_id=None, - test_status=None, - test_tags=None, - runnable=True, - file_name=None, - file_bytes=None, - eof=False, - mime_type=None, - route_code=None, - timestamp=None, - ): + test_id: str | None = None, + test_status: str | None = None, + test_tags: set[str] | None = None, + runnable: bool = True, + file_name: str | None = None, + file_bytes: bytes | None = None, + eof: bool = False, + mime_type: str | None = None, + route_code: str | None = None, + timestamp: datetime.datetime | None = None, + ) -> None: self._events.append( _StatusEvent( "status", @@ -239,22 +333,3 @@ def status( timestamp, ) ) - - -# Convenience for easier access to status fields -_StatusEvent = namedtuple( - "_StatusEvent", - [ - "name", - "test_id", - "test_status", - "test_tags", - "runnable", - "file_name", - "file_bytes", - "eof", - "mime_type", - "route_code", - "timestamp", - ], -) diff --git a/testtools/testresult/real.py b/testtools/testresult/real.py index c80c0270..e0c072a4 100644 --- a/testtools/testresult/real.py +++ b/testtools/testresult/real.py @@ -27,10 +27,23 @@ import math import sys import threading +import types import unittest +from collections.abc import Callable, Iterable, Sequence from operator import methodcaller from queue import Queue -from typing import ClassVar, TypeAlias +from typing import ( + TYPE_CHECKING, + ClassVar, + Protocol, + TextIO, + TypeAlias, + TypedDict, + TypeVar, +) + +if TYPE_CHECKING: + from testtools.testcase import PlaceHolder from testtools.content import ( Content, @@ -40,12 +53,68 @@ from testtools.content_type import ContentType from testtools.tags import TagContext + +class _OnTestCallback(Protocol): + """Protocol for the on_test callback in TestByTestResult.""" + + def __call__( + self, + *, + test: unittest.TestCase, + status: str | None, + start_time: datetime.datetime, + stop_time: datetime.datetime | None, + tags: set[str], + details: "DetailsDict | None", + ) -> None: ... + + # Type for event dicts that go into the queue -EventDict: TypeAlias = "dict[str, str | bytes | bool | None | StreamResult]" +class _StartStopEventDict(TypedDict): + event: str # "startTestRun" or "stopTestRun" + result: "StreamResult" + + +class _StatusEventDict(TypedDict, total=False): + event: str # "status" + test_id: str | None + test_status: str | None + test_tags: set[str] | None + runnable: bool + file_name: str | None + file_bytes: bytes | None + eof: bool + mime_type: str | None + route_code: str | None + timestamp: datetime.datetime | None + + +EventDict: TypeAlias = _StartStopEventDict | _StatusEventDict + +# Type for exc_info tuples from sys.exc_info() +ExcInfo: TypeAlias = tuple[ + type[BaseException], BaseException, types.TracebackType | None +] + +# Type for details dict (mapping from names to Content objects) +DetailsDict: TypeAlias = dict[str, Content] + + +# Type for test dict with test information +class TestDict(TypedDict): + id: str + tags: set[str] + details: DetailsDict + status: str + timestamps: list[datetime.datetime | None] + + +# Protocol for test resources with an id() method +class TestResourceProtocol(Protocol): + def id(self) -> str: ... -# circular import -# from testtools.testcase import PlaceHolder -PlaceHolder = None + +# PlaceHolder is imported at runtime in to_test_case() to avoid circular import # From http://docs.python.org/library/datetime.html _ZERO = datetime.timedelta(0) @@ -54,13 +123,13 @@ class UTC(datetime.tzinfo): """UTC""" - def utcoffset(self, dt): + def utcoffset(self, dt: datetime.datetime | None) -> datetime.timedelta: return _ZERO - def tzname(self, dt): + def tzname(self, dt: datetime.datetime | None) -> str: return "UTC" - def dst(self, dt): + def dst(self, dt: datetime.datetime | None) -> datetime.timedelta: return _ZERO @@ -84,7 +153,7 @@ class TestResult(unittest.TestResult): :ivar skip_reasons: A dict of skip-reasons -> list of tests. See addSkip. """ - def __init__(self, failfast=False, tb_locals=False): + def __init__(self, failfast: bool = False, tb_locals: bool = False) -> None: # startTestRun resets all attributes, and older clients don't know to # call startTestRun, so it is called once here. # Because subclasses may reasonably not expect this, we call the @@ -93,7 +162,12 @@ def __init__(self, failfast=False, tb_locals=False): self.tb_locals = tb_locals TestResult.startTestRun(self) - def addExpectedFailure(self, test, err=None, details=None): + def addExpectedFailure( + self, + test: unittest.TestCase, + err: ExcInfo | tuple[None, None, None] = (None, None, None), + details: DetailsDict | None = None, + ) -> None: """Called when a test has failed in an expected manner. Like with addSuccess and addError, testStopped should still be called. @@ -107,29 +181,50 @@ def addExpectedFailure(self, test, err=None, details=None): (test, self._err_details_to_string(test, err, details)) ) - def addError(self, test, err=None, details=None): + def addError( # type: ignore[override] + self, + test: unittest.TestCase, + err: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: """Called when an error has occurred. 'err' is a tuple of values as returned by sys.exc_info(). :param details: Alternative way to supply details about the outcome. see the class docstring for more information. + + Note: This extends unittest.TestResult by making err optional and + adding details parameter. """ - self.errors.append((test, self._err_details_to_string(test, err, details))) + self.errors.append((test, self._err_details_to_string(test, err, details))) # type: ignore[arg-type] if self.failfast: self.stop() - def addFailure(self, test, err=None, details=None): + def addFailure( # type: ignore[override] + self, + test: unittest.TestCase, + err: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: """Called when an error has occurred. 'err' is a tuple of values as returned by sys.exc_info(). :param details: Alternative way to supply details about the outcome. see the class docstring for more information. + + Note: This extends unittest.TestResult by making err optional and + adding details parameter. """ - self.failures.append((test, self._err_details_to_string(test, err, details))) + self.failures.append((test, self._err_details_to_string(test, err, details))) # type: ignore[arg-type] if self.failfast: self.stop() - def addSkip(self, test, reason=None, details=None): + def addSkip( + self, + test: unittest.TestCase, + reason: str | None = None, + details: DetailsDict | None = None, + ) -> None: """Called when a test has been skipped rather than running. Like with addSuccess and addError, testStopped should still be called. @@ -146,24 +241,29 @@ def addSkip(self, test, reason=None, details=None): :return: None """ if reason is None: - reason = details.get("reason") - if reason is None: + assert details is not None + reason_content = details.get("reason") + if reason_content is None: reason = "No reason given" else: - reason = reason.as_text() + reason = reason_content.as_text() skip_list = self.skip_reasons.setdefault(reason, []) skip_list.append(test) - def addSuccess(self, test, details=None): + def addSuccess( + self, test: unittest.TestCase, details: DetailsDict | None = None + ) -> None: """Called when a test succeeded.""" - def addUnexpectedSuccess(self, test, details=None): + def addUnexpectedSuccess( + self, test: unittest.TestCase, details: DetailsDict | None = None + ) -> None: """Called when a test was expected to fail, but succeed.""" self.unexpectedSuccesses.append(test) if self.failfast: self.stop() - def addDuration(self, test, duration): + def addDuration(self, test: unittest.TestCase, duration: float) -> None: """Called to add a test duration. :param test: The test that completed. @@ -171,7 +271,7 @@ def addDuration(self, test, duration): """ self.collectedDurations.append((test, duration)) - def wasSuccessful(self): + def wasSuccessful(self) -> bool: """Has this result been successful so far? If there have been any errors, failures or unexpected successes, @@ -183,18 +283,24 @@ def wasSuccessful(self): """ return not (self.errors or self.failures or self.unexpectedSuccesses) - def _err_details_to_string(self, test, err=None, details=None): + def _err_details_to_string( + self, + test: unittest.TestCase, + err: ExcInfo | tuple[None, None, None] = (None, None, None), + details: DetailsDict | None = None, + ) -> str: """Convert an error in exc_info form or a contents dict to a string.""" - if err is not None: + if err != (None, None, None) and err is not None: return TracebackContent(err, test, capture_locals=self.tb_locals).as_text() + assert details is not None return _details_to_str(details, special="traceback") - def _exc_info_to_unicode(self, err, test): + def _exc_info_to_unicode(self, err: ExcInfo, test: unittest.TestCase) -> str: # Deprecated. Only present because subunit upcalls to it. See # . return TracebackContent(err, test).as_text() - def _now(self): + def _now(self) -> datetime.datetime: """Return the current 'test time'. If the time() method has not been called, this is equivalent to @@ -206,7 +312,7 @@ def _now(self): else: return self.__now - def startTestRun(self): + def startTestRun(self) -> None: """Called before a test run starts. New in Python 2.7. The testtools version resets the result to a @@ -217,40 +323,40 @@ def startTestRun(self): failfast = self.failfast tb_locals = self.tb_locals super().__init__() - self.skip_reasons = {} - self.__now = None + self.skip_reasons: dict[str, list[unittest.TestCase]] = {} + self.__now: datetime.datetime | None = None self._tags = TagContext() # -- Start: As per python 2.7 -- - self.expectedFailures = [] - self.unexpectedSuccesses = [] + self.expectedFailures: list[tuple[unittest.TestCase, str]] = [] + self.unexpectedSuccesses: list[unittest.TestCase] = [] self.failfast = failfast # -- End: As per python 2.7 -- self.tb_locals = tb_locals # -- Python 3.12 - self.collectedDurations = [] + self.collectedDurations: list[tuple[unittest.TestCase, float]] = [] - def stopTestRun(self): + def stopTestRun(self) -> None: """Called after a test run completes New in python 2.7 """ - def startTest(self, test): + def startTest(self, test: unittest.TestCase) -> None: super().startTest(test) self._tags = TagContext(self._tags) - def stopTest(self, test): + def stopTest(self, test: unittest.TestCase) -> None: # NOTE: In Python 3.12.1 skipped tests may not call startTest() if self._tags is not None and self._tags.parent is not None: self._tags = self._tags.parent super().stopTest(test) @property - def current_tags(self): + def current_tags(self) -> set[str]: """The currently set tags.""" return self._tags.get_current_tags() - def tags(self, new_tags, gone_tags): + def tags(self, new_tags: Iterable[str], gone_tags: Iterable[str]) -> None: """Add and remove tags from the test. :param new_tags: A set of tags to be added to the stream. @@ -258,7 +364,7 @@ def tags(self, new_tags, gone_tags): """ self._tags.change_tags(new_tags, gone_tags) - def time(self, a_datetime): + def time(self, a_datetime: datetime.datetime | None) -> None: """Provide a timestamp to represent the current time. This is useful when test activity is time delayed, or happening @@ -273,7 +379,7 @@ def time(self, a_datetime): """ self.__now = a_datetime - def done(self): + def done(self) -> None: """Called when the test runner is done. deprecated in favour of stopTestRun. @@ -362,14 +468,14 @@ class StreamResult: as a base class regardless of intent. """ - def startTestRun(self): + def startTestRun(self) -> None: """Start a test run. This will prepare the test result to process results (which might imply connecting to a database or remote machine). """ - def stopTestRun(self): + def stopTestRun(self) -> None: """Stop a test run. This informs the result that no more test updates will be received. At @@ -449,7 +555,10 @@ def status( """ -def _strict_map(function, *sequences): +_T = TypeVar("_T") + + +def _strict_map(function: Callable[..., _T], *sequences: Sequence[object]) -> list[_T]: return list(map(function, *sequences)) @@ -461,21 +570,59 @@ class CopyStreamResult(StreamResult): For TestResult the equivalent class was ``MultiTestResult``. """ - def __init__(self, targets): + def __init__(self, targets: list[StreamResult]) -> None: super().__init__() self.targets = targets - def startTestRun(self): + def startTestRun(self) -> None: super().startTestRun() _strict_map(methodcaller("startTestRun"), self.targets) - def stopTestRun(self): + def stopTestRun(self) -> None: super().stopTestRun() _strict_map(methodcaller("stopTestRun"), self.targets) - def status(self, *args, **kwargs): - super().status(*args, **kwargs) - _strict_map(methodcaller("status", *args, **kwargs), self.targets) + def status( + self, + test_id: str | None = None, + test_status: str | None = None, + test_tags: set[str] | None = None, + runnable: bool = True, + file_name: str | None = None, + file_bytes: bytes | None = None, + eof: bool = False, + mime_type: str | None = None, + route_code: str | None = None, + timestamp: datetime.datetime | None = None, + ) -> None: + super().status( + test_id=test_id, + test_status=test_status, + test_tags=test_tags, + runnable=runnable, + file_name=file_name, + file_bytes=file_bytes, + eof=eof, + mime_type=mime_type, + route_code=route_code, + timestamp=timestamp, + ) + _strict_map( + methodcaller( + "status", + test_id=test_id, + test_status=test_status, + test_tags=test_tags, + runnable=runnable, + file_name=file_name, + file_bytes=file_bytes, + eof=eof, + mime_type=mime_type, + route_code=route_code, + timestamp=timestamp, + ), + self.targets, + ) class StreamFailFast(StreamResult): @@ -487,22 +634,22 @@ def do_something(): pass """ - def __init__(self, on_error): + def __init__(self, on_error: Callable[[], None]) -> None: self.on_error = on_error def status( self, - test_id=None, - test_status=None, - test_tags=None, - runnable=True, - file_name=None, - file_bytes=None, - eof=False, - mime_type=None, - route_code=None, - timestamp=None, - ): + test_id: str | None = None, + test_status: str | None = None, + test_tags: set[str] | None = None, + runnable: bool = True, + file_name: str | None = None, + file_bytes: bytes | None = None, + eof: bool = False, + mime_type: str | None = None, + route_code: str | None = None, + timestamp: datetime.datetime | None = None, + ) -> None: if test_status in ("uxsuccess", "fail"): self.on_error() @@ -538,9 +685,11 @@ class StreamResultRouter(StreamResult): the behaviour is undefined. Only a single route is chosen for any event. """ - _policies: ClassVar[dict] = {} + _policies: ClassVar[dict[str, Callable[..., None]]] = {} - def __init__(self, fallback=None, do_start_stop_run=True): + def __init__( + self, fallback: StreamResult | None = None, do_start_stop_run: bool = True + ) -> None: """Construct a StreamResultRouter with optional fallback. :param fallback: A StreamResult to forward events to when no route @@ -549,21 +698,21 @@ def __init__(self, fallback=None, do_start_stop_run=True): stopTestRun onto the fallback. """ self.fallback = fallback - self._route_code_prefixes = {} - self._test_ids = {} + self._route_code_prefixes: dict[str, tuple[StreamResult, bool]] = {} + self._test_ids: dict[str | None, StreamResult] = {} # Records sinks that should have do_start_stop_run called on them. - self._sinks = [] + self._sinks: list[StreamResult] = [] if do_start_stop_run and fallback: self._sinks.append(fallback) self._in_run = False - def startTestRun(self): + def startTestRun(self) -> None: super().startTestRun() for sink in self._sinks: sink.startTestRun() self._in_run = True - def stopTestRun(self): + def stopTestRun(self) -> None: super().stopTestRun() for sink in self._sinks: sink.stopTestRun() @@ -571,22 +720,23 @@ def stopTestRun(self): def status( self, - test_id=None, - test_status=None, - test_tags=None, - runnable=True, - file_name=None, - file_bytes=None, - eof=False, - mime_type=None, - route_code=None, - timestamp=None, - ): + test_id: str | None = None, + test_status: str | None = None, + test_tags: set[str] | None = None, + runnable: bool = True, + file_name: str | None = None, + file_bytes: bytes | None = None, + eof: bool = False, + mime_type: str | None = None, + route_code: str | None = None, + timestamp: datetime.datetime | None = None, + ) -> None: # route_code and test_id are already available as parameters + target: StreamResult | None if route_code is not None: prefix = route_code.split("/")[0] else: - prefix = route_code + prefix = None if prefix in self._route_code_prefixes: target, consume_route = self._route_code_prefixes[prefix] if route_code is not None and consume_route: @@ -598,6 +748,10 @@ def status( target = self._test_ids[test_id] else: target = self.fallback + if target is None: + raise Exception( + f"No route found for test_id={test_id!r}, route_code={route_code!r}" + ) target.status( test_id=test_id, test_status=test_status, @@ -611,7 +765,13 @@ def status( timestamp=timestamp, ) - def add_rule(self, sink, policy, do_start_stop_run=False, **policy_args): + def add_rule( + self, + sink: StreamResult, + policy: str, + do_start_stop_run: bool = False, + **policy_args: object, + ) -> None: """Add a rule to route events to sink when they match a given policy. :param sink: A StreamResult to receive events. @@ -640,14 +800,16 @@ def add_rule(self, sink, policy, do_start_stop_run=False, **policy_args): if self._in_run: sink.startTestRun() - def _map_route_code_prefix(self, sink, route_prefix, consume_route=False): + def _map_route_code_prefix( + self, sink: StreamResult, route_prefix: str, consume_route: bool = False + ) -> None: if "/" in route_prefix: raise TypeError(f"{route_prefix!r} is more than one route step long") self._route_code_prefixes[route_prefix] = (sink, consume_route) _policies["route_code_prefix"] = _map_route_code_prefix - def _map_test_id(self, sink, test_id): + def _map_test_id(self, sink: StreamResult, test_id: str | None) -> None: self._test_ids[test_id] = sink _policies["test_id"] = _map_test_id @@ -656,7 +818,12 @@ def _map_test_id(self, sink, test_id): class StreamTagger(CopyStreamResult): """Adds or discards tags from StreamResult events.""" - def __init__(self, targets, add=None, discard=None): + def __init__( + self, + targets: list[StreamResult], + add: set[str] | None = None, + discard: set[str] | None = None, + ) -> None: """Create a StreamTagger. :param targets: A list of targets to forward events onto. @@ -668,18 +835,47 @@ def __init__(self, targets, add=None, discard=None): self.add = frozenset(add or ()) self.discard = frozenset(discard or ()) - def status(self, *args, **kwargs): - test_tags = kwargs.get("test_tags") or set() - test_tags.update(self.add) - test_tags.difference_update(self.discard) - kwargs["test_tags"] = test_tags or None - super().status(*args, **kwargs) + def status( + self, + test_id: str | None = None, + test_status: str | None = None, + test_tags: set[str] | None = None, + runnable: bool = True, + file_name: str | None = None, + file_bytes: bytes | None = None, + eof: bool = False, + mime_type: str | None = None, + route_code: str | None = None, + timestamp: datetime.datetime | None = None, + ) -> None: + tags = test_tags or set() + tags.update(self.add) + tags.difference_update(self.discard) + super().status( + test_id=test_id, + test_status=test_status, + test_tags=tags or None, + runnable=runnable, + file_name=file_name, + file_bytes=file_bytes, + eof=eof, + mime_type=mime_type, + route_code=route_code, + timestamp=timestamp, + ) class _TestRecord: """Representation of a test.""" - def __init__(self, id, tags, details, status, timestamps): + def __init__( + self, + id: str, + tags: set[str], + details: dict[str, Content], + status: str, + timestamps: tuple[datetime.datetime | None, datetime.datetime | None], + ) -> None: # The test id. self.id = id @@ -698,7 +894,7 @@ def __init__(self, id, tags, details, status, timestamps): self.timestamps = timestamps @classmethod - def create(cls, test_id, timestamp): + def create(cls, test_id: str, timestamp: datetime.datetime | None) -> "_TestRecord": return cls( id=test_id, tags=set(), @@ -707,18 +903,40 @@ def create(cls, test_id, timestamp): timestamps=(timestamp, None), ) - def set(self, *args, **kwargs): - if args: - setattr(self, args[0], args[1]) - for key, value in kwargs.items(): - setattr(self, key, value) + def set( + self, + attr_name: str | None = None, + attr_value: ( + str + | set[str] + | dict[str, Content] + | tuple[datetime.datetime | None, datetime.datetime | None] + | None + ) = None, + *, + timestamps: tuple[datetime.datetime | None, datetime.datetime | None] + | None = None, + status: str | None = None, + tags: set[str] | None = None, + details: dict[str, Content] | None = None, + ) -> "_TestRecord": + if attr_name is not None: + setattr(self, attr_name, attr_value) + if timestamps is not None: + self.timestamps = timestamps + if status is not None: + self.status = status + if tags is not None: + self.tags = tags + if details is not None: + self.details = details return self - def transform(self, data, value): + def transform(self, data: list[str], value: Content) -> "_TestRecord": getattr(self, data[0])[data[1]] = value return self - def to_dict(self): + def to_dict(self) -> TestDict: """Convert record into a "test dict". A "test dict" is a concept used in other parts of the code-base. It @@ -743,7 +961,7 @@ def to_dict(self): "timestamps": list(self.timestamps), } - def got_timestamp(self, timestamp): + def got_timestamp(self, timestamp: datetime.datetime | None) -> "_TestRecord": """Called when we receive a timestamp. This will always update the second element of the 'timestamps' tuple. @@ -751,7 +969,9 @@ def got_timestamp(self, timestamp): """ return self.set(timestamps=(self.timestamps[0], timestamp)) - def got_file(self, file_name, file_bytes, mime_type=None): + def got_file( + self, file_name: str, file_bytes: bytes, mime_type: str | None = None + ) -> "_TestRecord": """Called when we receive file information. ``mime_type`` is only used when this is the first time we've seen data @@ -766,30 +986,31 @@ def got_file(self, file_name, file_bytes, mime_type=None): ["details", file_name], Content(content_type, lambda: content_bytes) ) - case.details[file_name]._get_bytes().append(file_bytes) + # _get_bytes() returns the list we created in the lambda above + bytes_list = case.details[file_name]._get_bytes() + assert isinstance(bytes_list, list) + bytes_list.append(file_bytes) return case - def to_test_case(self): + def to_test_case(self) -> "PlaceHolder": """Convert into a TestCase object. :return: A PlaceHolder test object. """ # Circular import. - global PlaceHolder - if PlaceHolder is None: - from testtools.testcase import PlaceHolder + from testtools.testcase import PlaceHolder + outcome = _status_map[self.status] - assert PlaceHolder is not None return PlaceHolder( self.id, outcome=outcome, details=self.details, - tags=self.tags, + tags=frozenset(self.tags), timestamps=self.timestamps, ) -def _make_content_type(mime_type=None): +def _make_content_type(mime_type: str | None = None) -> ContentType: """Return ContentType for a given mime type. testtools was emitting a bad encoding, and this works around it. @@ -852,7 +1073,7 @@ class _StreamToTestRecord(StreamResult): Only the most recent tags observed in the stream are reported. """ - def __init__(self, on_test): + def __init__(self, on_test: Callable[[_TestRecord], None]) -> None: """Create a _StreamToTestRecord calling on_test on test completions. :param on_test: A callback that accepts one parameter: @@ -861,23 +1082,23 @@ def __init__(self, on_test): super().__init__() self.on_test = on_test - def startTestRun(self): + def startTestRun(self) -> None: super().startTestRun() - self._inprogress = {} + self._inprogress: dict[tuple[str, str | None], _TestRecord] = {} def status( self, - test_id=None, - test_status=None, - test_tags=None, - runnable=True, - file_name=None, - file_bytes=None, - eof=False, - mime_type=None, - route_code=None, - timestamp=None, - ): + test_id: str | None = None, + test_status: str | None = None, + test_tags: set[str] | None = None, + runnable: bool = True, + file_name: str | None = None, + file_bytes: bytes | None = None, + eof: bool = False, + mime_type: str | None = None, + route_code: str | None = None, + timestamp: datetime.datetime | None = None, + ) -> None: super().status( test_id, test_status, @@ -912,14 +1133,14 @@ def status( def _update_case( self, - case, - test_status=None, - test_tags=None, - file_name=None, - file_bytes=None, - mime_type=None, - timestamp=None, - ): + case: _TestRecord, + test_status: str | None = None, + test_tags: set[str] | None = None, + file_name: str | None = None, + file_bytes: bytes | None = None, + mime_type: str | None = None, + timestamp: datetime.datetime | None = None, + ) -> _TestRecord: if test_status is not None: case = case.set(status=test_status) @@ -933,15 +1154,20 @@ def _update_case( return case - def stopTestRun(self): + def stopTestRun(self) -> None: super().stopTestRun() while self._inprogress: case = self._inprogress.popitem()[1] self.on_test(case.got_timestamp(None)) - def _ensure_key(self, test_id, route_code, timestamp): + def _ensure_key( + self, + test_id: str | None, + route_code: str | None, + timestamp: datetime.datetime | None, + ) -> tuple[str, str | None] | None: if test_id is None: - return + return None key = (test_id, route_code) if key not in self._inprogress: self._inprogress[key] = _TestRecord.create(test_id, timestamp) @@ -976,7 +1202,7 @@ class StreamToDict(StreamResult): # XXX: Alternative simplification is to extract a StreamAdapter base # class, and have this inherit from that. - def __init__(self, on_test): + def __init__(self, on_test: Callable[[TestDict], None]) -> None: """Create a _StreamToTestRecord calling on_test on test completions. :param on_test: A callback that accepts one parameter: @@ -989,34 +1215,73 @@ def __init__(self, on_test): # the boilerplate by subclassing _StreamToTestRecord. self.on_test = on_test - def _handle_test(self, test_record): + def _handle_test(self, test_record: _TestRecord) -> None: self.on_test(test_record.to_dict()) - def startTestRun(self): + def startTestRun(self) -> None: super().startTestRun() self._hook.startTestRun() - def status(self, *args, **kwargs): - super().status(*args, **kwargs) - self._hook.status(*args, **kwargs) + def status( + self, + test_id: str | None = None, + test_status: str | None = None, + test_tags: set[str] | None = None, + runnable: bool = True, + file_name: str | None = None, + file_bytes: bytes | None = None, + eof: bool = False, + mime_type: str | None = None, + route_code: str | None = None, + timestamp: datetime.datetime | None = None, + ) -> None: + super().status( + test_id=test_id, + test_status=test_status, + test_tags=test_tags, + runnable=runnable, + file_name=file_name, + file_bytes=file_bytes, + eof=eof, + mime_type=mime_type, + route_code=route_code, + timestamp=timestamp, + ) + self._hook.status( + test_id=test_id, + test_status=test_status, + test_tags=test_tags, + runnable=runnable, + file_name=file_name, + file_bytes=file_bytes, + eof=eof, + mime_type=mime_type, + route_code=route_code, + timestamp=timestamp, + ) - def stopTestRun(self): + def stopTestRun(self) -> None: super().stopTestRun() self._hook.stopTestRun() -def test_dict_to_case(test_dict): +def test_dict_to_case(test_dict: TestDict) -> "PlaceHolder": """Convert a test dict into a TestCase object. :param test_dict: A test dict as generated by StreamToDict. :return: A PlaceHolder test object. """ + ts_list = test_dict["timestamps"] + timestamps: tuple[datetime.datetime | None, datetime.datetime | None] = ( + ts_list[0] if len(ts_list) > 0 else None, + ts_list[1] if len(ts_list) > 1 else None, + ) return _TestRecord( id=test_dict["id"], tags=test_dict["tags"], details=test_dict["details"], status=test_dict["status"], - timestamps=tuple(test_dict["timestamps"]), + timestamps=timestamps, ).to_test_case() @@ -1028,10 +1293,10 @@ class StreamSummary(StreamResult): runner. """ - def __init__(self): + def __init__(self) -> None: super().__init__() self._hook = _StreamToTestRecord(self._gather_test) - self._handle_status = { + self._handle_status: dict[str, Callable[[_TestRecord], None]] = { "success": self._success, "skip": self._skip, "exists": self._exists, @@ -1042,25 +1307,59 @@ def __init__(self): "inprogress": self._incomplete, } - def startTestRun(self): + def startTestRun(self) -> None: super().startTestRun() - self.failures = [] - self.errors = [] - self.testsRun = 0 - self.skipped = [] - self.expectedFailures = [] - self.unexpectedSuccesses = [] + self.failures: list[tuple[PlaceHolder, str]] = [] + self.errors: list[tuple[PlaceHolder, str]] = [] + self.testsRun: int = 0 + self.skipped: list[tuple[PlaceHolder, str]] = [] + self.expectedFailures: list[tuple[PlaceHolder, str]] = [] + self.unexpectedSuccesses: list[PlaceHolder] = [] self._hook.startTestRun() - def status(self, *args, **kwargs): - super().status(*args, **kwargs) - self._hook.status(*args, **kwargs) + def status( + self, + test_id: str | None = None, + test_status: str | None = None, + test_tags: set[str] | None = None, + runnable: bool = True, + file_name: str | None = None, + file_bytes: bytes | None = None, + eof: bool = False, + mime_type: str | None = None, + route_code: str | None = None, + timestamp: datetime.datetime | None = None, + ) -> None: + super().status( + test_id=test_id, + test_status=test_status, + test_tags=test_tags, + runnable=runnable, + file_name=file_name, + file_bytes=file_bytes, + eof=eof, + mime_type=mime_type, + route_code=route_code, + timestamp=timestamp, + ) + self._hook.status( + test_id=test_id, + test_status=test_status, + test_tags=test_tags, + runnable=runnable, + file_name=file_name, + file_bytes=file_bytes, + eof=eof, + mime_type=mime_type, + route_code=route_code, + timestamp=timestamp, + ) - def stopTestRun(self): + def stopTestRun(self) -> None: super().stopTestRun() self._hook.stopTestRun() - def wasSuccessful(self): + def wasSuccessful(self) -> bool: """Return False if any failure has occurred. Note that incomplete tests can only be detected when stopTestRun is @@ -1068,38 +1367,41 @@ def wasSuccessful(self): """ return not self.failures and not self.errors - def _gather_test(self, test_record): + def _gather_test(self, test_record: _TestRecord) -> None: if test_record.status == "exists": return self.testsRun += 1 - case = test_record.to_test_case() - self._handle_status[test_record.status](case) + self._handle_status[test_record.status](test_record) - def _incomplete(self, case): - self.errors.append((case, "Test did not complete")) + def _incomplete(self, test_record: _TestRecord) -> None: + self.errors.append((test_record.to_test_case(), "Test did not complete")) - def _success(self, case): + def _success(self, test_record: _TestRecord) -> None: pass - def _skip(self, case): + def _skip(self, test_record: _TestRecord) -> None: + case = test_record.to_test_case() if "reason" not in case._details: reason = "Unknown" else: reason = case._details["reason"].as_text() self.skipped.append((case, reason)) - def _exists(self, case): + def _exists(self, test_record: _TestRecord) -> None: pass - def _fail(self, case): + def _fail(self, test_record: _TestRecord) -> None: + case = test_record.to_test_case() message = _details_to_str(case._details, special="traceback") self.errors.append((case, message)) - def _xfail(self, case): + def _xfail(self, test_record: _TestRecord) -> None: + case = test_record.to_test_case() message = _details_to_str(case._details, special="traceback") self.expectedFailures.append((case, message)) - def _uxsuccess(self, case): + def _uxsuccess(self, test_record: _TestRecord) -> None: + case = test_record.to_test_case() case._outcome = "addUnexpectedSuccess" self.unexpectedSuccesses.append(case) @@ -1112,9 +1414,9 @@ class TestControl: each test and if set stop dispatching any new tests and return. """ - def __init__(self): + def __init__(self) -> None: super().__init__() - self.shouldStop = False + self.shouldStop: bool = False def stop(self) -> None: """Indicate that tests should stop running.""" @@ -1124,85 +1426,111 @@ def stop(self) -> None: class MultiTestResult(TestResult): """A test result that dispatches to many test results.""" - def __init__(self, *results): + def __init__(self, *results: TestResult) -> None: # Setup _results first, as the base class __init__ assigns to failfast. self._results = list(map(ExtendedToOriginalDecorator, results)) super().__init__() - def __repr__(self): + def __repr__(self) -> str: return "<{} ({})>".format( self.__class__.__name__, ", ".join(map(repr, self._results)) ) - def _dispatch(self, message, *args, **kwargs): + def _dispatch( + self, message: str, *args: object, **kwargs: object + ) -> tuple[object, ...]: return tuple( getattr(result, message)(*args, **kwargs) for result in self._results ) - def _get_failfast(self): + def _get_failfast(self) -> bool: return getattr(self._results[0], "failfast", False) - def _set_failfast(self, value): + def _set_failfast(self, value: bool) -> None: self._dispatch("__setattr__", "failfast", value) failfast = property(_get_failfast, _set_failfast) - def _get_shouldStop(self): + def _get_shouldStop(self) -> bool: return any(self._dispatch("__getattr__", "shouldStop")) - def _set_shouldStop(self, value): + def _set_shouldStop(self, value: bool) -> None: # Called because we subclass TestResult. Probably should not do that. pass shouldStop = property(_get_shouldStop, _set_shouldStop) - def startTest(self, test): + def startTest(self, test: unittest.TestCase) -> None: super().startTest(test) - return self._dispatch("startTest", test) + self._dispatch("startTest", test) - def stop(self): - return self._dispatch("stop") + def stop(self) -> None: + self._dispatch("stop") - def stopTest(self, test): + def stopTest(self, test: unittest.TestCase) -> None: super().stopTest(test) - return self._dispatch("stopTest", test) + self._dispatch("stopTest", test) - def addError(self, test, error=None, details=None): - return self._dispatch("addError", test, error, details=details) + def addError( # type: ignore[override] + self, + test: unittest.TestCase, + error: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: + self._dispatch("addError", test, error, details=details) - def addExpectedFailure(self, test, err=None, details=None): - return self._dispatch("addExpectedFailure", test, err, details=details) + def addExpectedFailure( # type: ignore[override] + self, + test: unittest.TestCase, + err: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: + self._dispatch("addExpectedFailure", test, err, details=details) - def addFailure(self, test, err=None, details=None): - return self._dispatch("addFailure", test, err, details=details) + def addFailure( # type: ignore[override] + self, + test: unittest.TestCase, + err: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: + self._dispatch("addFailure", test, err, details=details) - def addSkip(self, test, reason=None, details=None): - return self._dispatch("addSkip", test, reason, details=details) + def addSkip( + self, + test: unittest.TestCase, + reason: str | None = None, + details: DetailsDict | None = None, + ) -> None: + self._dispatch("addSkip", test, reason, details=details) - def addSuccess(self, test, details=None): - return self._dispatch("addSuccess", test, details=details) + def addSuccess( + self, test: unittest.TestCase, details: DetailsDict | None = None + ) -> None: + self._dispatch("addSuccess", test, details=details) - def addUnexpectedSuccess(self, test, details=None): - return self._dispatch("addUnexpectedSuccess", test, details=details) + def addUnexpectedSuccess( + self, test: unittest.TestCase, details: DetailsDict | None = None + ) -> None: + self._dispatch("addUnexpectedSuccess", test, details=details) - def startTestRun(self): + def startTestRun(self) -> None: super().startTestRun() - return self._dispatch("startTestRun") + self._dispatch("startTestRun") - def stopTestRun(self): + def stopTestRun(self) -> tuple[object, ...]: # type: ignore[override] return self._dispatch("stopTestRun") - def tags(self, new_tags, gone_tags): + def tags(self, new_tags: Iterable[str], gone_tags: Iterable[str]) -> None: super().tags(new_tags, gone_tags) - return self._dispatch("tags", new_tags, gone_tags) + self._dispatch("tags", new_tags, gone_tags) - def time(self, a_datetime): - return self._dispatch("time", a_datetime) + def time(self, a_datetime: datetime.datetime | None) -> None: + self._dispatch("time", a_datetime) - def done(self): - return self._dispatch("done") + def done(self) -> None: + self._dispatch("done") - def wasSuccessful(self): + def wasSuccessful(self) -> bool: """Was this result successful? Only returns True if every constituent result was successful. @@ -1213,7 +1541,13 @@ def wasSuccessful(self): class TextTestResult(TestResult): """A TestResult which outputs activity to a text stream.""" - def __init__(self, stream, failfast=False, tb_locals=False, verbosity=1): + def __init__( + self, + stream: TextIO | None, + failfast: bool = False, + tb_locals: bool = False, + verbosity: int = 1, + ) -> None: """Construct a TextTestResult writing to stream. :param stream: A file-like object to write results to. @@ -1229,25 +1563,23 @@ def __init__(self, stream, failfast=False, tb_locals=False, verbosity=1): self.verbosity = verbosity self._progress_printed = False - def _delta_to_float(self, a_timedelta, precision): + def _delta_to_float(self, a_timedelta: datetime.timedelta, precision: int) -> float: # This calls ceiling to ensure that the most pessimistic view of time # taken is shown (rather than leaving it to the Python %f operator # to decide whether to round/floor/ceiling. This was added when we # had pyp3 test failures that suggest a floor was happening. shift = 10**precision - return ( - math.ceil( - ( - a_timedelta.days * 86400.0 - + a_timedelta.seconds - + a_timedelta.microseconds / 1000000.0 - ) - * shift - ) - / shift + total_seconds = ( + a_timedelta.days * 86400.0 + + a_timedelta.seconds + + a_timedelta.microseconds / 1000000.0 ) + result: float = math.ceil(total_seconds * shift) / shift + return result - def _show_list(self, label, error_list): + def _show_list( + self, label: str, error_list: list[tuple[unittest.TestCase, str]] + ) -> None: if self.stream is None: return for test, output in error_list: @@ -1256,13 +1588,15 @@ def _show_list(self, label, error_list): self.stream.write(self.sep2) self.stream.write(output) - def startTest(self, test): + def startTest(self, test: unittest.TestCase) -> None: super().startTest(test) if self.stream is not None and self.verbosity >= 2: self.stream.write(f"{test.id()} ... ") self.stream.flush() - def addSuccess(self, test, details=None): + def addSuccess( + self, test: unittest.TestCase, details: DetailsDict | None = None + ) -> None: super().addSuccess(test, details=details) if self.stream is not None: if self.verbosity == 1: @@ -1273,7 +1607,17 @@ def addSuccess(self, test, details=None): self.stream.write("ok\n") self.stream.flush() - def addError(self, test, err=None, details=None): + def addError( # type: ignore[override] + self, + test: unittest.TestCase, + err: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: + """Called when an error has occurred. + + Note: This extends unittest.TestResult by making err optional and + adding details parameter. + """ super().addError(test, err=err, details=details) if self.stream is not None: if self.verbosity == 1: @@ -1283,7 +1627,17 @@ def addError(self, test, err=None, details=None): self.stream.write("ERROR\n") self.stream.flush() - def addFailure(self, test, err=None, details=None): + def addFailure( # type: ignore[override] + self, + test: unittest.TestCase, + err: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: + """Called when a failure has occurred. + + Note: This extends unittest.TestResult by making err optional and + adding details parameter. + """ super().addFailure(test, err=err, details=details) if self.stream is not None: if self.verbosity == 1: @@ -1293,7 +1647,12 @@ def addFailure(self, test, err=None, details=None): self.stream.write("FAIL\n") self.stream.flush() - def addSkip(self, test, reason=None, details=None): + def addSkip( + self, + test: unittest.TestCase, + reason: str | None = None, + details: DetailsDict | None = None, + ) -> None: super().addSkip(test, reason=reason, details=details) if self.stream is not None: if self.verbosity == 1: @@ -1303,8 +1662,18 @@ def addSkip(self, test, reason=None, details=None): self.stream.write(f"skipped {reason!r}\n") self.stream.flush() - def addExpectedFailure(self, test, err=None, details=None): - super().addExpectedFailure(test, err=err, details=details) + def addExpectedFailure( # type: ignore[override] + self, + test: unittest.TestCase, + err: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: + """Called when an expected failure has occurred. + + Note: This extends unittest.TestResult by making err optional and + adding details parameter. + """ + super().addExpectedFailure(test, err=err, details=details) # type: ignore[arg-type] if self.stream is not None: if self.verbosity == 1: self.stream.write("x") @@ -1313,7 +1682,9 @@ def addExpectedFailure(self, test, err=None, details=None): self.stream.write("expected failure\n") self.stream.flush() - def addUnexpectedSuccess(self, test, details=None): + def addUnexpectedSuccess( + self, test: unittest.TestCase, details: DetailsDict | None = None + ) -> None: super().addUnexpectedSuccess(test, details=details) if self.stream is not None: if self.verbosity == 1: @@ -1323,13 +1694,13 @@ def addUnexpectedSuccess(self, test, details=None): self.stream.write("unexpected success\n") self.stream.flush() - def startTestRun(self): + def startTestRun(self) -> None: super().startTestRun() self.__start = self._now() if self.stream is not None: self.stream.write("Tests running...\n") - def stopTestRun(self): + def stopTestRun(self) -> None: if self.testsRun != 1: plural = "s" else: @@ -1404,17 +1775,23 @@ def __init__( TestResult.__init__(self) self.result = ExtendedToOriginalDecorator(target) self.semaphore = semaphore - self._test_start = None - self._global_tags: tuple[set, set] = set(), set() - self._test_tags: tuple[set, set] = set(), set() + self._test_start: datetime.datetime | None = None + self._global_tags: tuple[set[str], set[str]] = set(), set() + self._test_tags: tuple[set[str], set[str]] = set(), set() - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__} {self.result!r}>" - def _any_tags(self, tags): + def _any_tags(self, tags: tuple[set[str], set[str]]) -> bool: return bool(tags[0] or tags[1]) - def _add_result_with_semaphore(self, method, test, *args, **kwargs): + def _add_result_with_semaphore( + self, + method: Callable[..., None], + test: unittest.TestCase, + *args: object, + **kwargs: object, + ) -> None: now = self._now() self.semaphore.acquire() try: @@ -1434,38 +1811,77 @@ def _add_result_with_semaphore(self, method, test, *args, **kwargs): self.semaphore.release() self._test_start = None - def addError(self, test, err=None, details=None): + def addError( # type: ignore[override] + self, + test: unittest.TestCase, + err: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: + """Called when an error has occurred. + + Note: This extends unittest.TestResult by making err optional and + adding details parameter. + """ self._add_result_with_semaphore( self.result.addError, test, err, details=details ) - def addExpectedFailure(self, test, err=None, details=None): + def addExpectedFailure( # type: ignore[override] + self, + test: unittest.TestCase, + err: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: + """Called when an expected failure has occurred. + + Note: This extends unittest.TestResult by making err optional and + adding details parameter. + """ self._add_result_with_semaphore( self.result.addExpectedFailure, test, err, details=details ) - def addFailure(self, test, err=None, details=None): + def addFailure( # type: ignore[override] + self, + test: unittest.TestCase, + err: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: + """Called when a failure has occurred. + + Note: This extends unittest.TestResult by making err optional and + adding details parameter. + """ self._add_result_with_semaphore( self.result.addFailure, test, err, details=details ) - def addSkip(self, test, reason=None, details=None): + def addSkip( + self, + test: unittest.TestCase, + reason: str | None = None, + details: DetailsDict | None = None, + ) -> None: self._add_result_with_semaphore( self.result.addSkip, test, reason, details=details ) - def addSuccess(self, test, details=None): + def addSuccess( + self, test: unittest.TestCase, details: DetailsDict | None = None + ) -> None: self._add_result_with_semaphore(self.result.addSuccess, test, details=details) - def addUnexpectedSuccess(self, test, details=None): + def addUnexpectedSuccess( + self, test: unittest.TestCase, details: DetailsDict | None = None + ) -> None: self._add_result_with_semaphore( self.result.addUnexpectedSuccess, test, details=details ) - def progress(self, offset, whence): + def progress(self, offset: int, whence: int) -> None: pass - def startTestRun(self): + def startTestRun(self) -> None: super().startTestRun() self.semaphore.acquire() try: @@ -1473,57 +1889,65 @@ def startTestRun(self): finally: self.semaphore.release() - def _get_shouldStop(self): + def _get_shouldStop(self) -> bool: self.semaphore.acquire() try: - return self.result.shouldStop + result = self.result.shouldStop + assert isinstance(result, bool) + return result finally: self.semaphore.release() - def _set_shouldStop(self, value): + def _set_shouldStop(self, value: bool) -> None: # Another case where we should not subclass TestResult pass shouldStop = property(_get_shouldStop, _set_shouldStop) - def stop(self): + def stop(self) -> None: self.semaphore.acquire() try: self.result.stop() finally: self.semaphore.release() - def stopTestRun(self): + def stopTestRun(self) -> None: self.semaphore.acquire() try: self.result.stopTestRun() finally: self.semaphore.release() - def done(self): + def done(self) -> None: self.semaphore.acquire() try: self.result.done() finally: self.semaphore.release() - def startTest(self, test): + def startTest(self, test: unittest.TestCase) -> None: self._test_start = self._now() super().startTest(test) - def wasSuccessful(self): + def wasSuccessful(self) -> bool: return self.result.wasSuccessful() - def tags(self, new_tags, gone_tags): + def tags(self, new_tags: Iterable[str], gone_tags: Iterable[str]) -> None: """See `TestResult`.""" super().tags(new_tags, gone_tags) if self._test_start is not None: - self._test_tags = _merge_tags(self._test_tags, (new_tags, gone_tags)) + self._test_tags = _merge_tags( + self._test_tags, (set(new_tags), set(gone_tags)) + ) else: - self._global_tags = _merge_tags(self._global_tags, (new_tags, gone_tags)) + self._global_tags = _merge_tags( + self._global_tags, (set(new_tags), set(gone_tags)) + ) -def _merge_tags(existing, changed): +def _merge_tags( + existing: tuple[set[str], set[str]], changed: tuple[set[str], set[str]] +) -> tuple[set[str], set[str]]: new_tags, gone_tags = changed result_new = set(existing[0]) result_gone = set(existing[1]) @@ -1544,105 +1968,144 @@ class ExtendedToOriginalDecorator: does not support the newer style of calling. """ - def __init__(self, decorated): - self.decorated = decorated + def __init__(self, decorated: unittest.TestResult) -> None: + self.decorated: unittest.TestResult | TestResult = decorated self._tags = TagContext() # Only used for old TestResults that do not have failfast. self._failfast = False # Used for old TestResults that do not have stop. self._shouldStop = False - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__} {self.decorated!r}>" - def __getattr__(self, name): + def __getattr__(self, name: str) -> object: return getattr(self.decorated, name) - def addError(self, test, err=None, details=None): + def addError( + self, + test: unittest.TestCase, + err: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: try: self._check_args(err, details) if details is not None: try: - return self.decorated.addError(test, details=details) + self.decorated.addError(test, details=details) # type: ignore[call-arg] except TypeError: # have to convert err = self._details_to_exc_info(details) - return self.decorated.addError(test, err) + else: + return + self.decorated.addError(test, err) # type: ignore[arg-type] finally: if self.failfast: self.stop() - def addExpectedFailure(self, test, err=None, details=None): + def addExpectedFailure( + self, + test: unittest.TestCase, + err: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: self._check_args(err, details) addExpectedFailure = getattr(self.decorated, "addExpectedFailure", None) if addExpectedFailure is None: - return self.addSuccess(test) + self.addSuccess(test) + return if details is not None: try: - return addExpectedFailure(test, details=details) + addExpectedFailure(test, details=details) except TypeError: # have to convert err = self._details_to_exc_info(details) - return addExpectedFailure(test, err) + else: + return + addExpectedFailure(test, err) - def addFailure(self, test, err=None, details=None): + def addFailure( + self, + test: unittest.TestCase, + err: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: try: self._check_args(err, details) if details is not None: try: - return self.decorated.addFailure(test, details=details) + self.decorated.addFailure(test, details=details) # type: ignore[call-arg] except TypeError: # have to convert err = self._details_to_exc_info(details) - return self.decorated.addFailure(test, err) + else: + return + self.decorated.addFailure(test, err) # type: ignore[arg-type] finally: if self.failfast: self.stop() - def addSkip(self, test, reason=None, details=None): + def addSkip( + self, + test: unittest.TestCase, + reason: str | None = None, + details: DetailsDict | None = None, + ) -> None: self._check_args(reason, details) addSkip = getattr(self.decorated, "addSkip", None) if addSkip is None: - return self.decorated.addSuccess(test) + self.decorated.addSuccess(test) + return if details is not None: try: - return addSkip(test, details=details) + addSkip(test, details=details) except TypeError: # extract the reason if it's available try: reason = details["reason"].as_text() except KeyError: reason = _details_to_str(details) - return addSkip(test, reason) + else: + return + addSkip(test, reason) - def addUnexpectedSuccess(self, test, details=None): + def addUnexpectedSuccess( + self, test: unittest.TestCase, details: DetailsDict | None = None + ) -> None: try: outcome = getattr(self.decorated, "addUnexpectedSuccess", None) if outcome is None: try: test.fail("") except test.failureException: - return self.addFailure(test, sys.exc_info()) + self.addFailure(test, sys.exc_info()) # type: ignore[arg-type] + return else: if details is not None: try: - return outcome(test, details=details) + outcome(test, details=details) except TypeError: pass - return outcome(test) + else: + return + outcome(test) finally: if self.failfast: self.stop() - def addSuccess(self, test, details=None): + def addSuccess( + self, test: unittest.TestCase, details: DetailsDict | None = None + ) -> None: if details is not None: try: - return self.decorated.addSuccess(test, details=details) + self.decorated.addSuccess(test, details=details) # type: ignore[call-arg] except TypeError: pass - return self.decorated.addSuccess(test) + else: + return + self.decorated.addSuccess(test) - def _check_args(self, err, details): + def _check_args(self, err: object, details: object) -> None: param_count = 0 if err is not None: param_count += 1 @@ -1653,7 +2116,7 @@ def _check_args(self, err, details): f"Must pass only one of err '{err}' and details '{details}" ) - def _details_to_exc_info(self, details): + def _details_to_exc_info(self, details: DetailsDict) -> ExcInfo: """Convert a details dict to an exc_info tuple.""" return ( _StringException, @@ -1662,19 +2125,19 @@ def _details_to_exc_info(self, details): ) @property - def current_tags(self): + def current_tags(self) -> set[str]: return getattr(self.decorated, "current_tags", self._tags.get_current_tags()) - def done(self): + def done(self) -> None: try: - return self.decorated.done() + self.decorated.done() # type: ignore[union-attr] except AttributeError: - return + pass - def _get_failfast(self): + def _get_failfast(self) -> bool: return getattr(self.decorated, "failfast", self._failfast) - def _set_failfast(self, value): + def _set_failfast(self, value: bool) -> None: if hasattr(self.decorated, "failfast"): self.decorated.failfast = value else: @@ -1682,16 +2145,16 @@ def _set_failfast(self, value): failfast = property(_get_failfast, _set_failfast) - def progress(self, offset, whence): + def progress(self, offset: int, whence: int) -> None: method = getattr(self.decorated, "progress", None) if method is None: return - return method(offset, whence) + method(offset, whence) - def _get_shouldStop(self): + def _get_shouldStop(self) -> bool: return getattr(self.decorated, "shouldStop", self._shouldStop) - def _set_shouldStop(self, value): + def _set_shouldStop(self, value: bool) -> None: if hasattr(self.decorated, "shouldStop"): self.decorated.shouldStop = value else: @@ -1699,49 +2162,50 @@ def _set_shouldStop(self, value): shouldStop = property(_get_shouldStop, _set_shouldStop) - def startTest(self, test): + def startTest(self, test: unittest.TestCase) -> None: self._tags = TagContext(self._tags) - return self.decorated.startTest(test) + self.decorated.startTest(test) - def startTestRun(self): + def startTestRun(self) -> None: self._tags = TagContext() try: - return self.decorated.startTestRun() + self.decorated.startTestRun() except AttributeError: - return + pass - def stop(self): + def stop(self) -> None: method = getattr(self.decorated, "stop", None) if method: - return method() - self.shouldStop = True + method() + else: + self.shouldStop = True - def stopTest(self, test): + def stopTest(self, test: unittest.TestCase) -> None: # NOTE: In Python 3.12.1 skipped tests may not call startTest() if self._tags is not None and self._tags.parent is not None: self._tags = self._tags.parent - return self.decorated.stopTest(test) + self.decorated.stopTest(test) - def stopTestRun(self): + def stopTestRun(self) -> object: # type: ignore[override] try: return self.decorated.stopTestRun() except AttributeError: - return + return None - def tags(self, new_tags, gone_tags): + def tags(self, new_tags: Iterable[str], gone_tags: Iterable[str]) -> None: method = getattr(self.decorated, "tags", None) if method is not None: - return method(new_tags, gone_tags) + method(new_tags, gone_tags) else: self._tags.change_tags(new_tags, gone_tags) - def time(self, a_datetime): + def time(self, a_datetime: datetime.datetime | None) -> None: method = getattr(self.decorated, "time", None) if method is None: return - return method(a_datetime) + method(a_datetime) - def wasSuccessful(self): + def wasSuccessful(self) -> bool: return self.decorated.wasSuccessful() @@ -1761,11 +2225,12 @@ def __init__(self, decorated: StreamResult) -> None: TestControl.__init__(self) self._started = False self._tags: TagContext | None = None + self.__now: datetime.datetime | None = None - def _get_failfast(self): + def _get_failfast(self) -> bool: return len(self.targets) == 2 - def _set_failfast(self, value): + def _set_failfast(self, value: bool) -> None: if value: if len(self.targets) == 2: return @@ -1775,24 +2240,36 @@ def _set_failfast(self, value): failfast = property(_get_failfast, _set_failfast) - def startTest(self, test): + def startTest(self, test: unittest.TestCase) -> None: if not self._started: self.startTestRun() self.status(test_id=test.id(), test_status="inprogress", timestamp=self._now()) self._tags = TagContext(self._tags) - def stopTest(self, test): + def stopTest(self, test: unittest.TestCase) -> None: # NOTE: In Python 3.12.1 skipped tests may not call startTest() if self._tags is not None: self._tags = self._tags.parent - def addError(self, test, err=None, details=None): + def addError( + self, + test: unittest.TestCase, + err: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: self._check_args(err, details) self._convert(test, err, details, "fail") addFailure = addError - def _convert(self, test, err, details, status, reason=None): + def _convert( + self, + test: unittest.TestCase, + err: ExcInfo | None, + details: DetailsDict | None, + status: str, + reason: str | None = None, + ) -> None: if not self._started: self.startTestRun() test_id = test.id() @@ -1841,20 +2318,34 @@ def _convert(self, test, err, details, status, reason=None): timestamp=now, ) - def addExpectedFailure(self, test, err=None, details=None): + def addExpectedFailure( + self, + test: unittest.TestCase, + err: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: self._check_args(err, details) self._convert(test, err, details, "xfail") - def addSkip(self, test, reason=None, details=None): + def addSkip( + self, + test: unittest.TestCase, + reason: str | None = None, + details: DetailsDict | None = None, + ) -> None: self._convert(test, None, details, "skip", reason) - def addUnexpectedSuccess(self, test, details=None): + def addUnexpectedSuccess( + self, test: unittest.TestCase, details: DetailsDict | None = None + ) -> None: self._convert(test, None, details, "uxsuccess") - def addSuccess(self, test, details=None): + def addSuccess( + self, test: unittest.TestCase, details: DetailsDict | None = None + ) -> None: self._convert(test, None, details, "success") - def _check_args(self, err, details): + def _check_args(self, err: ExcInfo | None, details: DetailsDict | None) -> None: param_count = 0 if err is not None: param_count += 1 @@ -1865,7 +2356,7 @@ def _check_args(self, err, details): f"Must pass only one of err '{err}' and details '{details}" ) - def startTestRun(self): + def startTestRun(self) -> None: super().startTestRun() self._tags = TagContext() self.shouldStop = False @@ -1873,13 +2364,13 @@ def startTestRun(self): self._started = True @property - def current_tags(self): + def current_tags(self) -> set[str]: """The currently set tags.""" if self._tags is None: return set() return self._tags.get_current_tags() - def tags(self, new_tags, gone_tags): + def tags(self, new_tags: Iterable[str], gone_tags: Iterable[str]) -> None: """Add and remove tags from the test. :param new_tags: A set of tags to be added to the stream. @@ -1888,7 +2379,7 @@ def tags(self, new_tags, gone_tags): if self._tags is not None: self._tags.change_tags(new_tags, gone_tags) - def _now(self): + def _now(self) -> datetime.datetime: """Return the current 'test time'. If the time() method has not been called, this is equivalent to @@ -1900,10 +2391,10 @@ def _now(self): else: return self.__now - def time(self, a_datetime): + def time(self, a_datetime: datetime.datetime) -> None: self.__now = a_datetime - def wasSuccessful(self): + def wasSuccessful(self) -> bool: if not self._started: self.startTestRun() return super().wasSuccessful() @@ -1926,19 +2417,21 @@ class ResourcedToStreamDecorator(ExtendedToStreamDecorator): The runnable flag will be set to False. """ - def startMakeResource(self, resource): + def startMakeResource(self, resource: TestResourceProtocol) -> None: self._convertResourceLifecycle(resource, "make", "start") - def stopMakeResource(self, resource): + def stopMakeResource(self, resource: TestResourceProtocol) -> None: self._convertResourceLifecycle(resource, "make", "stop") - def startCleanResource(self, resource): + def startCleanResource(self, resource: TestResourceProtocol) -> None: self._convertResourceLifecycle(resource, "clean", "start") - def stopCleanResource(self, resource): + def stopCleanResource(self, resource: TestResourceProtocol) -> None: self._convertResourceLifecycle(resource, "clean", "stop") - def _convertResourceLifecycle(self, resource, method, phase): + def _convertResourceLifecycle( + self, resource: TestResourceProtocol, method: str, phase: str + ) -> None: """Convert a resource lifecycle report to a stream event.""" # If the resource implements the TestResourceManager.id() API, let's # use it, otherwise fallback to the class name. @@ -1976,7 +2469,7 @@ class StreamToExtendedDecorator(StreamResult): 'testtools.extradata' flushed at the end of the run. """ - def __init__(self, decorated): + def __init__(self, decorated: unittest.TestResult) -> None: # ExtendedToOriginalDecorator takes care of thunking details back to # exceptions/reasons etc. self.decorated = ExtendedToOriginalDecorator(decorated) @@ -1985,17 +2478,17 @@ def __init__(self, decorated): def status( self, - test_id=None, - test_status=None, - test_tags=None, - runnable=True, - file_name=None, - file_bytes=None, - eof=False, - mime_type=None, - route_code=None, - timestamp=None, - ): + test_id: str | None = None, + test_status: str | None = None, + test_tags: set[str] | None = None, + runnable: bool = True, + file_name: str | None = None, + file_bytes: bytes | None = None, + eof: bool = False, + mime_type: str | None = None, + route_code: str | None = None, + timestamp: datetime.datetime | None = None, + ) -> None: if test_status == "exists": return self.hook.status( @@ -2011,19 +2504,19 @@ def status( timestamp=timestamp, ) - def startTestRun(self): + def startTestRun(self) -> None: self.decorated.startTestRun() self.hook.startTestRun() - def stopTestRun(self): + def stopTestRun(self) -> None: self.hook.stopTestRun() self.decorated.stopTestRun() - def _handle_tests(self, test_record): + def _handle_tests(self, test_record: _TestRecord) -> None: case = test_record.to_test_case() - case.run(self.decorated) + case.run(self.decorated) # type: ignore[arg-type] - def wasSuccessful(self): + def wasSuccessful(self) -> bool: """Return whether this result was successful. Delegates to the decorated result object. @@ -2031,27 +2524,31 @@ def wasSuccessful(self): return self.decorated.wasSuccessful() @property - def shouldStop(self): + def shouldStop(self) -> bool: """Return whether the test run should stop. Delegates to the decorated result object. """ - return self.decorated.shouldStop + result = self.decorated.shouldStop + assert isinstance(result, bool) + return result - def stop(self): + def stop(self) -> None: """Indicate that the test run should stop. Delegates to the decorated result object. """ - return self.decorated.stop() + self.decorated.stop() @property - def testsRun(self): + def testsRun(self) -> int: """Return the number of tests run. Delegates to the decorated result object. """ - return self.decorated.testsRun + result = self.decorated.testsRun + assert isinstance(result, int) + return result class StreamToQueue(StreamResult): @@ -2094,22 +2591,22 @@ def __init__(self, queue: Queue[EventDict], routing_code: str | None) -> None: self.queue = queue self.routing_code = routing_code - def startTestRun(self): + def startTestRun(self) -> None: self.queue.put(dict(event="startTestRun", result=self)) def status( self, - test_id=None, - test_status=None, - test_tags=None, - runnable=True, - file_name=None, - file_bytes=None, - eof=False, - mime_type=None, - route_code=None, - timestamp=None, - ): + test_id: str | None = None, + test_status: str | None = None, + test_tags: set[str] | None = None, + runnable: bool = True, + file_name: str | None = None, + file_bytes: bytes | None = None, + eof: bool = False, + mime_type: str | None = None, + route_code: str | None = None, + timestamp: datetime.datetime | None = None, + ) -> None: self.queue.put( dict( event="status", @@ -2126,7 +2623,7 @@ def status( ) ) - def stopTestRun(self): + def stopTestRun(self) -> None: self.queue.put(dict(event="stopTestRun", result=self)) def route_code(self, route_code: str | None) -> str | None: @@ -2145,85 +2642,116 @@ class TestResultDecorator: gain basic forwarding functionality. """ - def __init__(self, decorated): + def __init__(self, decorated: "TestResult") -> None: """Create a TestResultDecorator forwarding to decorated.""" self.decorated = decorated - def startTest(self, test): - return self.decorated.startTest(test) + def startTest(self, test: unittest.TestCase) -> None: + self.decorated.startTest(test) - def startTestRun(self): - return self.decorated.startTestRun() + def startTestRun(self) -> None: + self.decorated.startTestRun() - def stopTest(self, test): - return self.decorated.stopTest(test) + def stopTest(self, test: unittest.TestCase) -> None: + self.decorated.stopTest(test) - def stopTestRun(self): - return self.decorated.stopTestRun() + def stopTestRun(self) -> None: + self.decorated.stopTestRun() - def addError(self, test, err=None, details=None): - return self.decorated.addError(test, err, details=details) + def addError( + self, + test: unittest.TestCase, + err: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: + self.decorated.addError(test, err, details=details) - def addFailure(self, test, err=None, details=None): - return self.decorated.addFailure(test, err, details=details) + def addFailure( + self, + test: unittest.TestCase, + err: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: + self.decorated.addFailure(test, err, details=details) - def addSuccess(self, test, details=None): - return self.decorated.addSuccess(test, details=details) + def addSuccess( + self, test: unittest.TestCase, details: DetailsDict | None = None + ) -> None: + self.decorated.addSuccess(test, details=details) - def addSkip(self, test, reason=None, details=None): - return self.decorated.addSkip(test, reason, details=details) + def addSkip( + self, + test: unittest.TestCase, + reason: str | None = None, + details: DetailsDict | None = None, + ) -> None: + self.decorated.addSkip(test, reason, details=details) - def addExpectedFailure(self, test, err=None, details=None): - return self.decorated.addExpectedFailure(test, err, details=details) + def addExpectedFailure( + self, + test: unittest.TestCase, + err: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: + self.decorated.addExpectedFailure(test, err, details=details) # type: ignore[arg-type] - def addUnexpectedSuccess(self, test, details=None): - return self.decorated.addUnexpectedSuccess(test, details=details) + def addUnexpectedSuccess( + self, test: unittest.TestCase, details: DetailsDict | None = None + ) -> None: + self.decorated.addUnexpectedSuccess(test, details=details) - def addDuration(self, test, duration): - return self.decorated.addDuration(test, duration) + def addDuration(self, test: unittest.TestCase, duration: float) -> None: + self.decorated.addDuration(test, duration) - def progress(self, offset, whence): - return self.decorated.progress(offset, whence) + def progress(self, offset: int, whence: int) -> None: + self.decorated.progress(offset, whence) # type: ignore[attr-defined] - def wasSuccessful(self): + def wasSuccessful(self) -> bool: return self.decorated.wasSuccessful() @property - def current_tags(self): - return self.decorated.current_tags + def current_tags(self) -> set[str]: + result = self.decorated.current_tags + assert isinstance(result, set) + return result @property - def shouldStop(self): + def shouldStop(self) -> bool: return self.decorated.shouldStop - def stop(self): - return self.decorated.stop() + def stop(self) -> None: + self.decorated.stop() @property - def testsRun(self): + def testsRun(self) -> int: return self.decorated.testsRun - def tags(self, new_tags, gone_tags): - return self.decorated.tags(new_tags, gone_tags) + def tags(self, new_tags: Iterable[str], gone_tags: Iterable[str]) -> None: + self.decorated.tags(new_tags, gone_tags) - def time(self, a_datetime): - return self.decorated.time(a_datetime) + def time(self, a_datetime: datetime.datetime) -> None: + self.decorated.time(a_datetime) class Tagger(TestResultDecorator): """Tag each test individually.""" - def __init__(self, decorated, new_tags, gone_tags): + def __init__( + self, + decorated: unittest.TestResult, + new_tags: set[str], + gone_tags: set[str], + ) -> None: """Wrap 'decorated' such that each test is tagged. :param new_tags: Tags to be added for each test. :param gone_tags: Tags to be removed for each test. """ - super().__init__(decorated) + super().__init__(decorated) # type: ignore[arg-type] self._new_tags = set(new_tags) self._gone_tags = set(gone_tags) - def startTest(self, test): + def startTest(self, test: unittest.TestCase) -> None: super().startTest(test) self.tags(self._new_tags, self._gone_tags) @@ -2231,7 +2759,10 @@ def startTest(self, test): class TestByTestResult(TestResult): """Call something every time a test completes.""" - def __init__(self, on_test): + def __init__( + self, + on_test: _OnTestCallback, + ) -> None: """Construct a ``TestByTestResult``. :param on_test: A callable that take a test case, a status (one of @@ -2243,16 +2774,16 @@ def __init__(self, on_test): super().__init__() self._on_test = on_test - def startTest(self, test): + def startTest(self, test: unittest.TestCase) -> None: super().startTest(test) self._start_time = self._now() # There's no supported (i.e. tested) behaviour that relies on these # being set, but it makes me more comfortable all the same. -- jml - self._status = None - self._details = None - self._stop_time = None + self._status: str | None = None + self._details: DetailsDict | None = None + self._stop_time: datetime.datetime | None = None - def stopTest(self, test): + def stopTest(self, test: unittest.TestCase) -> None: self._stop_time = self._now() tags = set(self.current_tags) super().stopTest(test) @@ -2265,42 +2796,88 @@ def stopTest(self, test): details=self._details, ) - def _err_to_details(self, test, err, details): + def _err_to_details( + self, + test: unittest.TestCase, + err: ExcInfo | None, + details: DetailsDict | None, + ) -> DetailsDict: if details: return details + assert err is not None, "Either err or details must be provided" return {"traceback": TracebackContent(err, test, capture_locals=self.tb_locals)} - def addSuccess(self, test, details=None): + def addSuccess( + self, test: unittest.TestCase, details: DetailsDict | None = None + ) -> None: super().addSuccess(test) self._status = "success" self._details = details - def addFailure(self, test, err=None, details=None): + def addFailure( # type: ignore[override] + self, + test: unittest.TestCase, + err: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: + """Called when a failure has occurred. + + Note: This extends unittest.TestResult by making err optional and + adding details parameter. + """ super().addFailure(test, err, details) self._status = "failure" self._details = self._err_to_details(test, err, details) - def addError(self, test, err=None, details=None): + def addError( # type: ignore[override] + self, + test: unittest.TestCase, + err: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: + """Called when an error has occurred. + + Note: This extends unittest.TestResult by making err optional and + adding details parameter. + """ super().addError(test, err, details) self._status = "error" self._details = self._err_to_details(test, err, details) - def addSkip(self, test, reason=None, details=None): + def addSkip( + self, + test: unittest.TestCase, + reason: str | None = None, + details: DetailsDict | None = None, + ) -> None: super().addSkip(test, reason, details) self._status = "skip" if details is None: + assert reason is not None, "Either reason or details must be provided" details = {"reason": text_content(reason)} elif reason: # XXX: What if details already has 'reason' key? details["reason"] = text_content(reason) self._details = details - def addExpectedFailure(self, test, err=None, details=None): - super().addExpectedFailure(test, err, details) + def addExpectedFailure( # type: ignore[override] + self, + test: unittest.TestCase, + err: ExcInfo | None = None, + details: DetailsDict | None = None, + ) -> None: + """Called when an expected failure has occurred. + + Note: This extends unittest.TestResult by making err optional and + adding details parameter. + """ + super().addExpectedFailure(test, err, details) # type: ignore[arg-type] self._status = "xfail" self._details = self._err_to_details(test, err, details) - def addUnexpectedSuccess(self, test, details=None): + def addUnexpectedSuccess( + self, test: unittest.TestCase, details: DetailsDict | None = None + ) -> None: super().addUnexpectedSuccess(test, details) self._status = "success" self._details = details @@ -2315,33 +2892,54 @@ class TimestampingStreamResult(CopyStreamResult): def __init__(self, target: StreamResult) -> None: super().__init__([target]) - def status(self, *args, **kwargs): - timestamp = kwargs.pop("timestamp", None) + def status( + self, + test_id: str | None = None, + test_status: str | None = None, + test_tags: set[str] | None = None, + runnable: bool = True, + file_name: str | None = None, + file_bytes: bytes | None = None, + eof: bool = False, + mime_type: str | None = None, + route_code: str | None = None, + timestamp: datetime.datetime | None = None, + ) -> None: if timestamp is None: timestamp = datetime.datetime.now(utc) - super().status(*args, timestamp=timestamp, **kwargs) + super().status( + test_id=test_id, + test_status=test_status, + test_tags=test_tags, + runnable=runnable, + file_name=file_name, + file_bytes=file_bytes, + eof=eof, + mime_type=mime_type, + route_code=route_code, + timestamp=timestamp, + ) class _StringException(Exception): """An exception made from an arbitrary string.""" - def __hash__(self): + def __hash__(self) -> int: return id(self) - def __eq__(self, other): - try: - return self.args == other.args - except AttributeError: + def __eq__(self, other: object) -> bool: + if not isinstance(other, BaseException): return False + return self.args == other.args -def _format_text_attachment(name, text): +def _format_text_attachment(name: str, text: str) -> str: if "\n" in text: return f"{name}: {{{{{{\n{text}\n}}}}}}\n" return f"{name}: {{{{{{{text}}}}}}}" -def _details_to_str(details, special=None): +def _details_to_str(details: DetailsDict, special: str | None = None) -> str: """Convert a details dict to a string. :param details: A dictionary mapping short names to ``Content`` objects. diff --git a/testtools/twistedsupport/_runtest.py b/testtools/twistedsupport/_runtest.py index 035c27dc..e549f3ef 100644 --- a/testtools/twistedsupport/_runtest.py +++ b/testtools/twistedsupport/_runtest.py @@ -388,7 +388,9 @@ def _log_user_exception(self, e): try: raise e except e.__class__: - self._got_user_exception(sys.exc_info()) + exc_info = sys.exc_info() + assert exc_info[0] is not None and exc_info[1] is not None + self._got_user_exception(exc_info) def _blocking_run_deferred(self, spinner): try: @@ -396,7 +398,9 @@ def _blocking_run_deferred(self, spinner): except NoResultError: # We didn't get a result at all! This could be for any number of # reasons, but most likely someone hit Ctrl-C during the test. - self._got_user_exception(sys.exc_info()) + exc_info = sys.exc_info() + assert exc_info[0] is not None and exc_info[1] is not None + self._got_user_exception(exc_info) self.result.stop() return False, [] except TimeoutError: @@ -419,7 +423,7 @@ def _run_core(self): # XXX: Blatting over the namespace of the test case isn't a nice thing # to do. Find a better way of communicating between runtest and test # case. - self.case.reactor = self._reactor + setattr(self.case, "reactor", self._reactor) spinner = self._make_spinner() # We can't just install these as fixtures on self.case, because we