1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
|
from __future__ import annotations
import abc
import contextlib
import logging
from collections.abc import Iterable as IterableABC
from dataclasses import dataclass
from typing import (
Any,
Callable,
Dict,
Generator,
Generic,
Iterable,
NoReturn,
Optional,
Pattern,
Sequence,
Type,
TypeVar,
Union,
cast,
)
import pytest
from pytest import ExceptionInfo
AnyT = TypeVar("AnyT")
OutcomePrimitive = Union[
AnyT, Callable[[AnyT], None], "OutcomeChecker[AnyT]", Type[Exception], Exception
]
OutcomePrimitives = Union[
Iterable[Union[AnyT, Callable[[AnyT], None], "OutcomeChecker[AnyT]"]],
OutcomePrimitive,
]
class OutcomeChecker(abc.ABC, Generic[AnyT]):
"""
Validates expected outcomes for tests.
Useful for parameterized test that can result in values or
exceptions.
"""
@abc.abstractmethod
def check(self, actual: AnyT) -> None:
"""
Check the actual outcome against the expectation.
This should run inside the checker's context.
Raises:
AssertionError: If the outcome does not match the
expectation.
RuntimeError: If this method is called when no outcome
is expected.
"""
...
@contextlib.contextmanager
@abc.abstractmethod
def context(self) -> Generator[Optional[ExceptionInfo[Exception]], None, None]:
"""
The context in which the test code should run.
This is necessary for checking exception outcomes.
Returns:
A context manager that yields the exception info for
any exceptions that were raised in this context.
Raises:
AssertionError: If the test does not raise an exception
when one is expected, or if the exception does not match the
expectation.
"""
...
@classmethod
def from_primitive(
cls,
primitive: OutcomePrimitive[AnyT],
) -> OutcomeChecker[AnyT]:
checker = cls._from_special(primitive)
if checker is not None:
return checker
return ValueChecker(cast(AnyT, primitive))
@classmethod
def _from_special(
cls,
primitive: Union[
AnyT,
Callable[[AnyT], None],
OutcomeChecker[AnyT],
Type[Exception],
Exception,
],
) -> Optional[OutcomeChecker[AnyT]]:
if isinstance(primitive, OutcomeChecker):
return primitive
if isinstance(primitive, type) and issubclass(primitive, Exception):
return ExceptionChecker(primitive)
if isinstance(primitive, Exception):
return ExceptionChecker(type(primitive), match=primitive.args[0])
if callable(primitive):
return CallableChecker(cast(Callable[[AnyT], None], primitive))
return None
@classmethod
def from_primitives(
cls,
primitives: OutcomePrimitives[AnyT],
) -> OutcomeChecker[AnyT]:
checker = cls._from_special(primitives) # type: ignore[arg-type]
if checker is not None:
return checker
if isinstance(primitives, IterableABC) and not isinstance(
primitives, (str, bytes)
):
primitives = iter(primitives)
return AggregateChecker([cls.from_primitive(p) for p in primitives])
return ValueChecker(cast(AnyT, primitives))
@dataclass(frozen=True)
class NoExceptionChecker(OutcomeChecker[AnyT]):
"""
Base class for checkers that do not expect exceptions.
"""
@contextlib.contextmanager
def context(self) -> Generator[None, None, None]:
yield None
@dataclass(frozen=True)
class AggregateChecker(NoExceptionChecker[AnyT]):
"""
Validates that the outcome matches all of the given checkers.
"""
checkers: Sequence[OutcomeChecker[AnyT]]
def check(self, actual: AnyT) -> None:
for checker in self.checkers:
if isinstance(checker, ExceptionChecker):
raise ValueError(
"AggregateChecker should never contain ExceptionChecker"
)
checker.check(actual)
@dataclass(frozen=True)
class ValueChecker(NoExceptionChecker[AnyT]):
"""
Validates that the outcome is a specific value.
Args:
value: The expected value.
"""
expected: AnyT
def check(self, actual: AnyT) -> None:
assert self.expected == actual
@dataclass(frozen=True)
class CallableChecker(NoExceptionChecker[AnyT]):
"""
Validates the outcome with a callable.
Args:
callable: The callable that will be called with the outcome
to validate it.
"""
callable: Callable[[AnyT], None]
def check(self, actual: AnyT) -> None:
self.callable(actual)
@dataclass(frozen=True)
class ExceptionChecker(OutcomeChecker[AnyT]):
"""
Validates that the outcome is a specific exception.
Args:
type: The expected exception type.
match: A regular expression or string that the exception
message must match.
attributes: A dictionary of attributes that the exception
must have and their expected values.
"""
type: Type[Exception]
match: Optional[Union[Pattern[str], str]] = None
attributes: Optional[Dict[str, Any]] = None
def check(self, actual: AnyT) -> NoReturn:
raise RuntimeError("ExceptionResult.check_result should never be called")
def _check_attributes(self, exception: Exception) -> None:
if self.attributes is not None:
for key, value in self.attributes.items():
logging.debug("checking exception attribute %s=%r", key, value)
assert hasattr(exception, key)
assert getattr(exception, key) == value
@contextlib.contextmanager
def context(self) -> Generator[ExceptionInfo[Exception], None, None]:
with pytest.raises(self.type, match=self.match) as catcher:
yield catcher
self._check_attributes(catcher.value)
|