File: imports.py

package info (click to toggle)
python-datamodel-code-generator 0.26.4-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 712 kB
  • sloc: python: 9,525; makefile: 14
file content (127 lines) | stat: -rw-r--r-- 5,728 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from __future__ import annotations

from collections import defaultdict
from functools import lru_cache
from typing import DefaultDict, Dict, Iterable, List, Optional, Set, Tuple, Union

from datamodel_code_generator.util import BaseModel


class Import(BaseModel):
    from_: Optional[str] = None
    import_: str
    alias: Optional[str] = None
    reference_path: Optional[str] = None

    @classmethod
    @lru_cache()
    def from_full_path(cls, class_path: str) -> Import:
        split_class_path: List[str] = class_path.split('.')
        return Import(
            from_='.'.join(split_class_path[:-1]) or None, import_=split_class_path[-1]
        )


class Imports(DefaultDict[Optional[str], Set[str]]):
    def __str__(self) -> str:
        return self.dump()

    def __init__(self, use_exact: bool = False) -> None:
        super().__init__(set)
        self.alias: DefaultDict[Optional[str], Dict[str, str]] = defaultdict(dict)
        self.counter: Dict[Tuple[Optional[str], str], int] = defaultdict(int)
        self.reference_paths: Dict[str, Import] = {}
        self.use_exact: bool = use_exact

    def _set_alias(self, from_: Optional[str], imports: Set[str]) -> List[str]:
        return [
            f'{i} as {self.alias[from_][i]}'
            if i in self.alias[from_] and i != self.alias[from_][i]
            else i
            for i in sorted(imports)
        ]

    def create_line(self, from_: Optional[str], imports: Set[str]) -> str:
        if from_:
            return f"from {from_} import {', '.join(self._set_alias(from_, imports))}"
        return '\n'.join(f'import {i}' for i in self._set_alias(from_, imports))

    def dump(self) -> str:
        return '\n'.join(
            self.create_line(from_, imports) for from_, imports in self.items()
        )

    def append(self, imports: Union[Import, Iterable[Import], None]) -> None:
        if imports:
            if isinstance(imports, Import):
                imports = [imports]
            for import_ in imports:
                if import_.reference_path:
                    self.reference_paths[import_.reference_path] = import_
                if '.' in import_.import_:
                    self[None].add(import_.import_)
                    self.counter[(None, import_.import_)] += 1
                else:
                    self[import_.from_].add(import_.import_)
                    self.counter[(import_.from_, import_.import_)] += 1
                    if import_.alias:
                        self.alias[import_.from_][import_.import_] = import_.alias

    def remove(self, imports: Union[Import, Iterable[Import]]) -> None:
        if isinstance(imports, Import):  # pragma: no cover
            imports = [imports]
        for import_ in imports:
            if '.' in import_.import_:  # pragma: no cover
                self.counter[(None, import_.import_)] -= 1
                if self.counter[(None, import_.import_)] == 0:  # pragma: no cover
                    self[None].remove(import_.import_)
                    if not self[None]:
                        del self[None]
            else:
                self.counter[(import_.from_, import_.import_)] -= 1  # pragma: no cover
                if (
                    self.counter[(import_.from_, import_.import_)] == 0
                ):  # pragma: no cover
                    self[import_.from_].remove(import_.import_)
                    if not self[import_.from_]:
                        del self[import_.from_]
                    if import_.alias:  # pragma: no cover
                        del self.alias[import_.from_][import_.import_]
                        if not self.alias[import_.from_]:
                            del self.alias[import_.from_]

    def remove_referenced_imports(self, reference_path: str) -> None:
        if reference_path in self.reference_paths:
            self.remove(self.reference_paths[reference_path])


IMPORT_ANNOTATED = Import.from_full_path('typing.Annotated')
IMPORT_ANNOTATED_BACKPORT = Import.from_full_path('typing_extensions.Annotated')
IMPORT_ANY = Import.from_full_path('typing.Any')
IMPORT_LIST = Import.from_full_path('typing.List')
IMPORT_SET = Import.from_full_path('typing.Set')
IMPORT_UNION = Import.from_full_path('typing.Union')
IMPORT_OPTIONAL = Import.from_full_path('typing.Optional')
IMPORT_LITERAL = Import.from_full_path('typing.Literal')
IMPORT_TYPE_ALIAS = Import.from_full_path('typing.TypeAlias')
IMPORT_LITERAL_BACKPORT = Import.from_full_path('typing_extensions.Literal')
IMPORT_SEQUENCE = Import.from_full_path('typing.Sequence')
IMPORT_FROZEN_SET = Import.from_full_path('typing.FrozenSet')
IMPORT_MAPPING = Import.from_full_path('typing.Mapping')
IMPORT_ABC_SEQUENCE = Import.from_full_path('collections.abc.Sequence')
IMPORT_ABC_SET = Import.from_full_path('collections.abc.Set')
IMPORT_ABC_MAPPING = Import.from_full_path('collections.abc.Mapping')
IMPORT_ENUM = Import.from_full_path('enum.Enum')
IMPORT_ANNOTATIONS = Import.from_full_path('__future__.annotations')
IMPORT_DICT = Import.from_full_path('typing.Dict')
IMPORT_DECIMAL = Import.from_full_path('decimal.Decimal')
IMPORT_DATE = Import.from_full_path('datetime.date')
IMPORT_DATETIME = Import.from_full_path('datetime.datetime')
IMPORT_TIMEDELTA = Import.from_full_path('datetime.timedelta')
IMPORT_PATH = Import.from_full_path('pathlib.Path')
IMPORT_TIME = Import.from_full_path('datetime.time')
IMPORT_UUID = Import.from_full_path('uuid.UUID')
IMPORT_PENDULUM_DATE = Import.from_full_path('pendulum.Date')
IMPORT_PENDULUM_DATETIME = Import.from_full_path('pendulum.DateTime')
IMPORT_PENDULUM_DURATION = Import.from_full_path('pendulum.Duration')
IMPORT_PENDULUM_TIME = Import.from_full_path('pendulum.Time')