File: gather.py

package info (click to toggle)
python-libcst 1.4.0-1.2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 5,928 kB
  • sloc: python: 76,235; makefile: 10; sh: 2
file content (160 lines) | stat: -rw-r--r-- 4,714 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import inspect
from collections import defaultdict
from collections.abc import Sequence as ABCSequence
from dataclasses import dataclass, fields, replace
from typing import Dict, Iterator, List, Mapping, Sequence, Set, Type, Union

import libcst as cst


def _get_bases() -> Iterator[Type[cst.CSTNode]]:
    """
    Get all base classes that are subclasses of CSTNode but not an actual
    node itself. This allows us to keep our types sane by refering to the
    base classes themselves.
    """

    for name in dir(cst):
        if not name.startswith("Base"):
            continue

        yield getattr(cst, name)


typeclasses: Sequence[Type[cst.CSTNode]] = sorted(
    _get_bases(), key=lambda base: base.__name__
)


def _get_nodes() -> Iterator[Type[cst.CSTNode]]:
    """
    Grab all CSTNodes that are not a superclass. Basically, anything that a
    person might use to generate a tree.
    """

    for name in dir(cst):
        if name.startswith("__") and name.endswith("__"):
            continue
        if name == "CSTNode":
            continue

        node = getattr(cst, name)
        try:
            if issubclass(node, cst.CSTNode):
                yield node
        except TypeError:
            # This isn't a class, so we don't care about it.
            pass


all_libcst_nodes: Sequence[Type[cst.CSTNode]] = sorted(
    _get_nodes(), key=lambda node: node.__name__
)
node_to_bases: Dict[Type[cst.CSTNode], List[Type[cst.CSTNode]]] = {}
for node in all_libcst_nodes:
    # Map the base classes for this node
    node_to_bases[node] = list(
        reversed([b for b in inspect.getmro(node) if issubclass(b, cst.CSTNode)])
    )


def _get_most_generic_base_for_node(node: Type[cst.CSTNode]) -> Type[cst.CSTNode]:
    # Ignore non-exported bases, a user couldn't specify these types
    # in type hints.
    exportable_bases = [b for b in node_to_bases[node] if b in node_to_bases]
    return exportable_bases[0]


nodebases: Dict[Type[cst.CSTNode], Type[cst.CSTNode]] = {}
for node in all_libcst_nodes:
    # Find the most generic version of this node that isn't CSTNode.
    nodebases[node] = _get_most_generic_base_for_node(node)


@dataclass(frozen=True)
class Usage:
    maybe: bool = False
    optional: bool = False
    sequence: bool = False


nodeuses: Dict[Type[cst.CSTNode], Usage] = {node: Usage() for node in all_libcst_nodes}


def _is_maybe(typeobj: object) -> bool:
    try:
        # pyre-ignore We wrap this in a TypeError check so this is safe
        return issubclass(typeobj, cst.MaybeSentinel)
    except TypeError:
        return False


def _get_origin(typeobj: object) -> object:
    try:
        # pyre-ignore We wrap this in a AttributeError check so this is safe
        return typeobj.__origin__
    except AttributeError:
        # Don't care, not a union or sequence
        return None


def _get_args(typeobj: object) -> List[object]:
    try:
        # pyre-ignore We wrap this in a AttributeError check so this is safe
        return typeobj.__args__
    except AttributeError:
        # Don't care, not a union or sequence
        return []


def _is_sequence(typeobj: object) -> bool:
    origin = _get_origin(typeobj)
    return origin is Sequence or origin is ABCSequence


def _is_union(typeobj: object) -> bool:
    return _get_origin(typeobj) is Union


def _calc_node_usage(typeobj: object) -> None:
    if _is_union(typeobj):
        has_maybe = any(_is_maybe(n) for n in _get_args(typeobj))
        has_none = any(isinstance(n, type(None)) for n in _get_args(typeobj))

        for node in _get_args(typeobj):
            if node in all_libcst_nodes:
                nodeuses[node] = replace(
                    nodeuses[node],
                    maybe=nodeuses[node].maybe or has_maybe,
                    optional=nodeuses[node].optional or has_none,
                )
            else:
                _calc_node_usage(node)

    if _is_sequence(typeobj):
        for node in _get_args(typeobj):
            if node in all_libcst_nodes:
                nodeuses[node] = replace(nodeuses[node], sequence=True)
            else:
                _calc_node_usage(node)


for node in all_libcst_nodes:
    for field in fields(node) or []:
        if field.name == "_metadata":
            continue

        _calc_node_usage(field.type)


imports: Mapping[str, Set[str]] = defaultdict(set)
for node, base in nodebases.items():
    if node.__name__.startswith("Base"):
        continue
    for x in (node, base):
        imports[x.__module__].add(x.__name__)