File: reentrant_codegen.py

package info (click to toggle)
python-libcst 1.4.0-1.2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 5,928 kB
  • sloc: python: 76,235; makefile: 10; sh: 2
file content (209 lines) | stat: -rw-r--r-- 8,391 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


from dataclasses import dataclass, field
from typing import List, Optional, Sequence

from libcst import BaseStatement, CSTNode, Module
from libcst._add_slots import add_slots
from libcst._nodes.internal import CodegenState
from libcst.metadata import BaseMetadataProvider


class CodegenPartial:
    """
    Provided by :class:`ExperimentalReentrantCodegenProvider`.

    Stores enough information to generate either a small patch
    (:meth:`get_modified_code_range`) or a new file (:meth:`get_modified_code`) by
    replacing the old node at this position.
    """

    __slots__ = [
        "start_offset",
        "end_offset",
        "has_trailing_newline",
        "_indent_tokens",
        "_prev_codegen_state",
    ]

    def __init__(self, state: "_ReentrantCodegenState") -> None:
        # store a frozen copy of these values, since they change over time
        self.start_offset: int = state.start_offset_stack[-1]
        self.end_offset: int = state.char_offset
        self.has_trailing_newline: bool = True  # this may get updated to False later
        self._indent_tokens: Sequence[str] = tuple(state.indent_tokens)
        # everything else can be accessed from the codegen state object
        self._prev_codegen_state: _ReentrantCodegenState = state

    def get_original_module_code(self) -> str:
        """
        Equivalent to :meth:`libcst.Module.bytes` on the top-level module that contains
        this statement, except that it uses the cached result from our previous code
        generation pass, so it's faster.
        """
        return self._prev_codegen_state.get_code()

    def get_original_module_bytes(self) -> bytes:
        """
        Equivalent to :meth:`libcst.Module.bytes` on the top-level module that contains
        this statement, except that it uses the cached result from our previous code
        generation pass, so it's faster.
        """
        return self.get_original_module_code().encode(self._prev_codegen_state.encoding)

    def get_original_statement_code(self) -> str:
        """
        Equivalent to :meth:`libcst.Module.code_for_node` on the current statement,
        except that it uses the cached result from our previous code generation pass,
        so it's faster.
        """
        return self._prev_codegen_state.get_code()[self.start_offset : self.end_offset]

    def get_modified_statement_code(self, node: BaseStatement) -> str:
        """
        Gets the new code for ``node`` as if it were in same location as the old
        statement being replaced. This means that it inherits details like the old
        statement's indentation.
        """
        new_codegen_state = CodegenState(
            default_indent=self._prev_codegen_state.default_indent,
            default_newline=self._prev_codegen_state.default_newline,
            indent_tokens=list(self._indent_tokens),
        )
        node._codegen(new_codegen_state)
        if not self.has_trailing_newline:
            new_codegen_state.pop_trailing_newline()
        return "".join(new_codegen_state.tokens)

    def get_modified_module_code(self, node: BaseStatement) -> str:
        """
        Gets the new code for the module at the root of this statement's tree, but with
        the supplied replacement ``node`` in its place.
        """
        original = self.get_original_module_code()
        patch = self.get_modified_statement_code(node)
        return f"{original[:self.start_offset]}{patch}{original[self.end_offset:]}"

    def get_modified_module_bytes(self, node: BaseStatement) -> bytes:
        """
        Gets the new bytes for the module at the root of this statement's tree, but with
        the supplied replacement ``node`` in its place.
        """
        return self.get_modified_module_code(node).encode(
            self._prev_codegen_state.encoding
        )


@add_slots
@dataclass(frozen=False)
class _ReentrantCodegenState(CodegenState):
    provider: BaseMetadataProvider[CodegenPartial]
    encoding: str = "utf-8"
    indent_size: int = 0
    char_offset: int = 0
    start_offset_stack: List[int] = field(default_factory=list)
    cached_code: Optional[str] = None
    trailing_partials: List[CodegenPartial] = field(default_factory=list)

    def increase_indent(self, value: str) -> None:
        super(_ReentrantCodegenState, self).increase_indent(value)
        self.indent_size += len(value)

    def decrease_indent(self) -> None:
        self.indent_size -= len(self.indent_tokens[-1])
        super(_ReentrantCodegenState, self).decrease_indent()

    def add_indent_tokens(self) -> None:
        super(_ReentrantCodegenState, self).add_indent_tokens()
        self.char_offset += self.indent_size

    def add_token(self, value: str) -> None:
        super(_ReentrantCodegenState, self).add_token(value)
        self.char_offset += len(value)
        self.trailing_partials.clear()

    def before_codegen(self, node: CSTNode) -> None:
        if not isinstance(node, BaseStatement):
            return

        self.start_offset_stack.append(self.char_offset)

    def after_codegen(self, node: CSTNode) -> None:
        if not isinstance(node, BaseStatement):
            return

        partial = CodegenPartial(self)
        self.provider.set_metadata(node, partial)
        self.start_offset_stack.pop()
        self.trailing_partials.append(partial)

    def pop_trailing_newline(self) -> None:
        """
        :class:`libcst.Module` contains a hack where it removes the last token (a
        newline) if the original file didn't have a newline.

        If this happens, we need to go back through every node at the end of the file,
        and fix their `end_offset`.
        """
        for tp in self.trailing_partials:
            tp.end_offset -= len(self.tokens[-1])
            tp.has_trailing_newline = False
        super(_ReentrantCodegenState, self).pop_trailing_newline()

    def get_code(self) -> str:
        # Ideally this would use functools.cached_property, but that's only in
        # Python 3.8+.
        #
        # This is a little ugly to make pyre's attribute refinement checks happy.
        cached_code = self.cached_code
        if cached_code is not None:
            return cached_code
        cached_code = "".join(self.tokens)
        self.cached_code = cached_code
        return cached_code


class ExperimentalReentrantCodegenProvider(BaseMetadataProvider[CodegenPartial]):
    """
    An experimental API that allows fast generation of modified code by recording an
    initial code-generation pass, and incrementally applying updates. It is a
    performance optimization for a few niche use-cases and is not user-friendly.

    **This API may change at any time without warning (including in minor releases).**

    This is rarely useful. Instead you should make multiple modifications to a single
    syntax tree, and generate the code once. However, we can think of a few use-cases
    for this API (hence, why it exists):

    - When linting a file, you might generate multiple independent patches that a user
      can accept or reject. Depending on your architecture, it may be advantageous to
      avoid regenerating the file when computing each patch.

    - You might want to call out to an external utility (e.g. a typechecker, such as
      pyre or mypy) to validate a small change. You may need to generate and test lots
      of these patches.

    Restrictions:

    - For safety and sanity reasons, the smallest/only level of granularity is a
      statement. If you need to patch part of a statement, you regenerate the entire
      statement. If you need to regenerate an entire module, just call
      :meth:`libcst.Module.code`.

    - This does not (currently) operate recursively. You can patch an unpatched piece
      of code multiple times, but you can't layer additional patches on an already
      patched piece of code.
    """

    def _gen_impl(self, module: Module) -> None:
        state = _ReentrantCodegenState(
            default_indent=module.default_indent,
            default_newline=module.default_newline,
            provider=self,
            encoding=module.encoding,
        )
        module._codegen(state)