1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
|
from typing import List, Optional
from xsdata.codegen.container import ClassContainer
from xsdata.codegen.models import Attr, Class, Extension, get_tag
from xsdata.codegen.utils import ClassUtils
from xsdata.logger import logger
from xsdata.models.enums import Tag
from xsdata.utils import collections
from xsdata.utils.collections import group_by
class ClassValidator:
"""Run validations against the class container in order to remove or merge
invalid or redefined types."""
__slots__ = "container"
def __init__(self, container: ClassContainer):
self.container = container
def process(self):
"""
Remove if possible classes with the same qualified name.
Steps:
1. Remove invalid classes
2. Handle duplicate types
3. Merge dummy types
"""
for classes in self.container.data.values():
if len(classes) > 1:
self.remove_invalid_classes(classes)
if len(classes) > 1:
self.handle_duplicate_types(classes)
if len(classes) > 1:
self.merge_global_types(classes)
def remove_invalid_classes(self, classes: List[Class]):
"""Remove from the given class list any class with missing extension
type."""
def is_invalid(ext: Extension) -> bool:
"""Check if given type declaration is not native and is missing."""
return not ext.type.native and ext.type.qname not in self.container.data
for target in list(classes):
if any(is_invalid(extension) for extension in target.extensions):
classes.remove(target)
@classmethod
def handle_duplicate_types(cls, classes: List[Class]):
"""Handle classes with same namespace, name that are derived from the
same xs type."""
for items in group_by(classes, get_tag).values():
if len(items) == 1:
continue
index = cls.select_winner(list(items))
if index == -1:
logger.warning(
"Duplicate type %s, will keep the last defined",
items[0].qname,
)
winner = items.pop(index)
for item in items:
classes.remove(item)
if winner.container == Tag.REDEFINE:
cls.merge_redefined_type(item, winner)
@classmethod
def merge_redefined_type(cls, source: Class, target: Class):
"""
Copy any attributes and extensions to redefined types from the original
definitions.
Redefined inheritance is optional search for self references in
extensions and attribute groups.
"""
circular_extension = cls.find_circular_extension(target)
circular_group = cls.find_circular_group(target)
if circular_extension:
ClassUtils.copy_attributes(source, target, circular_extension)
ClassUtils.copy_extensions(source, target, circular_extension)
if circular_group:
ClassUtils.copy_group_attributes(source, target, circular_group)
@classmethod
def select_winner(cls, candidates: List[Class]) -> int:
"""
Returns the index of the class that will survive the duplicate process.
Classes that were extracted from in xs:override/xs:redefined
containers have priority, otherwise pick the last in the list.
"""
for index, item in enumerate(candidates):
if item.container in (Tag.OVERRIDE, Tag.REDEFINE):
return index
return -1
@classmethod
def find_circular_extension(cls, target: Class) -> Optional[Extension]:
"""Search for any target class extensions that is a circular
reference."""
for ext in target.extensions:
if ext.type.name == target.name:
return ext
return None
@classmethod
def find_circular_group(cls, target: Class) -> Optional[Attr]:
"""Search for any target class attributes that is a circular
reference."""
return ClassUtils.find_attr(target, target.name)
@classmethod
def merge_global_types(cls, classes: List[Class]):
"""
Merge parent-child global types.
Conditions
1. One of them is derived from xs:element
2. One of them is derived from xs:complexType
3. The xs:element is a subclass of the xs:complexType
4. The xs:element has no attributes (This can't happen in a valid schema)
"""
el = collections.first(x for x in classes if x.tag == Tag.ELEMENT)
ct = collections.first(x for x in classes if x.tag == Tag.COMPLEX_TYPE)
if (
el is None
or ct is None
or el is ct
or el.attrs
or len(el.extensions) != 1
or el.extensions[0].type.qname != el.qname
):
return
ct.namespace = el.namespace or ct.namespace
ct.help = el.help or ct.help
ct.substitutions = el.substitutions
classes.remove(el)
|