File: test_removal_behavior.py

package info (click to toggle)
python-libcst 1.8.6-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,240 kB
  • sloc: python: 78,096; makefile: 15; sh: 2
file content (108 lines) | stat: -rw-r--r-- 4,032 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Type, Union

import libcst as cst
from libcst import parse_module, RemovalSentinel
from libcst._nodes.tests.base import CSTNodeTest
from libcst._types import CSTNodeT
from libcst._visitors import CSTTransformer
from libcst.testing.utils import data_provider


class IfStatementRemovalVisitor(CSTTransformer):
    def on_leave(
        self, original_node: CSTNodeT, updated_node: CSTNodeT
    ) -> Union[CSTNodeT, RemovalSentinel]:
        if isinstance(updated_node, cst.If):
            return cst.RemoveFromParent()
        else:
            return updated_node


class ContinueStatementRemovalVisitor(CSTTransformer):
    def on_leave(
        self, original_node: CSTNodeT, updated_node: CSTNodeT
    ) -> Union[CSTNodeT, RemovalSentinel]:
        if isinstance(updated_node, cst.Continue):
            return cst.RemoveFromParent()
        else:
            return updated_node


class SpecificImportRemovalVisitor(CSTTransformer):
    def on_leave(
        self, original_node: CSTNodeT, updated_node: CSTNodeT
    ) -> Union[cst.Import, cst.ImportFrom, CSTNodeT, RemovalSentinel]:
        if isinstance(updated_node, cst.Import):
            for alias in updated_node.names:
                name = alias.name
                if isinstance(name, cst.Name) and name.value == "b":
                    return cst.RemoveFromParent()
        elif isinstance(updated_node, cst.ImportFrom):
            module = updated_node.module
            if isinstance(module, cst.Name) and module.value == "e":
                return cst.RemoveFromParent()
        return updated_node


class RemovalBehavior(CSTNodeTest):
    @data_provider(
        (
            # Top of module doesn't require a pass, empty code is valid.
            ("continue", "", ContinueStatementRemovalVisitor),
            ("if condition: print('hello world')", "", IfStatementRemovalVisitor),
            # Verify behavior within an indented block.
            (
                "while True:\n    continue",
                "while True:\n    pass",
                ContinueStatementRemovalVisitor,
            ),
            (
                "while True:\n    if condition: print('hello world')",
                "while True:\n    pass",
                IfStatementRemovalVisitor,
            ),
            # Verify behavior within a simple statement suite.
            (
                "while True: continue",
                "while True: pass",
                ContinueStatementRemovalVisitor,
            ),
            # Verify with some imports
            (
                "import a\nimport b\n\nfrom c import d\nfrom e import f",
                "import a\n\nfrom c import d",
                SpecificImportRemovalVisitor,
            ),
            # Verify only one pass is generated even if we remove multiple statements
            (
                "while True:\n    continue\ncontinue",
                "while True:\n    pass",
                ContinueStatementRemovalVisitor,
            ),
            (
                "while True: continue ; continue",
                "while True: pass",
                ContinueStatementRemovalVisitor,
            ),
        )
    )
    def test_removal_pass_behavior(
        self, before: str, after: str, visitor: Type[CSTTransformer]
    ) -> None:
        if before.endswith("\n") or after.endswith("\n"):
            raise ValueError("Test cases should not be newline-terminated!")

        # Test doesn't have newline termination case
        before_module = parse_module(before)
        after_module = before_module.visit(visitor())
        self.assertEqual(after, after_module.code)

        # Test does have newline termination case
        before_module = parse_module(before + "\n")
        after_module = before_module.visit(visitor())
        self.assertEqual(after + "\n", after_module.code)