File: missing_references.py

package info (click to toggle)
python-advanced-alchemy 1.4.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,708 kB
  • sloc: python: 25,811; makefile: 162; javascript: 123; sh: 4
file content (359 lines) | stat: -rw-r--r-- 12,152 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
"""Sphinx extension for changelog and change directives."""

# ruff: noqa: PLR0911, ARG001
from __future__ import annotations

import ast
import importlib
import inspect
from pathlib import Path
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from collections.abc import Generator

    from docutils.nodes import Element, Node
    from sphinx.addnodes import pending_xref
    from sphinx.application import Sphinx
    from sphinx.environment import BuildEnvironment


def _get_module_ast(source_file: str) -> ast.AST | ast.Module:
    return ast.parse(Path(source_file).read_text(encoding="utf-8"))


def _get_import_nodes(nodes: list[ast.stmt]) -> Generator[ast.Import | ast.ImportFrom, None, None]:
    for node in nodes:
        if isinstance(node, (ast.Import, ast.ImportFrom)):
            yield node
        elif isinstance(node, ast.If) and getattr(node.test, "id", None) == "TYPE_CHECKING":
            yield from _get_import_nodes(node.body)


def get_module_global_imports(module_import_path: str, reference_target_source_obj: str) -> set[str]:
    """Return a set of names that are imported globally within the containing module of ``reference_target_source_obj``,
    including imports in ``if TYPE_CHECKING`` blocks.
    """
    module = importlib.import_module(module_import_path)
    obj = getattr(module, reference_target_source_obj)
    tree = _get_module_ast(inspect.getsourcefile(obj))  # type: ignore[arg-type]  # pyright: ignore[reportArgumentType]

    import_nodes = _get_import_nodes(tree.body)  # type: ignore[attr-defined]
    return {path.asname or path.name for import_node in import_nodes for path in import_node.names}


def _resolve_local_reference(module_path: str, target: str) -> bool:
    """Attempt to resolve a reference within the local codebase.

    Args:
        module_path: The module path (e.g., 'advanced_alchemy.base')
        target: The target class/attribute name

    Returns:
        bool: True if reference exists, False otherwise
    """
    try:
        module = importlib.import_module(module_path)
        if "." in target:
            # Handle fully qualified names (e.g., advanced_alchemy.base.BasicAttributes)
            parts = target.split(".")
            current = module
            for part in parts:
                current = getattr(current, part)
            return True
        return hasattr(module, target)
    except (ImportError, AttributeError):
        return False


def _resolve_sqlalchemy_reference(target: str) -> bool:
    """Attempt to resolve SQLAlchemy references.

    Args:
        target: The target class/attribute name

    Returns:
        bool: True if reference exists, False otherwise
    """
    try:
        import sqlalchemy

        if "." in target:
            # Handle nested attributes (e.g., Connection.in_transaction)
            obj_name, attr_name = target.rsplit(".", 1)
            obj = getattr(sqlalchemy, obj_name)
            return hasattr(obj, attr_name)
        return hasattr(sqlalchemy, target)
    except (ImportError, AttributeError):
        return False


def _resolve_litestar_reference(target: str) -> bool:
    """Attempt to resolve Litestar references.

    Args:
        target: The target class/attribute name

    Returns:
        bool: True if reference exists, False otherwise
    """
    try:
        import litestar
        from litestar import datastructures

        # Handle common Litestar classes
        if target in {"Litestar", "State", "Scope", "Message", "AppConfig", "BeforeMessageSendHookHandler"}:
            return True
        if target.startswith("datastructures."):
            _, attr = target.split(".")
            return hasattr(datastructures, attr)
        if target.startswith("config.app."):
            return True  # These are valid Litestar config references
        return hasattr(litestar, target)
    except ImportError:
        return False


def _resolve_sqlalchemy_type_reference(target: str) -> bool:
    """Attempt to resolve SQLAlchemy type references.

    Args:
        target: The target class/attribute name

    Returns:
        bool: True if reference exists, False otherwise
    """
    try:
        from sqlalchemy import types as sa_types

        type_classes = {
            "TypeEngine",
            "TypeDecorator",
            "UserDefinedType",
            "ExternalType",
            "Dialect",
            "_types.TypeDecorator",
        }

        if target in type_classes:
            return True
        if target.startswith("_types."):
            _, attr = target.split(".")
            return hasattr(sa_types, attr)
        return hasattr(sa_types, target)
    except ImportError:
        return False


def _resolve_advanced_alchemy_reference(target: str, module: str) -> bool:
    """Attempt to resolve Advanced Alchemy references.

    Args:
        target: The target class/attribute name
        module: The current module context

    Returns:
        bool: True if reference exists, False otherwise
    """
    # Handle base module references
    base_classes = {
        "BasicAttributes",
        "CommonTableAttributes",
        "AuditColumns",
        "BigIntPrimaryKey",
        "UUIDPrimaryKey",
        "UUIDv6PrimaryKey",
        "UUIDv7PrimaryKey",
        "NanoIDPrimaryKey",
        "Empty",
        "TableArgsType",
        "DeclarativeBase",
    }

    # Handle config module references
    config_classes = {
        "EngineT",
        "SessionT",
        "SessionMakerT",
        "ConnectionT",
        "GenericSessionConfig",
        "GenericAlembicConfig",
    }

    func_references = {"repository.SQLAlchemyAsyncRepositoryProtocol.add_many"}

    # Handle type module references
    type_classes = {"DateTimeUTC", "ORA_JSONB", "GUID", "EncryptedString", "EncryptedText"}

    if target in base_classes or target in config_classes or target in type_classes:
        return True

    # Handle fully qualified references
    if target.startswith("advanced_alchemy."):
        parts = target.split(".")
        if parts[-1] in base_classes | config_classes | type_classes | func_references:
            return True

    # Handle module-relative references
    return bool(module.startswith("advanced_alchemy."))


def _resolve_serialization_reference(target: str) -> bool:
    """Attempt to resolve serialization-related references.

    Args:
        target: The target class/attribute name

    Returns:
        bool: True if reference exists, False otherwise
    """
    serialization_attrs = {"decode_json", "encode_json", "serialization.decode_json", "serialization.encode_json"}
    return target in serialization_attrs


def _resolve_click_reference(target: str) -> bool:
    """Attempt to resolve Click references.

    Args:
        target: The target class/attribute name

    Returns:
        bool: True if reference exists, False otherwise
    """
    try:
        import click

        if target == "Group":
            return True
        return hasattr(click, target)
    except ImportError:
        return False


def on_warn_missing_reference(app: Sphinx, domain: str, node: Node) -> bool | None:
    if node.tagname != "pending_xref":  # type: ignore[attr-defined]
        return None

    if not hasattr(node, "attributes"):
        return None

    # Wrap the main logic in a try-except to catch potential AttributeErrors (e.g., startswith on None)
    try:
        attributes = node.attributes  # type: ignore[attr-defined,unused-ignore]
        target = attributes["reftarget"]  # pyright: ignore
        ref_type = attributes.get("reftype")  # pyright: ignore
        module = attributes.get("py:module", "")  # pyright: ignore

        # Handle TypeVar references
        if hasattr(target, "__class__") and target.__class__.__name__ == "TypeVar":  # pyright: ignore
            return True

        # Handle Advanced Alchemy references
        if _resolve_advanced_alchemy_reference(target, module):  # pyright: ignore
            return True

        # Handle SQLAlchemy type system references
        if ref_type in {"class", "meth", "attr"} and any(
            x in target for x in ["TypeDecorator", "TypeEngine", "Dialect", "ExternalType", "UserDefinedType"]
        ):
            return _resolve_sqlalchemy_type_reference(target)  # pyright: ignore

        # Handle SQLAlchemy core references
        if (isinstance(target, str) and target.startswith("sqlalchemy.")) or (
            ref_type in {"class", "attr", "obj", "meth"}
            and target
            in {
                "Engine",
                "Session",
                "Connection",
                "MetaData",
                "AsyncSession",
                "AsyncEngine",
                "AsyncConnection",
                "sessionmaker",
                "async_sessionmaker",
            }
        ):
            # Ensure target is string before replace
            clean_target = target.replace("sqlalchemy.", "") if isinstance(target, str) else ""
            if clean_target and _resolve_sqlalchemy_reference(clean_target):
                return True

        # Handle Litestar references
        if ref_type in {"class", "obj"} and (
            (isinstance(target, str) and target.startswith(("datastructures.", "config.app.")))
            or target
            in {
                "Litestar",
                "State",
                "Scope",
                "Message",
                "AppConfig",
                "BeforeMessageSendHookHandler",
                "FieldDefinition",
                "ImproperConfigurationError",
            }
        ):
            return _resolve_litestar_reference(target)  # pyright: ignore

        # Handle serialization references
        if ref_type in {"attr", "meth"} and _resolve_serialization_reference(target):  # pyright: ignore
            return True

        # Handle Click references
        if ref_type == "class" and _resolve_click_reference(target):  # pyright: ignore
            return True

    except AttributeError:
        # Catch the specific error (likely startswith on None) and allow Sphinx to handle the warning normally
        return None

    return None


def on_missing_reference(app: Sphinx, env: BuildEnvironment, node: pending_xref, contnode: Element) -> Element | None:
    """Handle missing references by attempting to resolve them through different methods.

    Args:
        app: The Sphinx application instance
        env: The Sphinx build environment
        node: The pending cross-reference node
        contnode: The content node

    Returns:
        Element | None: The resolved reference node if found, None otherwise
    """
    if not hasattr(node, "attributes"):
        return None

    attributes = node.attributes  # type: ignore[attr-defined,unused-ignore]
    target = attributes["reftarget"]

    # Remove this check since it's causing issues
    if not isinstance(target, str):
        return None

    py_domain = env.domains["py"]

    # autodoc sometimes incorrectly resolves these types, so we try to resolve them as py:data first and fall back to any
    new_node = py_domain.resolve_xref(env, node["refdoc"], app.builder, "data", target, node, contnode)
    if new_node is None:
        resolved_xrefs = py_domain.resolve_any_xref(env, node["refdoc"], app.builder, target, node, contnode)
        for ref in resolved_xrefs:
            if ref:
                return ref[1]
    return new_node


def on_env_before_read_docs(app: Sphinx, env: BuildEnvironment, docnames: set[str]) -> None:
    tmp_examples_path = Path.cwd() / "docs/_build/_tmp_examples"
    tmp_examples_path.mkdir(exist_ok=True, parents=True)
    env.tmp_examples_path = tmp_examples_path  # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]


def setup(app: Sphinx) -> dict[str, bool]:
    app.connect("env-before-read-docs", on_env_before_read_docs)
    app.connect("missing-reference", on_missing_reference)
    app.connect("warn-missing-reference", on_warn_missing_reference)
    app.add_config_value("ignore_missing_refs", default={}, rebuild="")
    return {"parallel_read_safe": True, "parallel_write_safe": True}