1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from typing import List, Optional, Sequence
from libcst import BaseStatement, CSTNode, Module
from libcst._add_slots import add_slots
from libcst._nodes.internal import CodegenState
from libcst.metadata import BaseMetadataProvider
class CodegenPartial:
"""
Provided by :class:`ExperimentalReentrantCodegenProvider`.
Stores enough information to generate either a small patch
(:meth:`get_modified_code_range`) or a new file (:meth:`get_modified_code`) by
replacing the old node at this position.
"""
__slots__ = [
"start_offset",
"end_offset",
"has_trailing_newline",
"_indent_tokens",
"_prev_codegen_state",
]
def __init__(self, state: "_ReentrantCodegenState") -> None:
# store a frozen copy of these values, since they change over time
self.start_offset: int = state.start_offset_stack[-1]
self.end_offset: int = state.char_offset
self.has_trailing_newline: bool = True # this may get updated to False later
self._indent_tokens: Sequence[str] = tuple(state.indent_tokens)
# everything else can be accessed from the codegen state object
self._prev_codegen_state: _ReentrantCodegenState = state
def get_original_module_code(self) -> str:
"""
Equivalent to :meth:`libcst.Module.bytes` on the top-level module that contains
this statement, except that it uses the cached result from our previous code
generation pass, so it's faster.
"""
return self._prev_codegen_state.get_code()
def get_original_module_bytes(self) -> bytes:
"""
Equivalent to :meth:`libcst.Module.bytes` on the top-level module that contains
this statement, except that it uses the cached result from our previous code
generation pass, so it's faster.
"""
return self.get_original_module_code().encode(self._prev_codegen_state.encoding)
def get_original_statement_code(self) -> str:
"""
Equivalent to :meth:`libcst.Module.code_for_node` on the current statement,
except that it uses the cached result from our previous code generation pass,
so it's faster.
"""
return self._prev_codegen_state.get_code()[self.start_offset : self.end_offset]
def get_modified_statement_code(self, node: BaseStatement) -> str:
"""
Gets the new code for ``node`` as if it were in same location as the old
statement being replaced. This means that it inherits details like the old
statement's indentation.
"""
new_codegen_state = CodegenState(
default_indent=self._prev_codegen_state.default_indent,
default_newline=self._prev_codegen_state.default_newline,
indent_tokens=list(self._indent_tokens),
)
node._codegen(new_codegen_state)
if not self.has_trailing_newline:
new_codegen_state.pop_trailing_newline()
return "".join(new_codegen_state.tokens)
def get_modified_module_code(self, node: BaseStatement) -> str:
"""
Gets the new code for the module at the root of this statement's tree, but with
the supplied replacement ``node`` in its place.
"""
original = self.get_original_module_code()
patch = self.get_modified_statement_code(node)
return f"{original[:self.start_offset]}{patch}{original[self.end_offset:]}"
def get_modified_module_bytes(self, node: BaseStatement) -> bytes:
"""
Gets the new bytes for the module at the root of this statement's tree, but with
the supplied replacement ``node`` in its place.
"""
return self.get_modified_module_code(node).encode(
self._prev_codegen_state.encoding
)
@add_slots
@dataclass(frozen=False)
class _ReentrantCodegenState(CodegenState):
provider: BaseMetadataProvider[CodegenPartial]
encoding: str = "utf-8"
indent_size: int = 0
char_offset: int = 0
start_offset_stack: List[int] = field(default_factory=list)
cached_code: Optional[str] = None
trailing_partials: List[CodegenPartial] = field(default_factory=list)
def increase_indent(self, value: str) -> None:
super(_ReentrantCodegenState, self).increase_indent(value)
self.indent_size += len(value)
def decrease_indent(self) -> None:
self.indent_size -= len(self.indent_tokens[-1])
super(_ReentrantCodegenState, self).decrease_indent()
def add_indent_tokens(self) -> None:
super(_ReentrantCodegenState, self).add_indent_tokens()
self.char_offset += self.indent_size
def add_token(self, value: str) -> None:
super(_ReentrantCodegenState, self).add_token(value)
self.char_offset += len(value)
self.trailing_partials.clear()
def before_codegen(self, node: CSTNode) -> None:
if not isinstance(node, BaseStatement):
return
self.start_offset_stack.append(self.char_offset)
def after_codegen(self, node: CSTNode) -> None:
if not isinstance(node, BaseStatement):
return
partial = CodegenPartial(self)
self.provider.set_metadata(node, partial)
self.start_offset_stack.pop()
self.trailing_partials.append(partial)
def pop_trailing_newline(self) -> None:
"""
:class:`libcst.Module` contains a hack where it removes the last token (a
newline) if the original file didn't have a newline.
If this happens, we need to go back through every node at the end of the file,
and fix their `end_offset`.
"""
for tp in self.trailing_partials:
tp.end_offset -= len(self.tokens[-1])
tp.has_trailing_newline = False
super(_ReentrantCodegenState, self).pop_trailing_newline()
def get_code(self) -> str:
# Ideally this would use functools.cached_property, but that's only in
# Python 3.8+.
#
# This is a little ugly to make pyre's attribute refinement checks happy.
cached_code = self.cached_code
if cached_code is not None:
return cached_code
cached_code = "".join(self.tokens)
self.cached_code = cached_code
return cached_code
class ExperimentalReentrantCodegenProvider(BaseMetadataProvider[CodegenPartial]):
"""
An experimental API that allows fast generation of modified code by recording an
initial code-generation pass, and incrementally applying updates. It is a
performance optimization for a few niche use-cases and is not user-friendly.
**This API may change at any time without warning (including in minor releases).**
This is rarely useful. Instead you should make multiple modifications to a single
syntax tree, and generate the code once. However, we can think of a few use-cases
for this API (hence, why it exists):
- When linting a file, you might generate multiple independent patches that a user
can accept or reject. Depending on your architecture, it may be advantageous to
avoid regenerating the file when computing each patch.
- You might want to call out to an external utility (e.g. a typechecker, such as
pyre or mypy) to validate a small change. You may need to generate and test lots
of these patches.
Restrictions:
- For safety and sanity reasons, the smallest/only level of granularity is a
statement. If you need to patch part of a statement, you regenerate the entire
statement. If you need to regenerate an entire module, just call
:meth:`libcst.Module.code`.
- This does not (currently) operate recursively. You can patch an unpatched piece
of code multiple times, but you can't layer additional patches on an already
patched piece of code.
"""
def _gen_impl(self, module: Module) -> None:
state = _ReentrantCodegenState(
default_indent=module.default_indent,
default_newline=module.default_newline,
provider=self,
encoding=module.encoding,
)
module._codegen(state)
|