File: codegen.py

package info (click to toggle)
python-xmlschema 4.1.0-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 5,208 kB
  • sloc: python: 39,174; xml: 1,282; makefile: 36
file content (623 lines) | stat: -rw-r--r-- 21,874 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
#
# Copyright (c), 2016-2021, SISSA (International School for Advanced Studies).
# All rights reserved.
# This file is distributed under the terms of the MIT License.
# See the file 'LICENSE' in the root directory of the present
# distribution, or http://opensource.org/licenses/MIT.
#
# @author Davide Brunato <brunato@sissa.it>
#
# mypy: ignore-errors
"""
This module contains abstact base class and helper
functions for building XSD based code generators.
"""
import os
import re
import sys
import inspect
import logging
from abc import ABC, ABCMeta
from fnmatch import fnmatch
from pathlib import Path
from typing import Optional

from jinja2 import Environment, ChoiceLoader, FileSystemLoader, \
    TemplateNotFound, TemplateAssertionError
from elementpath import datatypes

import xmlschema
from xmlschema.validators import XsdType, XsdElement, XsdAttribute
from xmlschema.names import XSD_NAMESPACE


NCNAME_PATTERN = re.compile(r'^[^\d\W][\w.\-]*$')
QNAME_PATTERN = re.compile(
    r'^(?:(?P<prefix>[^\d\W][\w\-.\xb7\u0387\u06DD\u06DE]*):)?'
    r'(?P<local>[^\d\W][\w\-.\xb7\u0387\u06DD\u06DE]*)$',
)


def is_shell_wildcard(pathname):
    return '*' in pathname or '?' in pathname or '[' in pathname


def xsd_qname(name):
    return f'{{{XSD_NAMESPACE}}}{name}'


def filter_method(func):
    """Marks a method for registration as template filter."""
    func.is_filter = True
    return func


def test_method(func):
    """Marks a method for registration as template test."""
    func.is_test = True
    return func


logger = logging.getLogger('xmlschema-codegen')


class GeneratorMeta(ABCMeta):
    """Metaclass for creating code generators. Checks formal_language """

    def __new__(mcs, name, bases, attrs):
        module = sys.modules.get(attrs['__module__'])
        module_path = getattr(module, '__file__', os.getcwd())

        formal_language = None
        searchpaths = []
        builtin_types = {}

        for base in bases:
            if getattr(base, 'formal_language', None):
                if formal_language is None:
                    formal_language = base.formal_language
                elif formal_language != base.formal_language:
                    raise ValueError("ambiguous formal_language from base classes")

            if getattr(base, 'searchpaths', None):
                searchpaths.extend(base.searchpaths)
            if getattr(base, 'builtin_types', None):
                builtin_types.update(base.builtin_types)

        if 'formal_language' not in attrs:
            attrs['formal_language'] = formal_language
        elif formal_language and formal_language != attrs['formal_language']:
            raise ValueError("formal_language cannot be changed")

        try:
            for path in attrs['searchpaths']:
                if Path(path).is_absolute():
                    dirpath = Path(path)
                else:
                    dirpath = Path(module_path).parent.joinpath(path)

                if not dirpath.is_dir():
                    raise ValueError(f"path {str(path)!r} is not a directory!")
                searchpaths.append(dirpath)

        except (KeyError, TypeError):
            pass
        else:
            attrs['searchpaths'] = searchpaths

        try:
            for k, v in attrs['builtin_types'].items():
                builtin_types[xsd_qname(k)] = v
        except (KeyError, AttributeError):
            pass
        finally:
            attrs['builtin_types'] = builtin_types

        return type.__new__(mcs, name, bases, attrs)


class AbstractGenerator(ABC, metaclass=GeneratorMeta):
    """
    Abstract base class for code generators based on Jinja2 template engine.

    :param schema: the source or the instance of the XSD schema.
    :param searchpath: additional search path for custom templates. \
    If provided the search path has priority over searchpaths defined \
    in generator class.
    :param types_map: a dictionary with custom mapping for XSD types.
    """
    formal_language: Optional[str] = None
    """The formal language associated to the code generator (eg. Python)."""

    searchpaths: Optional[list[str]] = None
    """
    Directory paths for searching templates, specified with a list or a tuple.
    Each path must be provided as relative from the directory of the module
    where the class is defined. Extends the searchpath defined in base classes.
    """

    builtin_types = {
        'anyType': '',
        'anySimpleType': '',
    }
    """
    Translation map for XSD builtin types. Updates the builtin_types
    defined in base classes.
    """

    def __init__(self, schema, searchpath=None, types_map=None):
        if isinstance(schema, xmlschema.XMLSchemaBase):
            self.schema = schema
        else:
            self.schema = xmlschema.XMLSchema11(schema)

        file_loaders = []
        if searchpath:
            file_loaders.append(FileSystemLoader(searchpath))
        if self.searchpaths is not None:
            file_loaders.extend(
                FileSystemLoader(str(path)) for path in reversed(self.searchpaths)
            )
        if not file_loaders:
            raise ValueError("no search paths defined!")
        loader = ChoiceLoader(file_loaders) if len(file_loaders) > 1 else file_loaders[0]

        self.types_map = self.builtin_types.copy()
        if types_map:
            if not self.schema.target_namespace:
                self.types_map.update(types_map)
            else:
                ns_part = '{%s}' % self.schema.target_namespace
                self.types_map.update((ns_part + k, v) for k, v in types_map.items())

        self.filters = {}
        self.tests = {}
        for name in filter(lambda x: callable(getattr(self, x)), dir(self)):
            method = getattr(self, name)
            if inspect.isfunction(method):
                # static methods
                if getattr(method, 'is_filter', False):
                    self.filters[name] = method
                elif getattr(method, 'is_test', False):
                    self.tests[name] = method
            elif inspect.isroutine(method) and hasattr(method, '__func__'):
                # class and instance methods
                if getattr(method.__func__, 'is_filter', False):
                    self.filters[name] = method
                elif getattr(method.__func__, 'is_test', False):
                    self.tests[name] = method

        type_mapping_filter = f'{self.formal_language}_type'.lower().replace(' ', '_')
        if type_mapping_filter not in self.filters:
            self.filters[type_mapping_filter] = self.map_type

        self._env = Environment(loader=loader)
        self._env.filters.update(self.filters)
        self._env.tests.update(self.tests)

    def __repr__(self):
        if self.schema.url:
            return f'{self.__class__.__name__}(schema={self.schema.name!r})'
        return f'{self.__class__.__name__}(namespace={self.schema.target_namespace!r})'

    def list_templates(self, extensions=None, filter_func=None):
        return self._env.list_templates(extensions, filter_func)

    def matching_templates(self, name):
        return self._env.list_templates(filter_func=lambda x: fnmatch(x, name))

    def get_template(self, name, parent=None, global_vars=None):
        return self._env.get_template(name, parent, global_vars)

    def select_template(self, names, parent=None, global_vars=None):
        return self._env.select_template(names, parent, global_vars)

    def render(self, names, parent=None, global_vars=None):
        if isinstance(names, str):
            names = [names]
        elif not all(isinstance(x, str) for x in names):
            raise TypeError("'names' argument must contain only strings!")

        results = []
        for name in names:
            try:
                template = self._env.get_template(name, parent, global_vars)
            except TemplateNotFound as err:
                logger.debug("name %r: %s", name, str(err))
            except TemplateAssertionError as err:
                logger.warning("template %r: %s", name, str(err))
            else:
                results.append(template.render(schema=self.schema))
        return results

    def render_to_files(self, names, parent=None, global_vars=None, output_dir='.', force=False):
        if isinstance(names, str):
            names = [names]
        elif not all(isinstance(x, str) for x in names):
            raise TypeError("'names' argument must contain only strings!")

        template_names = []
        for name in names:
            if is_shell_wildcard(name):
                template_names.extend(self.matching_templates(name))
            else:
                template_names.append(name)

        output_dir = Path(output_dir)
        rendered = []

        for name in template_names:
            try:
                template = self._env.get_template(name, parent, global_vars)
            except TemplateNotFound as err:
                logger.debug("name %r: %s", name, str(err))
            except TemplateAssertionError as err:
                logger.warning("template %r: %s", name, str(err))
            else:
                output_file = output_dir.joinpath(Path(name).name).with_suffix('')
                if not force and output_file.exists():
                    continue

                result = template.render(schema=self.schema)
                logger.info("write file %r", str(output_file))
                with open(output_file, 'w') as fp:
                    fp.write(result)
                rendered.append(str(output_file))

        return rendered

    def map_type(self, obj):
        """
        Maps an XSD type to a type declaration of the target language.
        This method is registered as filter with a name dependant from
        the language name (eg. c_type).

        :param obj: an XSD type or another type-related declaration as \
        an attribute or an element.
        :return: an empty string for non-XSD objects.
        """
        if isinstance(obj, XsdType):
            xsd_type = obj
        elif isinstance(obj, (XsdAttribute, XsdElement)):
            xsd_type = obj.type
        else:
            return ''

        try:
            return self.types_map[xsd_type.name]
        except KeyError:
            try:
                return self.types_map[xsd_type.base_type.name]
            except (KeyError, AttributeError):
                if xsd_type.is_complex():
                    return self.types_map[xsd_qname('anyType')]
                else:
                    return self.types_map[xsd_qname('anySimpleType')]

    @staticmethod
    @filter_method
    def name(obj, unnamed='none'):
        """
        Get the unqualified name of the provided object. Invalid
        chars for identifiers are replaced by an underscore.

        :param obj: an XSD object or a named object or a string.
        :param unnamed: value for unnamed objects. Defaults to 'none'.
        :return: str
        """
        try:
            name = obj.local_name
        except AttributeError:
            try:
                obj = obj.name
            except AttributeError:
                pass

            if not isinstance(obj, str):
                return unnamed

            try:
                if obj[0] == '{':
                    _, name = obj.split('}')
                elif ':' in obj:
                    prefix, name = obj.split(':')
                    if NCNAME_PATTERN.match(prefix) is None:
                        return ''
                else:
                    name = obj
            except (IndexError, ValueError):
                return ''
        else:
            if not isinstance(name, str):
                return ''

        if NCNAME_PATTERN.match(name) is None:
            return unnamed
        return name.replace('.', '_').replace('-', '_')

    @filter_method
    def qname(self, obj, unnamed='none', sep='__'):
        """
        Get the QName of the provided object. Invalid chars for
        identifiers are replaced by an underscore.

        :param obj: an XSD object or a named object or a string.
        :param unnamed: value for unnamed objects. Defaults to 'none'.
        :param sep: the replacement for colon. Defaults to double underscore.
        :return: str
        """
        try:
            qname = obj.prefixed_name
        except AttributeError:
            try:
                obj = obj.name
            except AttributeError:
                pass

            if not isinstance(obj, str):
                return unnamed

            try:
                if obj[0] == '{':
                    namespace, local_name = obj[1:].split('}')
                    for prefix, uri in self.schema.namespaces.items():
                        if uri == namespace:
                            qname = f'{prefix}:{local_name}'
                            break
                    else:
                        qname = local_name
                else:
                    qname = obj
            except IndexError:
                return ''
            except ValueError:
                return unnamed

        if not qname or QNAME_PATTERN.match(qname) is None:
            return unnamed
        return qname.replace('.', '_').replace('-', '_').replace(':', sep)

    @filter_method
    def namespace(self, obj):
        """Get the namespace URI of the provided object."""
        try:
            namespace = obj.target_namespace
        except AttributeError:
            if isinstance(obj, datatypes.QName):
                return obj.namespace
            elif not isinstance(obj, str):
                return ''

            try:
                if obj[0] == '{':
                    namespace, _ = obj[1:].split('}')
                    return namespace
                elif ':' in obj:
                    prefix, _ = obj.split(':')
                    return self.schema.namespaces.get(prefix, '')
                else:
                    return ''
            except (IndexError, ValueError):
                return ''
        else:
            return namespace if isinstance(namespace, str) else ''

    @staticmethod
    @filter_method
    def type_name(obj, suffix=None, unnamed='none'):
        """
        Get the unqualified name of the XSD type. Invalid
        chars for identifiers are replaced by an underscore.

        :param obj: an instance of (XsdType|XsdAttribute|XsdElement).
        :param suffix: force a suffix. For default removes '_type' or 'Type' suffixes.
        :param unnamed: value for unnamed XSD types. Defaults to 'none'.
        :return: str
        """
        if isinstance(obj, XsdType):
            name = obj.local_name or unnamed
        elif isinstance(obj, (XsdElement, XsdAttribute)):
            name = obj.type.local_name or unnamed
        else:
            name = unnamed

        if name.endswith('Type'):
            name = name[:-4]
        elif name.endswith('_type'):
            name = name[:-5]

        if suffix:
            name = f'{name}{suffix}'

        return name.replace('.', '_').replace('-', '_')

    @staticmethod
    @filter_method
    def type_qname(obj, suffix=None, unnamed='none', sep='__'):
        """
        Get the unqualified name of the XSD type. Invalid
        chars for identifiers are replaced by an underscore.

        :param obj: an instance of (XsdType|XsdAttribute|XsdElement).
        :param suffix: force a suffix. For default removes '_type' or 'Type' suffixes.
        :param unnamed: value for unnamed XSD types. Defaults to 'none'.
        :param sep: the replacement for colon. Defaults to double underscore.
        :return: str
        """
        if isinstance(obj, XsdType):
            qname = obj.prefixed_name or unnamed
        elif isinstance(obj, (XsdElement, XsdAttribute)):
            qname = obj.type.prefixed_name or unnamed
        else:
            qname = unnamed

        if qname.endswith('Type'):
            qname = qname[:-4]
        elif qname.endswith('_type'):
            qname = qname[:-5]

        if suffix:
            qname = f'{qname}{suffix}'

        return qname.replace('.', '_').replace('-', '_').replace(':', sep)

    @staticmethod
    @filter_method
    def sort_types(xsd_types, accept_circularity=False):
        """
        Returns a sorted sequence of XSD types usable for building type declarations.

        :param xsd_types: a sequence with XSD types.
        :param accept_circularity: if set to `True` circularities \
        are accepted. Defaults to `False`.
        :return: a list with ordered types.
        """
        if not isinstance(xsd_types, (list, tuple)):
            try:
                xsd_types = list(xsd_types.values())
            except AttributeError:
                pass

        assert all(isinstance(x, XsdType) for x in xsd_types)
        ordered_types = [x for x in xsd_types if x.is_simple()]
        ordered_types.extend(x for x in xsd_types if x.is_complex() and x.has_simple_content())
        unordered = {x: [] for x in xsd_types if x.is_complex() and not x.has_simple_content()}

        for xsd_type in unordered:
            for e in xsd_type.content.iter_elements():
                if e.type in unordered:
                    unordered[xsd_type].append(e.type)

        while unordered:
            deleted = 0
            for xsd_type in xsd_types:
                if xsd_type in unordered:
                    if not unordered[xsd_type]:
                        del unordered[xsd_type]
                        ordered_types.append(xsd_type)
                        deleted += 1

            for xsd_type in unordered:
                unordered[xsd_type] = [x for x in unordered[xsd_type] if x in unordered]

            if not deleted:
                if not accept_circularity:
                    raise ValueError(f"circularity found between {list(unordered)!r}")
                ordered_types.extend(list(unordered))
                break

        assert len(xsd_types) == len(ordered_types)
        return ordered_types

    def is_derived(self, xsd_type, *names, derivation=None):
        """
        Returns `True` if the argument XSD type is derived from any
        of other types expressed by name, otherwise returns `False`.

        :param xsd_type: an XsdComplexType/XsdSimpleType instance.
        :param names: positional argument with the names of other \
        XSD types.
        :param derivation: the type of derivation, that can be \
        *extension* or *restriction*, or both with a space separator. \
        If no value is provided it only checks if it is derived from \
        or if it is the XSD type itself.
        """
        for type_name in names:
            if not isinstance(type_name, str) or not type_name:
                continue  # pragma: no cover
            elif type_name[0] == '{':
                other = self.schema.maps.types.get(type_name)
            else:
                try:
                    expanded_name = self.schema.resolve_qname(type_name)
                except xmlschema.XMLSchemaException:
                    other = self.schema.types.get(type_name)
                else:
                    other = self.schema.maps.types.get(expanded_name)
                    if other is None:
                        other = self.schema.types.get(type_name)

            if other is not None and xsd_type.is_derived(other, derivation):
                return True

        return False

    @test_method
    def derivation(self, xsd_type, *names):
        return self.is_derived(xsd_type, *names)

    @test_method
    def extension(self, xsd_type, *names):
        return self.is_derived(xsd_type, *names, derivation='extension')

    @test_method
    def restriction(self, xsd_type, *names):
        return self.is_derived(xsd_type, *names, derivation='restriction')

    @staticmethod
    @test_method
    def multi_sequence(xsd_type):
        try:
            return any(e.is_multiple() for e in xsd_type.content.iter_elements())
        except AttributeError:
            return False


class PythonGenerator(AbstractGenerator):
    """A Python code generator for XSD schemas."""

    formal_language = 'Python'

    searchpaths = ['templates/python/']

    builtin_types = {
        'string': 'str',
        'decimal': 'decimal.Decimal',
        'float': 'float',
        'double': 'float',

        'date': 'datatypes.Date10',
        'dateTime': 'datatypes.DateTime10',
        'gDay': 'datatypes.GregorianDay',
        'gMonth': 'datatypes.GregorianMonth',
        'gMonthDay': 'datatypes.GregorianMonthDay',
        'gYear': 'datatypes.GregorianYear10',
        'gYearMonth': 'datatypes.GregorianYearMonth10',
        'time': 'datatypes.Time',
        'duration': 'datatypes.Duration',

        'QName': 'datatypes.QName',
        'NOTATION': 'datatypes.DateTime10',
        'anyURI': 'datatypes.AnyURI',
        'boolean': 'bool',
        'base64Binary': 'datatypes.Base64Binary',
        'hexBinary': 'datatypes.HexBinary',
        'normalizedString': 'str',
        'token': 'str',
        'language': 'str',
        'Name': 'str',
        'NCName': 'str',
        'ID': 'str',
        'IDREF': 'str',
        'ENTITY': 'str',
        'NMTOKEN': 'str',

        'integer': 'int',
        'long': 'int',
        'int': 'int',
        'short': 'int',
        'byte': 'int',
        'nonNegativeInteger': 'int',
        'positiveInteger': 'int',
        'unsignedLong': 'int',
        'unsignedInt': 'int',
        'unsignedShort': 'int',
        'unsignedByte': 'int',
        'nonPositiveInteger': 'int',
        'negativeInteger': 'int',

        # XSD 1.1 built-in types
        'dateTimeStamp': 'datatypes.DateTimeStamp10',
        'dayTimeDuration': 'datatypes.DayTimeDuration',
        'yearMonthDuration': 'datatypes.YearMonthDuration',
    }