File: test_gather_exports.py

package info (click to toggle)
python-libcst 1.4.0-1.2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 5,928 kB
  • sloc: python: 76,235; makefile: 10; sh: 2
file content (150 lines) | stat: -rw-r--r-- 4,417 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
from libcst import parse_module
from libcst.codemod import CodemodContext, CodemodTest
from libcst.codemod.visitors import GatherExportsVisitor
from libcst.testing.utils import UnitTest


class TestGatherExportsVisitor(UnitTest):
    def gather_exports(self, code: str) -> GatherExportsVisitor:
        transform_instance = GatherExportsVisitor(CodemodContext())
        input_tree = parse_module(CodemodTest.make_fixture_data(code))
        input_tree.visit(transform_instance)
        return transform_instance

    def test_gather_noop(self) -> None:
        code = """
            from foo import bar

            from typing import List

            bar(["foo", "bar"])

            list_of_str = ["foo", "bar", "baz"]

            set_of_str = {"foo", "bar", "baz"}

            tuple_of_str = ("foo", "bar", "baz")

            another: List[str] = ["foobar", "foobarbaz"]
        """

        gatherer = self.gather_exports(code)
        self.assertEqual(gatherer.explicit_exported_objects, set())

    def test_gather_exports_simple(self) -> None:
        code = """
            from foo import bar
            from biz import baz

            __all__ = ["bar", "baz"]
        """

        gatherer = self.gather_exports(code)
        self.assertEqual(gatherer.explicit_exported_objects, {"bar", "baz"})

    def test_gather_exports_simple2(self) -> None:
        code = """
            from foo import bar
            from biz import baz

            __all__ = ["bar"]
            __all__ += ["baz"]
        """

        gatherer = self.gather_exports(code)
        self.assertEqual(gatherer.explicit_exported_objects, {"bar", "baz"})

    def test_gather_exports_simple_set(self) -> None:
        code = """
            from foo import bar
            from biz import baz

            __all__ = {"bar", "baz"}
        """

        gatherer = self.gather_exports(code)
        self.assertEqual(gatherer.explicit_exported_objects, {"bar", "baz"})

    def test_gather_exports_simple_tuple(self) -> None:
        code = """
            from foo import bar
            from biz import baz

            __all__ = ("bar", "baz")
        """

        gatherer = self.gather_exports(code)
        self.assertEqual(gatherer.explicit_exported_objects, {"bar", "baz"})

    def test_gather_exports_simple_annotated(self) -> None:
        code = """
            from foo import bar
            from biz import baz

            from typing import List

            __all__: List[str] = ["bar", "baz"]
        """

        gatherer = self.gather_exports(code)
        self.assertEqual(gatherer.explicit_exported_objects, {"bar", "baz"})

    def test_gather_exports_ignore_invalid_1(self) -> None:
        code = """
            from foo import bar
            from biz import baz

            __all__ = [bar, baz]
        """

        gatherer = self.gather_exports(code)
        self.assertEqual(gatherer.explicit_exported_objects, set())

    def test_gather_exports_ignore_invalid_2(self) -> None:
        code = """
            from foo import bar
            from biz import baz

            __all__ = ["bar", "baz", ["biz"]]
        """

        gatherer = self.gather_exports(code)
        self.assertEqual(gatherer.explicit_exported_objects, {"bar", "baz"})

    def test_gather_exports_ignore_valid_1(self) -> None:
        code = """
            from foo import bar
            from biz import baz

            __all__ = ["bar", "b""a""z"]
        """

        gatherer = self.gather_exports(code)
        self.assertEqual(gatherer.explicit_exported_objects, {"bar", "baz"})

    def test_gather_exports_ignore_valid_2(self) -> None:
        code = """
            from foo import bar
            from biz import baz

            __all__, _ = ["bar", "baz"], ["biz"]
        """

        gatherer = self.gather_exports(code)
        self.assertEqual(gatherer.explicit_exported_objects, {"bar", "baz"})

    def test_gather_exports_ignore_valid_3(self) -> None:
        code = """
            from foo import bar
            from biz import baz

            __all__ = exported = ["bar", "baz"]
        """

        gatherer = self.gather_exports(code)
        self.assertEqual(gatherer.explicit_exported_objects, {"bar", "baz"})