File: designate_class_packages.py

package info (click to toggle)
python-xsdata 24.1-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 2,936 kB
  • sloc: python: 29,257; xml: 404; makefile: 27; sh: 6
file content (155 lines) | stat: -rw-r--r-- 6,016 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import re
from collections import defaultdict
from pathlib import Path
from typing import Iterable, Iterator, List, Optional, Set
from urllib.parse import urlparse

from toposort import toposort_flatten

from xsdata.codegen.mixins import ContainerHandlerInterface
from xsdata.codegen.models import Class, get_location, get_target_namespace
from xsdata.exceptions import CodeGenerationError
from xsdata.models.config import ObjectType, StructureStyle
from xsdata.models.enums import COMMON_SCHEMA_DIR
from xsdata.utils import collections
from xsdata.utils.graphs import strongly_connected_components
from xsdata.utils.namespaces import to_package_name
from xsdata.utils.package import module_name


class DesignateClassPackages(ContainerHandlerInterface):
    """Designate classes to packages and modules based on the output structure
    style."""

    __slots__ = ()

    def run(self):
        structure_style = self.container.config.output.structure_style
        if structure_style == StructureStyle.NAMESPACES:
            self.group_by_namespace()
        elif structure_style == StructureStyle.SINGLE_PACKAGE:
            self.group_all_together()
        elif structure_style == StructureStyle.CLUSTERS:
            self.group_by_strong_components()
        elif structure_style == StructureStyle.NAMESPACE_CLUSTERS:
            self.group_by_namespace_clusters()
        else:
            self.group_by_filenames()

    def group_by_filenames(self):
        """Group uris by common path and auto assign package names to all
        classes."""
        package = self.container.config.output.package
        class_map = collections.group_by(self.container, key=get_location)
        groups = self.group_common_paths(class_map.keys())

        for keys in groups:
            if len(keys) == 1:
                common_path = os.path.dirname(keys[0])
            else:
                common_path = os.path.commonpath(keys)

            for key in keys:
                items = class_map[key]
                suffix = ".".join(Path(key).parent.relative_to(common_path).parts)

                package_name = f"{package}.{suffix}" if suffix else package
                self.assign(items, package_name, module_name(key))

    def group_by_namespace(self):
        """Group classes by their target namespace."""
        groups = collections.group_by(self.container, key=get_target_namespace)
        for namespace, classes in groups.items():
            parts = self.combine_ns_package(namespace)
            module = parts.pop()
            package = ".".join(parts)
            self.assign(classes, package, module)

    def group_all_together(self):
        """Group all classes together in the same module."""
        package_parts = self.container.config.output.package.split(".")
        module = package_parts.pop()
        package = ".".join(package_parts)

        self.assign(self.container, package, module)

    def group_by_strong_components(self):
        """Find circular imports and cluster their classes together."""
        package = self.container.config.output.package
        for group in self.strongly_connected_classes():
            classes = self.sorted_classes(group)
            module = classes[0].name
            self.assign(classes, package, module)

    def group_by_namespace_clusters(self):
        for group in self.strongly_connected_classes():
            classes = self.sorted_classes(group)
            if len(set(map(get_target_namespace, classes))) > 1:
                raise CodeGenerationError(
                    "Found strongly connected classes from different "
                    "namespaces, grouping them is impossible!"
                )

            parts = self.combine_ns_package(classes[0].target_namespace)
            module = classes[0].name
            self.assign(classes, ".".join(parts), module)

    def sorted_classes(self, qnames: Set[str]) -> List[Class]:
        edges = {
            qname: set(self.container.first(qname).dependencies()).intersection(qnames)
            for qname in qnames
        }
        return [self.container.first(qname) for qname in toposort_flatten(edges)]

    def strongly_connected_classes(self) -> Iterator[Set[str]]:
        edges = {obj.qname: list(set(obj.dependencies(True))) for obj in self.container}
        return strongly_connected_components(edges)

    @classmethod
    def assign(cls, classes: Iterable[Class], package: str, module: str):
        for obj in classes:
            obj.package = package
            obj.module = module
            cls.assign(obj.inner, package, module)

    @classmethod
    def group_common_paths(cls, paths: Iterable[str]) -> List[List[str]]:
        prev = ""
        index = 0
        groups = defaultdict(list)
        common_schemas_dir = COMMON_SCHEMA_DIR.as_uri()

        for path in sorted(paths):
            if path.startswith(common_schemas_dir):
                groups[0].append(path)
            else:
                path_parsed = urlparse(path)
                common_path = os.path.commonpath((prev, path))
                if not common_path or common_path == path_parsed.scheme:
                    index += 1

                prev = path
                groups[index].append(path)

        return list(groups.values())

    def combine_ns_package(self, namespace: Optional[str]) -> List[str]:
        result = self.container.config.output.package.split(".")

        if namespace:
            substitution = collections.first(
                re.sub(sub.search, sub.replace, namespace)
                for sub in self.container.config.substitutions.substitution
                if sub.type == ObjectType.PACKAGE
                and re.fullmatch(sub.search, namespace) is not None
            )
        else:
            substitution = None

        if substitution:
            result.extend(substitution.split("."))
        else:
            result.extend(to_package_name(namespace).split("."))

        return list(filter(None, result))