1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
|
# -*- coding: utf-8 -*-
"""
Generic mock functions for unit-tests.
Use `check_call` as a context manager to mock calls and their return values as
well to check if the method was called (or not).
"""
from __future__ import annotations
import contextlib
from types import ModuleType
from typing import Any, Dict, Generator, List, Tuple, Type, overload
from _pytest.monkeypatch import MonkeyPatch
Call = Tuple[Tuple[Any, ...], Dict[str, Any]]
CallList = List[Call]
@overload
@contextlib.contextmanager
def _mp_call(
monkeypatch: MonkeyPatch,
mock_class: Type[Any] | ModuleType,
method: str,
return_value: Any,
as_property: bool,
) -> CallList: ...
@overload
@contextlib.contextmanager
def _mp_call(
monkeypatch: MonkeyPatch,
mock_class: str,
method: Any, # return value in this case
return_value: bool, # as_property in this case
) -> CallList: ...
def _mp_call(
monkeypatch: MonkeyPatch,
mock_class: Type[Any] | ModuleType | str,
method: str | Any,
return_value: Any,
as_property: bool = False,
) -> CallList:
"""Mock a method in a class and record the calls to it.
If the given return_value is an Exception, it will be raised.
If not, the value will be returned from the mocked function/method.
"""
calls: CallList = []
def func_call(*a: Any, **k: Any) -> Any:
"""Mock the function call."""
calls.append((a, k))
if isinstance(return_value, Exception):
# bug in pylint https://www.logilab.org/ticket/3207
raise return_value # pylint: disable-msg=raising-bad-type
if callable(return_value):
# Handle the case that a function was passed
return return_value(*a, **k)
return return_value
# first case handles class + method, second one mock as str
if as_property or (isinstance(mock_class, str) and return_value):
callback = property(func_call)
else:
callback = func_call
if isinstance(mock_class, str):
return_value = method
monkeypatch.setattr(mock_class, callback)
else:
monkeypatch.setattr(mock_class, method, callback)
return calls
@contextlib.contextmanager
def check_call( # pylint: disable=too-many-arguments
mock_class: Type[Any] | ModuleType,
method: str,
return_value: Any = None,
call_args_list: List[Tuple[Any, ...]] | None = None,
call_kwargs_list: List[Dict[str, Any]] | None = None,
call_count: int = 1,
as_property: bool = False,
) -> Generator[CallList, None, None]:
"""Context manager for mocking and checking a call to a method.
If called is greater 0, and call_args and call_kwargs are given, the context manager
will check that the mocked method was called with those arguments. Also, it will
assert that the mock was called exactly once.
If called is False, it will assert that the mock was not called.
If a return_value is given, the mock will return this value. One can pass an
exception that will be raised by the mocked method instead of returning a value.
If a Callable is passed, it will be called and its return value returned.
"""
assert (call_args_list is not None and call_kwargs_list is not None) or (
call_args_list is None and call_kwargs_list is None
), "call_args and call_kwargs must be None or have a value (list/dict if empty)"
monkeypatch = MonkeyPatch()
calls = _mp_call(monkeypatch, mock_class, method, return_value, as_property)
try:
yield calls
finally:
m_name = f"{mock_class.__name__}.{method}"
monkeypatch.undo()
assert_calls(call_count, call_args_list, call_kwargs_list, calls, m_name)
# Duplicate the code because overloading is a mess due to this bug:
# https://github.com/python/mypy/issues/11373
@contextlib.contextmanager
def check_call_str( # pylint: disable=too-many-arguments
mock_class: str,
return_value: Any = None,
call_args_list: List[Tuple[Any, ...]] | None = None,
call_kwargs_list: List[Dict[str, Any]] | None = None,
call_count: int = 1,
as_property: bool = False,
) -> Generator[CallList, None, None]:
"""
Context manager for mocking and checking a call to a method.
See `check_call` documentation.
"""
assert (call_args_list is not None and call_kwargs_list is not None) or (
call_args_list is None and call_kwargs_list is None
), "call_args and call_kwargs must be None or have a value (list/dict if empty)"
monkeypatch = MonkeyPatch()
calls = _mp_call(monkeypatch, mock_class, return_value, as_property)
yield calls
m_name = mock_class
monkeypatch.undo()
assert_calls(call_count, call_args_list, call_kwargs_list, calls, m_name)
def assert_calls(
call_count: int,
call_args_list: List[Tuple[Any, ...]] | None,
call_kwargs_list: List[Dict[str, Any]] | None,
calls: CallList,
m_name: str,
) -> None:
"""Check that the calls made to the mocked function are correct."""
if call_count != -1:
assert len(calls) == call_count, (
f"Expected {call_count} calls to {m_name} but got: {len(calls)}"
)
if call_args_list and call_kwargs_list:
for idx, call_args in enumerate(call_args_list):
assert call_args == calls[idx][0], (
f"Args to {m_name}: {call_args} expected: {call_args}"
)
for idx, call_kwargs in enumerate(call_kwargs_list):
assert call_kwargs == calls[idx][1], (
f"Kwargs to {m_name}: {call_kwargs} expected: {call_kwargs}"
)
|