File: flatten_class_extensions.py

package info (click to toggle)
python-xsdata 24.1-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 2,936 kB
  • sloc: python: 29,257; xml: 404; makefile: 27; sh: 6
file content (294 lines) | stat: -rw-r--r-- 11,123 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
from typing import Optional

from xsdata.codegen.mixins import RelativeHandlerInterface
from xsdata.codegen.models import Attr, AttrType, Class, Extension
from xsdata.codegen.utils import ClassUtils
from xsdata.logger import logger
from xsdata.models.enums import DataType, NamespaceType, Tag
from xsdata.utils.constants import DEFAULT_ATTR_NAME


class FlattenClassExtensions(RelativeHandlerInterface):
    """Reduce class extensions by copying or creating new attributes."""

    __slots__ = ()

    def process(self, target: Class):
        """Iterate and process the target class's extensions in reverser
        order."""
        for extension in list(target.extensions):
            self.process_extension(target, extension)

    def process_extension(self, target: Class, extension: Extension):
        """Slit the process of extension into schema data types and user
        defined types."""
        if extension.type.native:
            self.process_native_extension(target, extension)
        else:
            self.process_dependency_extension(target, extension)

    @classmethod
    def process_native_extension(cls, target: Class, extension: Extension):
        """
        Native type flatten handler.

        In case of enumerations copy the native data type to all enum
        members, otherwise create a default text value with the
        extension attributes.
        """
        if target.is_enumeration:
            cls.replace_attributes_type(target, extension)
        else:
            cls.add_default_attribute(target, extension)

    def process_dependency_extension(self, target: Class, extension: Extension):
        """User defined type flatten handler."""
        source = self.find_dependency(extension.type)
        if not source:
            logger.warning("Missing extension type: %s", extension.type.name)
            target.extensions.remove(extension)
        elif target.is_enumeration:
            self.process_enum_extension(source, target, extension)
        elif not source.is_complex or source.is_enumeration:
            self.process_simple_extension(source, target, extension)
        else:
            self.process_complex_extension(source, target, extension)

    def process_enum_extension(
        self, source: Class, target: Class, ext: Optional[Extension]
    ):
        """
        Process enumeration class extension.

        Cases:
            1. Source is an enumeration: merge them
            2. Source is a simple type: copy all source attr types
            3. Source is a complex type
                3.1 Target has a single member: Restrict default value
                3.2 Target has multiple members: unsupported reset enumeration
        """
        if source.is_enumeration:
            self.merge_enumerations(source, target)
        elif source.is_simple_type:
            self.merge_enumeration_types(source, target)
        elif len(target.attrs) == 1:
            self.set_default_value(source, target)
        else:
            # We can't subclass and override the value field
            # the target enumeration, mypy doesn't play nicely.
            target.attrs.clear()

        if ext and target.is_enumeration:
            target.extensions.remove(ext)

    @classmethod
    def merge_enumerations(cls, source: Class, target: Class):
        source_attrs = {attr.name: attr for attr in source.attrs}
        target.attrs = [
            source_attrs[attr.name].clone() if attr.name in source_attrs else attr
            for attr in target.attrs
        ]

    def merge_enumeration_types(self, source: Class, target: Class):
        source_attr = source.attrs[0]
        for tp in source_attr.types:
            if tp.native:
                for target_attr in target.attrs:
                    target_attr.types.append(tp.clone())
                    target_attr.restrictions.merge(source_attr.restrictions)
            else:
                base = self.find_dependency(tp)
                # It's impossible to have a missing reference now, the
                # source class has passed through AttributeTypeHandler
                # and any missing types have been reset.
                assert base is not None
                self.process_enum_extension(base, target, None)

    @classmethod
    def set_default_value(cls, source: Class, target: Class):
        """Restrict the extension source class with the target single
        enumeration value."""
        new_attr = ClassUtils.find_value_attr(source).clone()
        new_attr.types = target.attrs[0].types
        new_attr.default = target.attrs[0].default
        new_attr.fixed = True
        target.attrs = [new_attr]

    @classmethod
    def process_simple_extension(cls, source: Class, target: Class, ext: Extension):
        """
        Simple flatten extension handler for common classes eg SimpleType,
        Restriction.

        Steps:
            1. If target is source: drop the extension.
            2. If source is enumeration and target isn't create default value attribute.
            3. If both source and target are enumerations copy all attributes.
            4. If both source and target are not enumerations copy all attributes.
            5. If target is enumeration: drop the extension.
        """
        if source is target:
            target.extensions.remove(ext)
        elif source.is_enumeration and not target.is_enumeration:
            cls.add_default_attribute(target, ext)
        elif source.is_enumeration == target.is_enumeration:
            ClassUtils.copy_attributes(source, target, ext)
        else:  # this is an enumeration
            target.extensions.remove(ext)

    @classmethod
    def process_complex_extension(cls, source: Class, target: Class, ext: Extension):
        """
        Complex flatten extension handler for primary classes eg ComplexType,
        Element.

        Compare source and target classes and either remove the
        extension completely, copy all source attributes to the target
        class or leave the extension alone.
        """
        if cls.should_remove_extension(source, target, ext):
            target.extensions.remove(ext)
        elif cls.should_flatten_extension(source, target):
            ClassUtils.copy_attributes(source, target, ext)
        else:
            ext.type.reference = id(source)

    def find_dependency(self, attr_type: AttrType) -> Optional[Class]:
        """
        Find dependency for the given extension type with priority.

        Search priority: xs:SimpleType >  xs:ComplexType
        """
        conditions = (
            lambda x: x.tag == Tag.SIMPLE_TYPE,
            lambda x: x.tag == Tag.COMPLEX_TYPE,
        )

        for condition in conditions:
            result = self.container.find(attr_type.qname, condition=condition)
            if result:
                return result

        return None

    @classmethod
    def should_remove_extension(
        cls, source: Class, target: Class, ext: Extension
    ) -> bool:
        """
        Return whether the extension should be removed because of some
        violation.

        Violations:
            - Circular Reference
            - Forward Reference
            - Unordered sequences
            - MRO Violation A(B), C(B) and extensions includes A, B, C
        """
        # Circular or Forward reference
        if (
            source is target
            or target in source.inner
            or cls.have_unordered_sequences(source, target, ext)
        ):
            return True

        # MRO Violation
        collision = {ext.type.qname for ext in target.extensions}
        return any(ext.type.qname in collision for ext in source.extensions)

    @classmethod
    def should_flatten_extension(cls, source: Class, target: Class) -> bool:
        """
        Return whether the extension should be flattened because of rules.

        Rules:
            1. Source doesn't have a parent class
            2. Source class is a simple type
            3. Source class has a suffix attr and target has its own attrs
            4. Target class has a suffix attr
            5. Target restrictions parent attrs in different sequence order
            6. Target restricts parent attr with a not matching type.
        """
        if not source.extensions and (
            source.is_simple_type
            or target.has_suffix_attr
            or (source.has_suffix_attr and target.attrs)
        ):
            return True

        return False

    @classmethod
    def have_unordered_sequences(
        cls, source: Class, target: Class, ext: Extension
    ) -> bool:
        """
        Validate sequence attributes are in the same order in the parent class.

        Dataclasses fields ordering follows the python mro pattern, the
        parent fields are always first, and they are updated if the
        subclass is overriding any of them but the overall ordering
        doesn't change!

        @todo This needs a complete rewrite and most likely it needs to
        @todo move way down in the process chain.
        """

        if ext.tag == Tag.EXTENSION or source.extensions:
            return False

        sequence = [
            attr.name
            for attr in target.attrs
            if attr.restrictions.sequence is not None and not attr.is_prohibited
        ]
        if len(sequence) > 1:
            compare = [attr.name for attr in source.attrs if attr.name in sequence]
            if compare and compare != sequence:
                return True

        return False

    @classmethod
    def replace_attributes_type(cls, target: Class, extension: Extension):
        """Replace all target attributes types with the extension's type and
        remove it from the target class extensions."""

        for attr in target.attrs:
            attr.types.clear()
            attr.types.append(extension.type.clone())
        target.extensions.remove(extension)

    @classmethod
    def add_default_attribute(cls, target: Class, extension: Extension):
        """Add a default value field to the given class based on the extension
        type."""
        if extension.type.datatype != DataType.ANY_TYPE:
            tag = Tag.EXTENSION
            name = DEFAULT_ATTR_NAME
            namespace = None
        else:
            tag = Tag.ANY
            name = "@any_element"
            namespace = NamespaceType.ANY_NS

        attr = cls.get_or_create_attribute(target, name, tag)
        attr.types.append(extension.type.clone())
        attr.restrictions.merge(extension.restrictions)
        attr.namespace = namespace
        target.extensions.remove(extension)

    @classmethod
    def get_or_create_attribute(cls, target: Class, name: str, tag: str) -> Attr:
        """Find or create for the given parameters an attribute in the target
        class."""

        attr = ClassUtils.find_attr(target, name)
        if attr is None:
            attr = Attr(name=name, tag=tag)
            attr.restrictions.min_occurs = 1
            attr.restrictions.max_occurs = 1
            target.attrs.insert(0, attr)

        return attr