File: test_generic_consistency.py

package info (click to toggle)
python-django-stubs 5.2.9-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 4,832 kB
  • sloc: python: 5,185; makefile: 15; sh: 8
file content (108 lines) | stat: -rw-r--r-- 4,471 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import ast
import glob
import importlib
import os
from typing import Any, final
from unittest import mock

import django

from django_stubs_ext.patch import MPGeneric

# The root directory of the django-stubs package
STUBS_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "django-stubs"))


@final
class GenericInheritanceVisitor(ast.NodeVisitor):
    """AST visitor to find classes inheriting from `typing.Generic` in stubs."""

    def __init__(self) -> None:
        self.generic_classes: set[str] = set()

    def visit_ClassDef(self, node: ast.ClassDef) -> None:
        for base in node.bases:
            if (
                isinstance(base, ast.Subscript)
                and isinstance(base.value, ast.Name)
                and base.value.id == "Generic"
                and not any(dec.id == "type_check_only" for dec in node.decorator_list if isinstance(dec, ast.Name))
            ):
                self.generic_classes.add(node.name)
                break
        self.generic_visit(node)


def test_find_classes_inheriting_from_generic() -> None:
    """
    This test ensures that the `ext/django_stubs_ext/patch.py` stays up-to-date with the stubs.
    It works as follows:
        1. Parse the ast of each .pyi file, and collects classes inheriting from Generic.
        2. For each Generic in the stubs, import the associated module and capture every class in the MRO
        3. Ensure that at least one class in the mro is patched in `ext/django_stubs_ext/patch.py`.
    """
    with mock.patch.dict(os.environ, {"DJANGO_SETTINGS_MODULE": "scripts.django_tests_settings"}):
        # We need this to be able to do django import
        django.setup()

    # A dict of class_name -> [subclasses names] for each Generic in the stubs.
    all_generic_classes: dict[str, list[str]] = {}

    print(f"Searching for classes inheriting from Generic in: {STUBS_ROOT}")
    pyi_files = glob.glob("**/*.pyi", root_dir=STUBS_ROOT, recursive=True)
    for file_path in pyi_files:
        with open(os.path.join(STUBS_ROOT, file_path)) as f:
            source = f.read()

        tree = ast.parse(source)
        generic_visitor = GenericInheritanceVisitor()
        generic_visitor.visit(tree)

        # For each Generic in the stubs, import the associated module and capture every class in the MRO
        if generic_visitor.generic_classes:
            module_name = _get_module_from_pyi(file_path)
            django_module = importlib.import_module(module_name)
            all_generic_classes.update(
                {
                    cls: [subcls.__name__ for subcls in getattr(django_module, cls).mro()[1:-1]]
                    for cls in generic_visitor.generic_classes
                }
            )

    print(f"Processed {len(pyi_files)} .pyi files.")
    print(f"Found {len(all_generic_classes)} unique classes inheriting from Generic in stubs")

    patched_classes = {mp_generic.cls.__name__ for mp_generic in _get_need_generic()}

    # Pretty-print missing patch in `ext/django_stubs_ext/patch.py`
    errors = []
    for cls_name, subcls_names in all_generic_classes.items():
        if not any(name in patched_classes for name in [*subcls_names, cls_name]):
            bases = f"({', '.join(subcls_names)})" if subcls_names else ""
            errors.append(f"{cls_name}{bases} is not patched in `ext/django_stubs_ext/patch.py`")

    assert not errors, "\n".join(errors)


def _get_module_from_pyi(pyi_path: str) -> str:
    py_module = "django." + pyi_path.replace(".pyi", "").replace("/", ".")
    return py_module.removesuffix(".__init__")


def _get_need_generic() -> list[MPGeneric[Any]]:
    """
    Symbols in `django.contrib.auth.forms` are very hard to patch automatically
    because we end up importing the User model and it crashes if `django.setup()` was not called beforehand.
    It can also very easily introduce circular imports so we require the user to monkeypatch it manually.
    See README.md for more details
    """

    import django_stubs_ext

    if django.VERSION >= (5, 1):
        from django.contrib.auth.forms import SetPasswordMixin, SetUnusablePasswordMixin

        return [MPGeneric(SetPasswordMixin), MPGeneric(SetUnusablePasswordMixin), *django_stubs_ext.patch._need_generic]
    from django.contrib.auth.forms import AdminPasswordChangeForm, SetPasswordForm

    return [MPGeneric(SetPasswordForm), MPGeneric(AdminPasswordChangeForm), *django_stubs_ext.patch._need_generic]