File: exports.py

package info (click to toggle)
python-xmlschema 4.1.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 5,208 kB
  • sloc: python: 39,174; xml: 1,282; makefile: 36
file content (375 lines) | stat: -rw-r--r-- 15,038 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
#
# Copyright (c), 2016-2023, SISSA (International School for Advanced Studies).
# All rights reserved.
# This file is distributed under the terms of the MIT License.
# See the file 'LICENSE' in the root directory of the present
# distribution, or http://opensource.org/licenses/MIT.
#
# @author Davide Brunato <brunato@sissa.it>
#
import re
import logging
import pprint
from dataclasses import dataclass
from itertools import chain
from pathlib import Path
from collections.abc import Iterable
from typing import Any, Optional, Union
from urllib.parse import unquote, urlsplit
from xml.etree import ElementTree

from xmlschema.aliases import SchemaType
from xmlschema.exceptions import XMLSchemaValueError, XMLResourceOSError
from xmlschema.names import XSD_SCHEMA, XSD_IMPORT, XSD_INCLUDE, XSD_REDEFINE, XSD_OVERRIDE
from xmlschema.utils.logger import logged
from xmlschema.utils.paths import LocationPath
from xmlschema.utils.urls import is_remote_url, normalize_url, match_location
from xmlschema.translation import gettext as _
from xmlschema.resources import XMLResource

logger = logging.getLogger('xmlschema')

FIND_PATTERN = r'\bschemaLocation\s*=\s*[\'"]([^\'"]*)[\'"]'
REPLACE_PATTERN = r'\bschemaLocation\s*=\s*[\'"]\s*{0}\s*[\'"]'


@dataclass
class XsdSource:
    """Class for keeping track of an XSD schema source."""
    path: LocationPath
    resource: XMLResource

    def __init__(self, path: LocationPath, resource: XMLResource) -> None:
        self.path = path
        self.resource = resource
        self.text = resource.get_text()
        self.processed = False
        self.modified = False
        self.substitutions: Optional[list[tuple[str, str]]] = None

    @property
    def schema_locations(self) -> set[str]:
        """Extract schema locations from XSD resource tree."""
        locations = set()
        for child in self.resource.root:
            if child.tag in (XSD_IMPORT, XSD_INCLUDE, XSD_REDEFINE, XSD_OVERRIDE):
                schema_location = child.get('schemaLocation', '').strip()
                if schema_location:
                    locations.add(schema_location)

        return locations

    def replace_location(self, location: str, repl_location: str) -> None:
        if location == repl_location:
            return

        logger.debug("Replace location %r with %r", location, repl_location)
        repl = f'schemaLocation="{repl_location}"'
        pattern = REPLACE_PATTERN.format(re.escape(location))
        self.text = re.sub(pattern, repl, self.text)
        self.modified = True

    def get_location_path(self, location: str,
                          ref: Union[SchemaType, XMLResource],
                          modify: bool = True) -> LocationPath:
        """
        Return a relative location path for the referred XSD schema, replacing
        the original location in the schema source, if necessary.
        """
        parts: Any

        if is_remote_url(location):
            parts = urlsplit(unquote(location))
            path = LocationPath(parts.scheme). \
                joinpath(parts.netloc). \
                joinpath(parts.path.lstrip('/'))
        else:
            if location.startswith('file:/'):
                path = LocationPath(unquote(urlsplit(location).path))
            else:
                path = LocationPath(unquote(location))

            if not path.is_absolute():
                path = self.path.parent.joinpath(path).normalize()
                if not str(path).startswith('..'):
                    # A relative path that doesn't exceed the loading schema dir
                    return path

                # Use the absolute resource path
                path = LocationPath(ref.filepath)  # type: ignore[arg-type]

            if path.drive:
                drive = path.drive.split(':')[0]
                path = LocationPath(drive).joinpath('/'.join(path.parts[1:]))

            path = LocationPath('file').joinpath(path.as_posix().lstrip('/'))

        if path.is_absolute():
            raise XMLSchemaValueError(f'Replacing path {path} is not relative!')

        # Obtain the replacement location
        parts = path.parent.parts
        dir_parts = self.path.parent.parts

        k = 0
        for item1, item2 in zip(parts, dir_parts):
            if item1 != item2:
                break
            k += 1

        if not k:
            prefix = '/'.join(['..'] * len(dir_parts))
            repl_path = LocationPath(prefix).joinpath(path)
        else:
            repl_path = LocationPath('/'.join(parts[k:])).joinpath(path.name)
            if k < len(dir_parts):
                prefix = '/'.join(['..'] * (len(dir_parts) - k))
                repl_path = LocationPath(prefix).joinpath(repl_path)

        repl_location = repl_path.as_posix()
        if location != repl_location:
            if self.substitutions is None:
                self.substitutions = []
            self.substitutions.append((location, repl_location))

            if modify:
                self.replace_location(location, repl_location)

        return path


def save_sources(target: Union[str, Path],
                 sources: Iterable[XsdSource],
                 save_locations: bool = False) -> dict[str, str]:
    """Save XSD sources to a target directory."""
    target_path = Path(target) if isinstance(target, str) else target
    if target_path.is_dir():
        if list(target_path.iterdir()):
            msg = _("target directory {} is not empty")
            raise XMLSchemaValueError(msg.format(target))
    elif target_path.exists():
        msg = _("target {} is not a directory")
        raise XMLSchemaValueError(msg.format(target_path.parent))
    elif not target_path.parent.exists():
        msg = _("target parent directory {} does not exist")
        raise XMLSchemaValueError(msg.format(target_path.parent))
    elif not target_path.parent.is_dir():
        msg = _("target parent {} is not a directory")
        raise XMLSchemaValueError(msg.format(target_path.parent))

    location_map = {}

    for src in sources:
        assert src.processed

        filepath = target_path.joinpath(src.path)

        # Safety check: raise error if filepath is not inside the target path
        try:
            filepath.resolve(strict=False).relative_to(target_path.resolve(strict=False))
        except ValueError:
            msg = _("target directory {} violation for exported path {}, {}")
            raise XMLSchemaValueError(msg.format(target, str(src.path), str(filepath)))

        if not filepath.parent.exists():
            filepath.parent.mkdir(parents=True)

        encoding = 'utf-8'  # default encoding for XML 1.0

        if src.text.startswith('<?'):
            # Get the encoding from XML declaration
            xml_declaration = src.text.split('\n', maxsplit=1)[0]
            re_match = re.search('(?<=encoding=["\'])[^"\']+', xml_declaration)
            if re_match is not None:
                encoding = re_match.group(0).lower()

        if src.modified:
            logger.info("Write modified XSD source to %s", filepath)
        else:
            logger.info("Write unchanged XSD source to %s", filepath)

        if src.substitutions:
            for location, repl_location in src.substitutions:
                if location not in location_map:
                    location_map[location] = repl_location
                elif repl_location != location_map[location]:
                    logger.warning("Substitution collision for location %r: %r != %r",
                                   location, repl_location, location_map[location])

        with filepath.open(mode='w', encoding=encoding) as fp:
            fp.write(src.text)

    if save_locations:
        with target_path.joinpath('__init__.py').open('w') as fp:
            logger.info("Write LOCATION_MAP to %s", fp.name)
            fp.write(f'LOCATION_MAP = {pprint.pformat(location_map)}')

    return location_map


@logged
def export_schema(schema: SchemaType,
                  target: Union[str, Path],
                  save_remote: bool = False,
                  remove_residuals: bool = True,
                  exclude_locations: Optional[list[str]] = None,
                  loglevel: Optional[Union[str, int]] = None) -> dict[str, str]:
    """
    Export XSD sources used by a schema instance to a target directory.
    Don't use this function directly, use XMLSchema.export() method instead.
    """
    def residuals_filter(x: str) -> bool:
        return is_remote_url(x) and x not in schema.includes and \
            (exclude_locations is None or x not in exclude_locations)

    if loglevel is not None:
        logger.info("Export schema using loglevel %r", loglevel)

    name = schema.name or 'schema.xsd'
    exports = {schema: XsdSource(LocationPath(name), schema.source)}
    path: Any

    if exclude_locations is None:
        exclude_locations = []

    logger.debug("Start export of schema %r", name)

    while True:
        current_length = len(exports)

        for schema in list(exports):
            schema_source = exports[schema]
            if schema_source.processed:
                continue  # Skip already processed schemas

            schema_source.processed = True
            logger.debug("Process schema instance %r", schema)

            schema_locations = schema_source.schema_locations

            imports_items = [(x.url, x) for x in schema.imports.values()
                             if x is not None and x.meta_schema is not None]

            for location, ref_schema in chain(schema.includes.items(), imports_items):
                if not location:
                    continue
                elif location in exclude_locations or not save_remote and is_remote_url(location):
                    logger.debug("Location %r is excluded by argument", location)
                    continue

                # Find matching schema location
                location_match = match_location(location, schema_locations)
                if location_match is None:
                    logger.debug("Unmatched location %r, skip ...", location)
                    continue

                location = location_match
                logger.debug("Matched location %r", location)
                schema_locations.remove(location)

                path = schema_source.get_location_path(location, ref_schema)
                if ref_schema not in exports:
                    exports[ref_schema] = XsdSource(path, ref_schema.source)

            if remove_residuals:
                # Deactivate residual redundant imports from remote URLs
                for location in filter(residuals_filter, schema_locations):
                    logger.debug("Clear residual remote location %r", location)
                    schema_source.replace_location(location, '')

        if current_length == len(exports):
            break

    return save_sources(target, exports.values())


@logged
def download_schemas(url: str,
                     target: Union[str, Path],
                     save_remote: bool = True,
                     save_locations: bool = True,
                     modify: bool = False,
                     defuse: str = 'remote',
                     timeout: int = 300,
                     exclude_locations: Optional[list[str]] = None,
                     loglevel: Optional[Union[str, int]] = None) -> dict[str, str]:
    """
    Download one or more schemas from a URL and save them in a target directory. All the
    referred locations in schema sources are downloaded and stored in the target directory.

    :param url: The URL of the schema to download, usually a remote one.
    :param target: the target directory to save the schema.
    :param save_remote: if to save remote schemas, defaults to `True`.
    :param save_locations: for default save a LOCATION_MAP dictionary to a `__init__.py`, \
    that can be imported in your code to provide a *uri_mapper* argument for build the \
    schema instance. Provide `False` to skip the package file creation in the target \
    directory.
    :param modify: provide `True` to modify original schemas, defaults to `False`.
    :param defuse: when to defuse XML data before loading, defaults to `'remote'`.
    :param timeout: the timeout in seconds for the connection attempt in case of remote data.
    :param exclude_locations: provide a list of locations to skip.
    :param loglevel: for setting a different logging level for schema downloads call.
    :return: a dictionary containing the map of modified locations.
    """
    if loglevel is not None:
        logger.info("Download schemas using loglevel %r", loglevel)

    resource = XMLResource(url, defuse=defuse, timeout=timeout)
    logger.info("Downloaded XML resource from %s", url)
    if resource.root.tag != XSD_SCHEMA:
        raise XMLSchemaValueError(f'Resource referred by {url} is not a XSD schema')

    name = resource.name
    downloads = {
        resource: XsdSource(LocationPath(name), resource)  # type: ignore[arg-type]
    }
    path: Any

    if exclude_locations is None:
        exclude_locations = []

    logger.debug("Start download of schema resource %r", name)

    while True:
        current_length = len(downloads)

        for resource in list(downloads):
            schema_source = downloads[resource]
            if schema_source.processed:
                continue  # Skip already processed schemas

            schema_source.processed = True
            logger.debug("Process schema resource %r", resource)
            schema_locations = schema_source.schema_locations

            for location in schema_locations:
                if location in exclude_locations or not save_remote and is_remote_url(location):
                    logger.debug("Location %r is excluded by argument", location)
                    continue

                url = normalize_url(location, resource.base_url)
                if any(x.url == url for x in downloads):
                    continue

                try:
                    ref_resource = XMLResource(url, defuse=defuse, timeout=timeout)
                except (OSError, XMLResourceOSError) as err:
                    logger.error('Error accessing resource at URL %s: %s', url, err)
                    continue
                except ElementTree.ParseError as err:
                    logger.error('Error parsing XML resource at URL %s: %s', url, err)
                    continue
                else:
                    logger.info("Downloaded XML resource from %s", url)

                if ref_resource.root.tag != XSD_SCHEMA:
                    logger.error('XML resource at URL %s is not an XSD schema', url)
                    continue

                path = schema_source.get_location_path(location, ref_resource, modify)
                downloads[ref_resource] = XsdSource(path, ref_resource)

        if current_length == len(downloads):
            break

    return save_sources(target, downloads.values(), save_locations)