File: utils.py

package info (click to toggle)
python-libcst 1.4.0-1.2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 5,928 kB
  • sloc: python: 76,235; makefile: 10; sh: 2
file content (173 lines) | stat: -rw-r--r-- 6,106 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe

import inspect
import re
from functools import wraps
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    List,
    Mapping,
    Optional,
    Sequence,
    Tuple,
    TypeVar,
    Union,
)
from unittest import TestCase

DATA_PROVIDER_DATA_ATTR_NAME = "__data_provider_data"
DATA_PROVIDER_DESCRIPTION_PREFIX = "_data_provider_"
PROVIDER_TEST_LIMIT_ATTR_NAME = "__provider_test_limit"
DEFAULT_TEST_LIMIT = 256


T = TypeVar("T")


def none_throws(value: Optional[T], message: str = "Unexpected None value") -> T:
    assert value is not None, message
    return value


def update_test_limit(test_method: Any, test_limit: int) -> None:
    # Store the maximum number of generated tests on the test_method. Since
    # contextmanager_provider can be specified multiple times, we need to
    # take the maximum of the existing attribute and the current value
    existing_test_limit = getattr(
        test_method, PROVIDER_TEST_LIMIT_ATTR_NAME, test_limit
    )
    setattr(
        test_method, PROVIDER_TEST_LIMIT_ATTR_NAME, max(existing_test_limit, test_limit)
    )


def try_get_provider_attr(
    member_name: str, member: Any, attr_name: str
) -> Optional[Any]:
    if inspect.isfunction(member) and member_name.startswith("test"):
        return getattr(member, attr_name, None)
    return None


def populate_data_provider_tests(dct: Dict[str, Any]) -> None:
    test_methods_to_add: Dict[str, Callable] = {}
    test_methods_to_remove: List[str] = []
    for member_name, member in dct.items():
        provider_data = try_get_provider_attr(
            member_name, member, DATA_PROVIDER_DATA_ATTR_NAME
        )
        if provider_data is not None:
            for description, data in (
                provider_data.items()
                if isinstance(provider_data, dict)
                else enumerate(provider_data)
            ):
                if isinstance(provider_data, dict):
                    description = f"{DATA_PROVIDER_DESCRIPTION_PREFIX}{description}"

                assert re.fullmatch(
                    r"[a-zA-Z0-9_]+", str(description)
                ), f"Testcase description must be a valid python identifier: '{description}'"

                @wraps(member)
                def new_test(
                    self: object,
                    data: Iterable[object] = data,
                    member: Callable[..., object] = member,
                ) -> object:
                    if isinstance(data, dict):
                        return member(self, **data)
                    else:
                        return member(self, *data)

                name = f"{member_name}_{description}"
                new_test.__name__ = name
                test_methods_to_add[name] = new_test
            if not test_methods_to_add:
                raise ValueError(
                    f"No data_provider tests were created for {member_name}! Please double check your data."
                )
            test_methods_to_remove.append(member_name)
    dct.update(test_methods_to_add)

    # Remove all old methods
    for test_name in test_methods_to_remove:
        del dct[test_name]


def validate_provider_tests(dct: Dict[str, Any]) -> None:
    members_to_replace = {}

    for member_name, member in dct.items():
        test_limit = try_get_provider_attr(
            member_name, member, PROVIDER_TEST_LIMIT_ATTR_NAME
        )
        if test_limit is not None:
            data = try_get_provider_attr(
                member_name, member, DATA_PROVIDER_DATA_ATTR_NAME
            )
            num_tests = len(data) if data else 1

            if num_tests > test_limit:
                # We don't use wraps() here so that the test isn't expanded
                # as it normally would be by whichever provider it uses
                def test_replacement(
                    self: Any,
                    member_name: Any = member_name,
                    num_tests: Any = num_tests,
                    test_limit: Any = test_limit,
                ) -> None:
                    raise AssertionError(
                        f"{member_name} generated {num_tests} tests but the limit is "
                        + f"{test_limit}. You can increase the number of "
                        + "allowed tests by specifying test_limit, but please "
                        + "consider whether you really need to test all of "
                        + "these combinations."
                    )

                setattr(test_replacement, "__name__", member_name)
                members_to_replace[member_name] = test_replacement

    for member_name, new_member in members_to_replace.items():
        dct[member_name] = new_member


TestCaseType = Union[Sequence[object], Mapping[str, object]]
# Can't use Sequence[TestCaseType] here as some clients may pass in a Generator[TestCaseType]
StaticDataType = Union[Iterable[TestCaseType], Mapping[str, TestCaseType]]


def data_provider(
    static_data: StaticDataType, *, test_limit: int = DEFAULT_TEST_LIMIT
) -> Callable[[Callable], Callable]:
    # We need to be able to iterate over static_data more than once
    # (for validation), so if we weren't passed in a dict, list, or tuple
    # then we'll just create a list from the data
    if not isinstance(static_data, (dict, list, tuple)):
        static_data = list(static_data)

    def test_decorator(test_method: Callable) -> Callable:
        update_test_limit(test_method, test_limit)

        setattr(test_method, DATA_PROVIDER_DATA_ATTR_NAME, static_data)
        return test_method

    return test_decorator


class BaseTestMeta(type):
    def __new__(mcs, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> object:
        validate_provider_tests(dct)
        populate_data_provider_tests(dct)
        return super().__new__(mcs, name, bases, dict(dct))


class UnitTest(TestCase, metaclass=BaseTestMeta):
    pass