File: test_deep_replace.py

package info (click to toggle)
python-libcst 1.8.6-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,240 kB
  • sloc: python: 78,096; makefile: 15; sh: 2
file content (137 lines) | stat: -rw-r--r-- 4,535 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from textwrap import dedent
from typing import Optional

import libcst as cst
from libcst.testing.utils import UnitTest


class DeepReplaceTest(UnitTest):
    def test_deep_replace_simple(self) -> None:
        old_code = """
            pass
        """
        new_code = """
            break
        """

        module = cst.parse_module(dedent(old_code))
        pass_stmt = cst.ensure_type(module.body[0], cst.SimpleStatementLine).body[0]
        new_module = cst.ensure_type(
            module.deep_replace(pass_stmt, cst.Break()), cst.Module
        )
        self.assertEqual(new_module.code, dedent(new_code))

    def test_deep_replace_complex(self) -> None:
        old_code = """
            def a():
                def b():
                    def c():
                        pass
        """
        new_code = """
            def a():
                def b():
                    def d(): break
        """

        module = cst.parse_module(dedent(old_code))
        outer_fun = cst.ensure_type(module.body[0], cst.FunctionDef)
        middle_fun = cst.ensure_type(
            cst.ensure_type(outer_fun.body, cst.IndentedBlock).body[0], cst.FunctionDef
        )
        inner_fun = cst.ensure_type(
            cst.ensure_type(middle_fun.body, cst.IndentedBlock).body[0], cst.FunctionDef
        )
        new_module = cst.ensure_type(
            module.deep_replace(
                inner_fun,
                cst.FunctionDef(
                    name=cst.Name("d"),
                    params=cst.Parameters(),
                    body=cst.SimpleStatementSuite(body=(cst.Break(),)),
                ),
            ),
            cst.Module,
        )
        self.assertEqual(new_module.code, dedent(new_code))

    def test_deep_replace_identity(self) -> None:
        old_code = """
            pass
        """
        new_code = """
            break
        """

        module = cst.parse_module(dedent(old_code))
        new_module = module.deep_replace(
            module,
            cst.Module(
                header=(cst.EmptyLine(),),
                body=(cst.SimpleStatementLine(body=(cst.Break(),)),),
            ),
        )
        self.assertEqual(new_module.code, dedent(new_code))

    def test_deep_remove_complex(self) -> None:
        old_code = """
            def a():
                def b():
                    def c():
                        print("Hello, world!")
        """
        new_code = """
            def a():
                def b():
                    pass
        """

        module = cst.parse_module(dedent(old_code))
        outer_fun = cst.ensure_type(module.body[0], cst.FunctionDef)
        middle_fun = cst.ensure_type(
            cst.ensure_type(outer_fun.body, cst.IndentedBlock).body[0], cst.FunctionDef
        )
        inner_fun = cst.ensure_type(
            cst.ensure_type(middle_fun.body, cst.IndentedBlock).body[0], cst.FunctionDef
        )
        new_module = cst.ensure_type(module.deep_remove(inner_fun), cst.Module)
        self.assertEqual(new_module.code, dedent(new_code))

    def test_with_deep_changes_complex(self) -> None:
        old_code = """
            def a():
                def b():
                    def c():
                        print("Hello, world!")
        """
        new_code = """
            def a():
                def b():
                    def c():
                        print("Goodbye, world!")
        """

        class NodeFinder(cst.CSTVisitor):
            # I wrote this so I wouldn't have to do a nasty multi-level
            # tree walk, but it is also a nice example of how to implement
            # a simple node find in the absence of official support.
            def __init__(self) -> None:
                super().__init__()
                self.node: Optional[cst.CSTNode] = None

            def visit_SimpleString(self, node: cst.SimpleString) -> None:
                self.node = node

        module = cst.parse_module(dedent(old_code))
        node_finder = NodeFinder()
        module.visit(node_finder)
        node = node_finder.node
        assert node is not None, "Expected to find a string node!"
        new_module = cst.ensure_type(
            module.with_deep_changes(node, value='"Goodbye, world!"'), cst.Module
        )
        self.assertEqual(new_module.code, dedent(new_code))