1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
|
import warnings
from types import TracebackType
from typing import TypeAlias, Any
from testfixtures import Comparison, SequenceComparison, compare
WarningOrType: TypeAlias = Warning | type[Warning]
class ShouldWarn(warnings.catch_warnings):
"""
This context manager is used to assert that warnings are issued
within the context it is managing.
:param expected: This should be a sequence made up of one or more elements,
each of one of the following types:
* A warning class, indicating that the type
of the warnings is important but not the
parameters it is created with.
* A warning instance, indicating that a
warning exactly matching the one supplied
should have been issued.
If no expected warnings are passed, you will need to inspect
the contents of the list returned by the context manager.
:param order_matters:
A keyword-only parameter that controls whether the order of the
captured entries is required to match those of the expected entries.
Defaults to ``True``.
:param filters:
If passed, these are used to create a filter such that only warnings you
are interested in will be considered by this :class:`ShouldWarn`
instance. The names and meanings are the same as the parameters for
:func:`warnings.filterwarnings`.
"""
_empty_okay = False
recorded: list[warnings.WarningMessage]
def __init__(
self, *expected: WarningOrType, order_matters: bool = True, **filters: Any
) -> None:
super(ShouldWarn, self).__init__(record=True)
self.order_matters = order_matters
self.expected = [Comparison(e) for e in expected]
self.filters = filters
def __enter__(self) -> list[warnings.WarningMessage]:
# We pass `record=True` above, so the following will *always* return a list:
self.recorded = super(ShouldWarn, self).__enter__() # type: ignore[assignment]
warnings.filterwarnings("always", **self.filters)
return self.recorded
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
super(ShouldWarn, self).__exit__(exc_type, exc_val, exc_tb)
if not self.recorded and self._empty_okay:
return
if not self.expected and self.recorded and not self._empty_okay:
return
compare(
expected=SequenceComparison(*self.expected, ordered=self.order_matters),
actual=[wm.message for wm in self.recorded]
)
class ShouldNotWarn(ShouldWarn):
"""
This context manager is used to assert that no warnings are issued
within the context it is managing.
"""
_empty_okay = True
def __init__(self) -> None:
super(ShouldNotWarn, self).__init__()
|