1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import inspect
from collections import defaultdict
from collections.abc import Sequence as ABCSequence
from dataclasses import dataclass, fields, replace
from typing import Dict, Iterator, List, Mapping, Sequence, Set, Type, Union
import libcst as cst
def _get_bases() -> Iterator[Type[cst.CSTNode]]:
"""
Get all base classes that are subclasses of CSTNode but not an actual
node itself. This allows us to keep our types sane by refering to the
base classes themselves.
"""
for name in dir(cst):
if not name.startswith("Base"):
continue
yield getattr(cst, name)
typeclasses: Sequence[Type[cst.CSTNode]] = sorted(
_get_bases(), key=lambda base: base.__name__
)
def _get_nodes() -> Iterator[Type[cst.CSTNode]]:
"""
Grab all CSTNodes that are not a superclass. Basically, anything that a
person might use to generate a tree.
"""
for name in dir(cst):
if name.startswith("__") and name.endswith("__"):
continue
if name == "CSTNode":
continue
node = getattr(cst, name)
try:
if issubclass(node, cst.CSTNode):
yield node
except TypeError:
# This isn't a class, so we don't care about it.
pass
all_libcst_nodes: Sequence[Type[cst.CSTNode]] = sorted(
_get_nodes(), key=lambda node: node.__name__
)
node_to_bases: Dict[Type[cst.CSTNode], List[Type[cst.CSTNode]]] = {}
for node in all_libcst_nodes:
# Map the base classes for this node
node_to_bases[node] = list(
reversed([b for b in inspect.getmro(node) if issubclass(b, cst.CSTNode)])
)
def _get_most_generic_base_for_node(node: Type[cst.CSTNode]) -> Type[cst.CSTNode]:
# Ignore non-exported bases, a user couldn't specify these types
# in type hints.
exportable_bases = [b for b in node_to_bases[node] if b in node_to_bases]
return exportable_bases[0]
nodebases: Dict[Type[cst.CSTNode], Type[cst.CSTNode]] = {}
for node in all_libcst_nodes:
# Find the most generic version of this node that isn't CSTNode.
nodebases[node] = _get_most_generic_base_for_node(node)
@dataclass(frozen=True)
class Usage:
maybe: bool = False
optional: bool = False
sequence: bool = False
nodeuses: Dict[Type[cst.CSTNode], Usage] = {node: Usage() for node in all_libcst_nodes}
def _is_maybe(typeobj: object) -> bool:
try:
# pyre-ignore We wrap this in a TypeError check so this is safe
return issubclass(typeobj, cst.MaybeSentinel)
except TypeError:
return False
def _get_origin(typeobj: object) -> object:
try:
# pyre-ignore We wrap this in a AttributeError check so this is safe
return typeobj.__origin__
except AttributeError:
# Don't care, not a union or sequence
return None
def _get_args(typeobj: object) -> List[object]:
try:
# pyre-ignore We wrap this in a AttributeError check so this is safe
return typeobj.__args__
except AttributeError:
# Don't care, not a union or sequence
return []
def _is_sequence(typeobj: object) -> bool:
origin = _get_origin(typeobj)
return origin is Sequence or origin is ABCSequence
def _is_union(typeobj: object) -> bool:
return _get_origin(typeobj) is Union
def _calc_node_usage(typeobj: object) -> None:
if _is_union(typeobj):
has_maybe = any(_is_maybe(n) for n in _get_args(typeobj))
has_none = any(isinstance(n, type(None)) for n in _get_args(typeobj))
for node in _get_args(typeobj):
if node in all_libcst_nodes:
nodeuses[node] = replace(
nodeuses[node],
maybe=nodeuses[node].maybe or has_maybe,
optional=nodeuses[node].optional or has_none,
)
else:
_calc_node_usage(node)
if _is_sequence(typeobj):
for node in _get_args(typeobj):
if node in all_libcst_nodes:
nodeuses[node] = replace(nodeuses[node], sequence=True)
else:
_calc_node_usage(node)
for node in all_libcst_nodes:
for field in fields(node) or []:
if field.name == "_metadata":
continue
_calc_node_usage(field.type)
imports: Mapping[str, Set[str]] = defaultdict(set)
for node, base in nodebases.items():
if node.__name__.startswith("Base"):
continue
for x in (node, base):
imports[x.__module__].add(x.__name__)
|