File: test_flatten_behavior.py

package info (click to toggle)
python-libcst 1.4.0-1.2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 5,928 kB
  • sloc: python: 76,235; makefile: 10; sh: 2
file content (79 lines) | stat: -rw-r--r-- 2,736 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Type, Union

import libcst as cst
from libcst import FlattenSentinel, parse_expression, parse_module, RemovalSentinel
from libcst._nodes.tests.base import CSTNodeTest
from libcst._types import CSTNodeT
from libcst._visitors import CSTTransformer
from libcst.testing.utils import data_provider


class InsertPrintBeforeReturn(CSTTransformer):
    def leave_Return(
        self, original_node: cst.Return, updated_node: cst.Return
    ) -> Union[cst.Return, RemovalSentinel, FlattenSentinel[cst.BaseSmallStatement]]:
        return FlattenSentinel(
            [
                cst.Expr(parse_expression("print('returning')")),
                updated_node,
            ]
        )


class FlattenLines(CSTTransformer):
    def on_leave(
        self, original_node: CSTNodeT, updated_node: CSTNodeT
    ) -> Union[CSTNodeT, RemovalSentinel, FlattenSentinel[cst.SimpleStatementLine]]:
        if isinstance(updated_node, cst.SimpleStatementLine):
            return FlattenSentinel(
                [
                    cst.SimpleStatementLine(
                        [stmt.with_changes(semicolon=cst.MaybeSentinel.DEFAULT)]
                    )
                    for stmt in updated_node.body
                ]
            )
        else:
            return updated_node


class RemoveReturnWithEmpty(CSTTransformer):
    def leave_Return(
        self, original_node: cst.Return, updated_node: cst.Return
    ) -> Union[cst.Return, RemovalSentinel, FlattenSentinel[cst.BaseSmallStatement]]:
        return FlattenSentinel([])


class FlattenBehavior(CSTNodeTest):
    @data_provider(
        (
            ("return", "print('returning'); return", InsertPrintBeforeReturn),
            (
                "print('returning'); return",
                "print('returning')\nreturn",
                FlattenLines,
            ),
            (
                "print('returning')\nreturn",
                "print('returning')",
                RemoveReturnWithEmpty,
            ),
        )
    )
    def test_flatten_pass_behavior(
        self, before: str, after: str, visitor: Type[CSTTransformer]
    ) -> None:
        # Test doesn't have newline termination case
        before_module = parse_module(before)
        after_module = before_module.visit(visitor())
        self.assertEqual(after, after_module.code)

        # Test does have newline termination case
        before_module = parse_module(before + "\n")
        after_module = before_module.visit(visitor())
        self.assertEqual(after + "\n", after_module.code)