File: gen_visitor_functions.py

package info (click to toggle)
python-libcst 1.4.0-1.2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 5,928 kB
  • sloc: python: 76,235; makefile: 10; sh: 2
file content (116 lines) | stat: -rw-r--r-- 4,459 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import fields
from typing import List

from libcst.codegen.gather import imports, nodebases, nodeuses

generated_code: List[str] = []
generated_code.append("# Copyright (c) Meta Platforms, Inc. and affiliates.")
generated_code.append("#")
generated_code.append(
    "# This source code is licensed under the MIT license found in the"
)
generated_code.append("# LICENSE file in the root directory of this source tree.")
generated_code.append("")
generated_code.append("")
generated_code.append("# This file was generated by libcst.codegen.gen_matcher_classes")
generated_code.append("from typing import Optional, Union, TYPE_CHECKING")
generated_code.append("")
generated_code.append("from libcst._flatten_sentinel import FlattenSentinel")
generated_code.append("from libcst._maybe_sentinel import MaybeSentinel")
generated_code.append("from libcst._removal_sentinel import RemovalSentinel")
generated_code.append("from libcst._typed_visitor_base import mark_no_op")

# Import the types we use. These have to be type guarded since it would
# cause an import cycle otherwise.
generated_code.append("")
generated_code.append("")
generated_code.append("if TYPE_CHECKING:")
for module, objects in imports.items():
    generated_code.append(f"    from {module} import (  # noqa: F401")
    generated_code.append(f"        {', '.join(sorted(objects))}")
    generated_code.append("    )")


# Generate the base visit_ methods
generated_code.append("")
generated_code.append("")
generated_code.append("class CSTTypedBaseFunctions:")
for node in sorted(nodebases.keys(), key=lambda node: node.__name__):
    name = node.__name__
    if name.startswith("Base"):
        continue

    generated_code.append("")
    generated_code.append("    @mark_no_op")
    generated_code.append(
        f'    def visit_{name}(self, node: "{name}") -> Optional[bool]:'
    )
    generated_code.append("        pass")
    for field in fields(node) or []:
        if field.name == "_metadata":
            continue
        generated_code.append("")
        generated_code.append("    @mark_no_op")
        generated_code.append(
            f'    def visit_{name}_{field.name}(self, node: "{name}") -> None:'
        )
        generated_code.append("        pass")
        generated_code.append("")
        generated_code.append("    @mark_no_op")
        generated_code.append(
            f'    def leave_{name}_{field.name}(self, node: "{name}") -> None:'
        )
        generated_code.append("        pass")

# Generate the visitor leave_ methods
generated_code.append("")
generated_code.append("")
generated_code.append("class CSTTypedVisitorFunctions(CSTTypedBaseFunctions):")
for node in sorted(nodebases.keys(), key=lambda node: node.__name__):
    name = node.__name__
    if name.startswith("Base"):
        continue

    generated_code.append("")
    generated_code.append("    @mark_no_op")
    generated_code.append(
        f'    def leave_{name}(self, original_node: "{name}") -> None:'
    )
    generated_code.append("        pass")

# Generate the transformer leave_ methods
generated_code.append("")
generated_code.append("")
generated_code.append("class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):")
for node in sorted(nodebases.keys(), key=lambda node: node.__name__):
    name = node.__name__
    if name.startswith("Base"):
        continue
    generated_code.append("")
    generated_code.append("    @mark_no_op")
    valid_return_types: List[str] = [f'"{nodebases[node].__name__}"']
    node_uses = nodeuses[node]
    base_uses = nodeuses[nodebases[node]]
    if node_uses.maybe or base_uses.maybe:
        valid_return_types.append("MaybeSentinel")

    if node_uses.sequence or base_uses.sequence:
        valid_return_types.append(f'FlattenSentinel["{nodebases[node].__name__}"]')
        valid_return_types.append("RemovalSentinel")
    elif node_uses.optional or base_uses.optional:
        valid_return_types.append("RemovalSentinel")

    generated_code.append(
        f'    def leave_{name}(self, original_node: "{name}", updated_node: "{name}") -> Union[{", ".join(valid_return_types)}]:'
    )
    generated_code.append("        return updated_node")


if __name__ == "__main__":
    # Output the code
    print("\n".join(generated_code))