File: loader.py

package info (click to toggle)
python-refurb 1.27.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,700 kB
  • sloc: python: 9,468; makefile: 40; sh: 6
file content (180 lines) | stat: -rw-r--r-- 5,322 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import importlib
import pkgutil
import sys
from collections import defaultdict
from collections.abc import Generator
from importlib.metadata import entry_points
from inspect import getsourcefile, getsourcelines, signature
from pathlib import Path
from types import GenericAlias, ModuleType, UnionType
from typing import Any, TypeGuard

from mypy.nodes import Node

from refurb.visitor.mapping import METHOD_NODE_MAPPINGS

from . import checks as checks_module
from .error import Error, ErrorCategory, ErrorCode
from .settings import Settings
from .types import Check


def get_modules(paths: list[str]) -> Generator[ModuleType, None, None]:
    sys.path.append(str(Path.cwd()))

    plugins = [x.value for x in entry_points(group="refurb.plugins")]
    extra_modules = (importlib.import_module(x) for x in paths + plugins)

    loaded: set[ModuleType] = set()

    for pkg in (checks_module, *extra_modules):
        if pkg in loaded:
            continue

        if not hasattr(pkg, "__path__"):
            module = importlib.import_module(pkg.__name__)

            if module not in loaded:
                loaded.add(module)
                yield module

            continue

        for info in pkgutil.walk_packages(pkg.__path__, f"{pkg.__name__}."):
            if info.ispkg:
                continue

            module = importlib.import_module(info.name)

            if module not in loaded:
                loaded.add(module)
                yield module

        loaded.add(pkg)


def is_valid_error_class(obj: Any) -> TypeGuard[type[Error]]:  # type: ignore
    if not hasattr(obj, "__name__"):
        return False

    name = obj.__name__
    ignored_names = ("Error", "ErrorCode", "ErrorCategory")

    return name.startswith("Error") and name not in ignored_names and issubclass(obj, Error)


def get_error_class(module: ModuleType) -> type[Error] | None:
    for name in dir(module):
        if name.startswith("Error") and name not in {"Error", "ErrorCode"}:
            error = getattr(module, name)

            if is_valid_error_class(error):
                return error

    return None


def should_load_check(settings: Settings, error: type[Error]) -> bool:
    error_code = ErrorCode.from_error(error)

    if error_code in settings.enable:
        return True

    if error_code in (settings.disable | settings.ignore):
        return False

    categories = {ErrorCategory(cat) for cat in error.categories}

    if settings.enable & categories:
        return True

    if settings.disable & categories or settings.disable_all:
        return False

    return error.enabled or settings.enable_all


VALID_NODE_TYPES = set(METHOD_NODE_MAPPINGS.values())
VALID_OPTIONAL_ARGS = (("settings", Settings),)


def type_error_with_line_info(func: Any, msg: str) -> TypeError:  # type: ignore
    filename = getsourcefile(func)
    line = getsourcelines(func)[1]

    if not filename:
        return TypeError(msg)  # pragma: no cover

    return TypeError(f"{filename}:{line}: {msg}")


def extract_function_types(  # type: ignore
    func: Any,
) -> Generator[type[Node], None, None]:
    if not callable(func):
        raise TypeError("Check function must be callable")

    params = list(signature(func).parameters.values())

    if len(params) not in {2, 3}:
        raise type_error_with_line_info(func, "Check function must take 2-3 parameters")

    node_param = params[0].annotation
    error_param = params[1].annotation
    optional_params = params[2:]

    if not (
        type(error_param) == GenericAlias
        and error_param.__origin__ is list
        and error_param.__args__[0] is Error
    ):
        raise type_error_with_line_info(func, '"error" param must be of type list[Error]')

    for param in optional_params:
        if (param.name, param.annotation) not in VALID_OPTIONAL_ARGS:
            raise type_error_with_line_info(
                func,
                f'"{param.name}: {param.annotation.__name__}" is not a valid service',  # noqa: E501
            )

    match node_param:
        case UnionType() as types:
            for ty in types.__args__:
                if ty not in VALID_NODE_TYPES:
                    raise type_error_with_line_info(
                        func,
                        f'"{ty.__name__}" is not a valid Mypy node type',
                    )

                yield ty

        case ty if ty in VALID_NODE_TYPES:
            yield ty

        case _:
            raise type_error_with_line_info(
                func,
                f'"{ty.__name__}" is not a valid Mypy node type',
            )


def load_checks(settings: Settings) -> defaultdict[type[Node], list[Check]]:
    found: defaultdict[type[Node], list[Check]] = defaultdict(list)
    enabled_errors: set[str] = set()

    for module in get_modules(settings.load):
        error = get_error_class(module)

        if error and should_load_check(settings, error):
            if func := getattr(module, "check", None):
                for ty in extract_function_types(func):
                    found[ty].append(func)

            enabled_errors.add(str(ErrorCode.from_error(error)))

    if settings.verbose:
        msg = ", ".join(sorted(enabled_errors)) if enabled_errors else "No checks enabled"

        print(f"Enabled checks: {msg}\n")

    return found