File: resolver.py

package info (click to toggle)
python-xsdata 24.1-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 2,936 kB
  • sloc: python: 29,257; xml: 404; makefile: 27; sh: 6
file content (131 lines) | stat: -rw-r--r-- 4,798 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import logging
import re
from typing import Dict, List

from toposort import toposort_flatten

from xsdata.codegen.models import Class, Import, get_slug
from xsdata.exceptions import ResolverValueError
from xsdata.utils import collections

logger = logging.getLogger(__name__)


class DependenciesResolver:
    __slots__ = "packages", "aliases", "imports", "class_list", "class_map", "package"

    def __init__(self, packages: Dict[str, str]):
        self.packages = packages

        self.aliases: Dict[str, str] = {}
        self.imports: List[Import] = []
        self.class_list: List[str] = []
        self.class_map: Dict[str, Class] = {}

    def process(self, classes: List[Class]):
        """
        Resolve the dependencies for the given list of classes and the target
        package.

        Reset aliases and imports from any previous runs keep the record
        of the processed class names
        """
        self.imports.clear()
        self.aliases.clear()
        self.class_map = self.create_class_map(classes)
        self.class_list = self.create_class_list(classes)
        self.resolve_imports()

    def sorted_imports(self) -> List[Import]:
        """Return a new sorted by name list of import packages."""
        return sorted(self.imports, key=lambda x: x.name)

    def sorted_classes(self) -> List[Class]:
        """Return an iterator of classes property sorted for generation and
        apply import aliases."""

        result = []
        for name in self.class_list:
            obj = self.class_map.get(name)
            if obj is not None:
                self.apply_aliases(obj)
                result.append(obj)
        return result

    def apply_aliases(self, target: Class):
        """Iterate over the target class dependencies and set the type
        aliases."""
        for attr in target.attrs:
            for attr_type in attr.types:
                attr_type.alias = self.aliases.get(attr_type.qname)

            for choice in attr.choices:
                for choice_type in choice.types:
                    choice_type.alias = self.aliases.get(choice_type.qname)

        for ext in target.extensions:
            ext.type.alias = self.aliases.get(ext.type.qname)

        collections.apply(target.inner, self.apply_aliases)

    def resolve_imports(self):
        """Walk the import qualified names, check for naming collisions and add
        the necessary code generator import instance."""
        self.imports = [
            Import(qname=qname, source=self.find_package(qname))
            for qname in self.import_classes()
        ]
        protected = {obj.slug for obj in self.class_map.values()}
        self.resolve_conflicts(self.imports, protected)
        self.set_aliases()

    def set_aliases(self):
        self.aliases = {imp.qname: imp.alias for imp in self.imports if imp.alias}

    @classmethod
    def resolve_conflicts(cls, imports: List[Import], protected: set):
        for slug, group in collections.group_by(imports, key=get_slug).items():
            if len(group) == 1:
                if slug in protected:
                    imp = group[0]
                    module = imp.source.split(".")[-1]
                    imp.alias = f"{module}:{imp.name}"
                continue

            for index, cur in enumerate(group):
                cmp = group[index + 1] if index == 0 else group[index - 1]
                parts = re.split("[_.]", cur.source)
                diff = set(parts) - set(re.split("[_.]", cmp.source))

                add = "_".join(part for part in parts if part in diff)
                cur.alias = f"{add}:{cur.name}"

    def find_package(self, qname: str) -> str:
        """
        Return the package name for the given qualified class name.

        :raises ResolverValueError: if name doesn't exist.
        """
        if qname not in self.packages:
            raise ResolverValueError(f"Unknown dependency: {qname}")
        return self.packages[qname]

    def import_classes(self) -> List[str]:
        """Return a list of class that need to be imported."""
        return [qname for qname in self.class_list if qname not in self.class_map]

    @staticmethod
    def create_class_list(classes: List[Class]) -> List[str]:
        """Use topology sort to return a flat list for all the dependencies."""
        return toposort_flatten({obj.qname: set(obj.dependencies()) for obj in classes})

    @staticmethod
    def create_class_map(classes: List[Class]) -> Dict[str, Class]:
        """Index the list of classes by name."""
        result: Dict[str, Class] = {}
        for obj in classes:
            if obj.qname in result:
                raise ResolverValueError(f"Duplicate class: `{obj.name}`")
            result[obj.qname] = obj

        return result