File: convert_namedtuple_to_dataclass.py

package info (click to toggle)
python-libcst 1.8.6-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,240 kB
  • sloc: python: 78,096; makefile: 15; sh: 2
file content (75 lines) | stat: -rw-r--r-- 2,808 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
from typing import List, Optional, Sequence

import libcst as cst
from libcst.codemod import VisitorBasedCodemodCommand
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
from libcst.metadata import (
    ProviderT,
    QualifiedName,
    QualifiedNameProvider,
    QualifiedNameSource,
)


class ConvertNamedTupleToDataclassCommand(VisitorBasedCodemodCommand):
    """
    Convert NamedTuple class declarations to Python 3.7 dataclasses.

    This only performs a conversion at the class declaration level.
    It does not perform type annotation conversions, nor does it convert
    NamedTuple-specific attributes and methods.
    """

    DESCRIPTION: str = (
        "Convert NamedTuple class declarations to Python 3.7 dataclasses using the @dataclass decorator."
    )
    METADATA_DEPENDENCIES: Sequence[ProviderT] = (QualifiedNameProvider,)

    # The 'NamedTuple' we are interested in
    qualified_namedtuple: QualifiedName = QualifiedName(
        name="typing.NamedTuple", source=QualifiedNameSource.IMPORT
    )

    def leave_ClassDef(
        self, original_node: cst.ClassDef, updated_node: cst.ClassDef
    ) -> cst.ClassDef:
        new_bases: List[cst.Arg] = []
        namedtuple_base: Optional[cst.Arg] = None

        # Need to examine the original node's bases since they are directly tied to import metadata
        for base_class in original_node.bases:
            # Compare the base class's qualified name against the expected typing.NamedTuple
            if not QualifiedNameProvider.has_name(
                self, base_class.value, self.qualified_namedtuple
            ):
                # Keep all bases that are not of type typing.NamedTuple
                new_bases.append(base_class)
            else:
                namedtuple_base = base_class

        # We still want to return the updated node in case some of its children have been modified
        if namedtuple_base is None:
            return updated_node

        AddImportsVisitor.add_needed_import(self.context, "dataclasses", "dataclass")
        RemoveImportsVisitor.remove_unused_import_by_node(
            self.context, namedtuple_base.value
        )

        call = cst.ensure_type(
            cst.parse_expression(
                "dataclass(frozen=True)", config=self.module.config_for_parsing
            ),
            cst.Call,
        )
        return updated_node.with_changes(
            lpar=cst.MaybeSentinel.DEFAULT,
            rpar=cst.MaybeSentinel.DEFAULT,
            bases=new_bases,
            decorators=[*original_node.decorators, cst.Decorator(decorator=call)],
        )