1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import dataclasses
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Any, Callable, Iterable, List, Optional, Sequence, Type
from unittest.mock import patch
import libcst as cst
from libcst._nodes.internal import CodegenState, visit_required
from libcst._types import CSTNodeT
from libcst._visitors import CSTTransformer, CSTVisitorT
from libcst.metadata import CodeRange, PositionProvider
from libcst.metadata.position_provider import PositionProvidingCodegenState
from libcst.testing.utils import UnitTest
@dataclass(frozen=True)
class _CSTCodegenPatchTarget:
type: Type[cst.CSTNode]
name: str
old_codegen: Callable[..., None]
class _NOOPVisitor(CSTTransformer):
pass
def _cst_node_equality_func(
a: cst.CSTNode, b: cst.CSTNode, msg: Optional[str] = None
) -> None:
"""
For use with addTypeEqualityFunc.
"""
if not a.deep_equals(b):
suffix = "" if msg is None else f"\n{msg}"
raise AssertionError(f"\n{a!r}\nis not deeply equal to \n{b!r}{suffix}")
def parse_expression_as(**config: Any) -> Callable[[str], cst.BaseExpression]:
def inner(code: str) -> cst.BaseExpression:
return cst.parse_expression(code, config=cst.PartialParserConfig(**config))
return inner
def parse_statement_as(**config: Any) -> Callable[[str], cst.BaseStatement]:
def inner(code: str) -> cst.BaseStatement:
return cst.parse_statement(code, config=cst.PartialParserConfig(**config))
return inner
# We can't use an ABCMeta here, because of metaclass conflicts
class CSTNodeTest(UnitTest):
def setUp(self) -> None:
# Fix `self.assertEqual` for CSTNode subclasses. We should compare equality by
# value instead of identity (what `CSTNode.__eq__` does) for tests.
#
# The time complexity of CSTNode.deep_equals doesn't matter much inside tests.
for v in cst.__dict__.values():
if isinstance(v, type) and issubclass(v, cst.CSTNode):
self.addTypeEqualityFunc(v, _cst_node_equality_func)
self.addTypeEqualityFunc(DummyIndentedBlock, _cst_node_equality_func)
def validate_node(
self,
node: CSTNodeT,
code: str,
parser: Optional[Callable[[str], CSTNodeT]] = None,
expected_position: Optional[CodeRange] = None,
) -> None:
node.validate_types_deep()
self.__assert_codegen(node, code, expected_position)
if parser is not None:
parsed_node = parser(code)
self.assertEqual(parsed_node, node)
# Tests of children should unwrap DummyIndentedBlock first, because we don't
# want to test DummyIndentedBlock's behavior.
unwrapped_node = node
while isinstance(unwrapped_node, DummyIndentedBlock):
unwrapped_node = unwrapped_node.child
self.__assert_children_match_codegen(unwrapped_node)
self.__assert_children_match_fields(unwrapped_node)
self.__assert_visit_returns_identity(unwrapped_node)
def assert_invalid(
self, get_node: Callable[[], cst.CSTNode], expected_re: str
) -> None:
with self.assertRaisesRegex(cst.CSTValidationError, expected_re):
get_node()
def assert_invalid_types(
self, get_node: Callable[[], cst.CSTNode], expected_re: str
) -> None:
with self.assertRaisesRegex(TypeError, expected_re):
get_node().validate_types_shallow()
def __assert_codegen(
self,
node: cst.CSTNode,
expected: str,
expected_position: Optional[CodeRange] = None,
) -> None:
"""
Verifies that the given node's `_codegen` method is correct.
"""
module = cst.Module([])
self.assertEqual(module.code_for_node(node), expected)
if expected_position is not None:
# This is using some internal APIs, because we only want to compute
# position for the node being tested, not a whole module.
#
# Normally, this is a nonsense operation (how can a node have a position if
# its not in a module?), which is why it's not supported, but it makes
# sense in the context of these node tests.
provider = PositionProvider()
state = PositionProvidingCodegenState(
default_indent=module.default_indent,
default_newline=module.default_newline,
provider=provider,
)
node._codegen(state)
self.assertEqual(provider._computed[node], expected_position)
def __assert_children_match_codegen(self, node: cst.CSTNode) -> None:
children = node.children
codegen_children = self.__derive_children_from_codegen(node)
self.assertSequenceEqual(
children,
codegen_children,
msg=(
"The list of children we got from `node.children` differs from the "
+ "children that were visited by `node._codegen`. This is probably "
+ "due to a mismatch between _visit_and_replace_children and "
+ "_codegen_impl."
),
)
def __derive_children_from_codegen(
self, node: cst.CSTNode
) -> Sequence[cst.CSTNode]:
"""
Patches all subclasses of `CSTNode` exported by the `cst` module to track which
`_codegen` methods get called, generating a list of children.
Because all children must be rendered out into lexical order, this should be
equivalent to `node.children`.
`node.children` uses `_visit_and_replace_children` under the hood, not
`_codegen`, so this helps us verify that both of those two method's behaviors
are in sync.
"""
patch_targets: Iterable[_CSTCodegenPatchTarget] = [
_CSTCodegenPatchTarget(type=v, name=k, old_codegen=v._codegen)
for (k, v) in cst.__dict__.items()
if isinstance(v, type)
and issubclass(v, cst.CSTNode)
and hasattr(v, "_codegen")
]
children: List[cst.CSTNode] = []
codegen_stack: List[cst.CSTNode] = []
def _get_codegen_override(
target: _CSTCodegenPatchTarget,
) -> Callable[..., None]:
def _codegen_impl(self: CSTNodeT, *args: Any, **kwargs: Any) -> None:
should_pop = False
# Don't stick duplicates in the stack. This is needed so that we don't
# track calls to `super()._codegen()`.
if len(codegen_stack) == 0 or codegen_stack[-1] is not self:
# Check the stack to see that we're a direct child, not the root or
# a transitive child.
if len(codegen_stack) == 1:
children.append(self)
codegen_stack.append(self)
should_pop = True
target.old_codegen(self, *args, **kwargs)
# only pop if we pushed something to the stack earlier
if should_pop:
codegen_stack.pop()
return _codegen_impl
with ExitStack() as patch_stack:
for t in patch_targets:
patch_stack.enter_context(
patch(f"libcst.{t.name}._codegen", _get_codegen_override(t))
)
# Execute `node._codegen()`
cst.Module([]).code_for_node(node)
return children
def __assert_children_match_fields(self, node: cst.CSTNode) -> None:
"""
We expect `node.children` to match everything we can extract from the node's
fields, but maybe in a different order. This asserts that those things match.
If you want to verify order as well, use `assert_children_ordered`.
"""
node_children_ids = {id(child) for child in node.children}
fields = dataclasses.fields(node)
field_child_ids = set()
for f in fields:
value = getattr(node, f.name)
if isinstance(value, cst.CSTNode):
field_child_ids.add(id(value))
elif isinstance(value, Iterable):
field_child_ids.update(
id(el) for el in value if isinstance(el, cst.CSTNode)
)
# order doesn't matter
self.assertSetEqual(
node_children_ids,
field_child_ids,
msg="`node.children` doesn't match what we found through introspection",
)
def __assert_visit_returns_identity(self, node: cst.CSTNode) -> None:
"""
When visit is called with a visitor that acts as a no-op, the visit method
should return the same node it started with.
"""
# TODO: We're only checking equality right now, because visit currently clones
# the node, since that was easier to implement. We should fix that behavior in a
# later version and tighten this check.
self.assertEqual(node, node.visit(_NOOPVisitor()))
def assert_parses(
self,
code: str,
parser: Callable[[str], cst.CSTNode],
expect_success: bool,
) -> None:
if not expect_success:
with self.assertRaises(cst.ParserSyntaxError):
parser(code)
else:
parser(code)
@dataclass(frozen=True)
class DummyIndentedBlock(cst.CSTNode):
"""
A stripped-down version of cst.IndentedBlock that only sets/clears the indentation
state for the purpose of testing cst.IndentWhitespace in isolation.
"""
value: str
child: cst.CSTNode
def _codegen_impl(self, state: CodegenState) -> None:
state.increase_indent(self.value)
with state.record_syntactic_position(
self, start_node=self.child, end_node=self.child
):
self.child._codegen(state)
state.decrease_indent()
def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "DummyIndentedBlock":
return DummyIndentedBlock(
value=self.value, child=visit_required(self, "child", self.child, visitor)
)
|