File: inheritance_diagrams.py

package info (click to toggle)
sphinx-autoapi 3.3.3-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 900 kB
  • sloc: python: 5,146; makefile: 7
file content (131 lines) | stat: -rw-r--r-- 4,299 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import astroid
import sphinx.ext.inheritance_diagram


def _do_import_class(name, currmodule=None):
    path_stack = list(reversed(name.split(".")))
    if not currmodule:
        currmodule = path_stack.pop()

    try:
        target = astroid.MANAGER.ast_from_module_name(currmodule)
        while target and path_stack:
            path_part = path_stack.pop()
            target = (target.getattr(path_part) or (None,))[0]
            while isinstance(target, (astroid.ImportFrom, astroid.Import)):
                try:
                    target = target.do_import_module(path_part)
                except astroid.AstroidImportError:
                    target = target.do_import_module()
                    target = (target.getattr(path_part) or (None,))[0]
                    break
    except astroid.AstroidError:
        target = None

    return target


def _import_class(name, currmodule):
    target = None
    if currmodule:
        target = _do_import_class(name, currmodule)

    if target is None:
        target = _do_import_class(name)

    if not target:
        raise sphinx.ext.inheritance_diagram.InheritanceException(
            f"Could not import class or module {name} specified for inheritance diagram"
        )

    if isinstance(target, astroid.ClassDef):
        return [target]

    if isinstance(target, astroid.Module):
        classes = []
        for child in target.get_children():
            if isinstance(child, astroid.ClassDef):
                classes.append(child)
        return classes

    raise sphinx.ext.inheritance_diagram.InheritanceException(
        f"{name} specified for inheritance diagram is not a class or module"
    )


class _AutoapiInheritanceGraph(sphinx.ext.inheritance_diagram.InheritanceGraph):
    @staticmethod
    def _import_classes(class_names, currmodule):
        classes = []

        for name in class_names:
            classes.extend(_import_class(name, currmodule))

        return classes

    def _class_info(
        self, classes, show_builtins, private_bases, parts, aliases, top_classes
    ):
        all_classes = {}

        def recurse(cls):
            if cls in all_classes:
                return
            if not show_builtins and cls.root().name == "builtins":
                return
            if not private_bases and cls.name.startswith("_"):
                return

            nodename = self.class_name(cls, parts, aliases)
            fullname = self.class_name(cls, 0, aliases)

            tooltip = None
            if cls.doc_node:
                doc = cls.doc_node.value.strip().split("\n")[0]
                if doc:
                    tooltip = '"%s"' % doc.replace('"', '\\"')

            baselist = []
            all_classes[cls] = (nodename, fullname, baselist, tooltip or "")

            if fullname in top_classes:
                return

            for base in cls.ancestors(recurs=False):
                if not show_builtins and base.root().name == "builtins":
                    continue
                if not private_bases and base.name.startswith("_"):
                    continue
                baselist.append(self.class_name(base, parts, aliases))
                if base not in all_classes:
                    recurse(base)

        for cls in classes:
            recurse(cls)

        return list(all_classes.values())

    @staticmethod
    def class_name(node, parts=0, aliases=None):
        fullname = node.qname()
        if fullname.startswith(("__builtin__.", "builtins")):
            fullname = fullname.split(".", 1)[-1]
        if parts == 0:
            result = fullname
        else:
            name_parts = fullname.split(".")
            result = ".".join(name_parts[-parts:])
        if aliases is not None and result in aliases:
            return aliases[result]
        return result


class AutoapiInheritanceDiagram(sphinx.ext.inheritance_diagram.InheritanceDiagram):
    def run(self):
        # Yucky! Monkeypatch InheritanceGraph to use our own
        old_graph = sphinx.ext.inheritance_diagram.InheritanceGraph
        sphinx.ext.inheritance_diagram.InheritanceGraph = _AutoapiInheritanceGraph
        try:
            return super().run()
        finally:
            sphinx.ext.inheritance_diagram.InheritanceGraph = old_graph