1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
|
import os
import re
from collections import defaultdict
from pathlib import Path
from typing import Iterable, Iterator, List, Optional, Set
from urllib.parse import urlparse
from toposort import toposort_flatten
from xsdata.codegen.mixins import ContainerHandlerInterface
from xsdata.codegen.models import Class, get_location, get_target_namespace
from xsdata.exceptions import CodeGenerationError
from xsdata.models.config import ObjectType, StructureStyle
from xsdata.models.enums import COMMON_SCHEMA_DIR
from xsdata.utils import collections
from xsdata.utils.graphs import strongly_connected_components
from xsdata.utils.namespaces import to_package_name
from xsdata.utils.package import module_name
class DesignateClassPackages(ContainerHandlerInterface):
"""Designate classes to packages and modules based on the output structure
style."""
__slots__ = ()
def run(self):
structure_style = self.container.config.output.structure_style
if structure_style == StructureStyle.NAMESPACES:
self.group_by_namespace()
elif structure_style == StructureStyle.SINGLE_PACKAGE:
self.group_all_together()
elif structure_style == StructureStyle.CLUSTERS:
self.group_by_strong_components()
elif structure_style == StructureStyle.NAMESPACE_CLUSTERS:
self.group_by_namespace_clusters()
else:
self.group_by_filenames()
def group_by_filenames(self):
"""Group uris by common path and auto assign package names to all
classes."""
package = self.container.config.output.package
class_map = collections.group_by(self.container, key=get_location)
groups = self.group_common_paths(class_map.keys())
for keys in groups:
if len(keys) == 1:
common_path = os.path.dirname(keys[0])
else:
common_path = os.path.commonpath(keys)
for key in keys:
items = class_map[key]
suffix = ".".join(Path(key).parent.relative_to(common_path).parts)
package_name = f"{package}.{suffix}" if suffix else package
self.assign(items, package_name, module_name(key))
def group_by_namespace(self):
"""Group classes by their target namespace."""
groups = collections.group_by(self.container, key=get_target_namespace)
for namespace, classes in groups.items():
parts = self.combine_ns_package(namespace)
module = parts.pop()
package = ".".join(parts)
self.assign(classes, package, module)
def group_all_together(self):
"""Group all classes together in the same module."""
package_parts = self.container.config.output.package.split(".")
module = package_parts.pop()
package = ".".join(package_parts)
self.assign(self.container, package, module)
def group_by_strong_components(self):
"""Find circular imports and cluster their classes together."""
package = self.container.config.output.package
for group in self.strongly_connected_classes():
classes = self.sorted_classes(group)
module = classes[0].name
self.assign(classes, package, module)
def group_by_namespace_clusters(self):
for group in self.strongly_connected_classes():
classes = self.sorted_classes(group)
if len(set(map(get_target_namespace, classes))) > 1:
raise CodeGenerationError(
"Found strongly connected classes from different "
"namespaces, grouping them is impossible!"
)
parts = self.combine_ns_package(classes[0].target_namespace)
module = classes[0].name
self.assign(classes, ".".join(parts), module)
def sorted_classes(self, qnames: Set[str]) -> List[Class]:
edges = {
qname: set(self.container.first(qname).dependencies()).intersection(qnames)
for qname in qnames
}
return [self.container.first(qname) for qname in toposort_flatten(edges)]
def strongly_connected_classes(self) -> Iterator[Set[str]]:
edges = {obj.qname: list(set(obj.dependencies(True))) for obj in self.container}
return strongly_connected_components(edges)
@classmethod
def assign(cls, classes: Iterable[Class], package: str, module: str):
for obj in classes:
obj.package = package
obj.module = module
cls.assign(obj.inner, package, module)
@classmethod
def group_common_paths(cls, paths: Iterable[str]) -> List[List[str]]:
prev = ""
index = 0
groups = defaultdict(list)
common_schemas_dir = COMMON_SCHEMA_DIR.as_uri()
for path in sorted(paths):
if path.startswith(common_schemas_dir):
groups[0].append(path)
else:
path_parsed = urlparse(path)
common_path = os.path.commonpath((prev, path))
if not common_path or common_path == path_parsed.scheme:
index += 1
prev = path
groups[index].append(path)
return list(groups.values())
def combine_ns_package(self, namespace: Optional[str]) -> List[str]:
result = self.container.config.output.package.split(".")
if namespace:
substitution = collections.first(
re.sub(sub.search, sub.replace, namespace)
for sub in self.container.config.substitutions.substitution
if sub.type == ObjectType.PACKAGE
and re.fullmatch(sub.search, namespace) is not None
)
else:
substitution = None
if substitution:
result.extend(substitution.split("."))
else:
result.extend(to_package_name(namespace).split("."))
return list(filter(None, result))
|