1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
|
import os
import re
from pathlib import Path
from typing import Dict, Optional, Union
from xsdata.codegen import opener
from xsdata.codegen.parsers import DefinitionsParser, SchemaParser
from xsdata.logger import logger
from xsdata.models.wsdl import Definitions
from xsdata.models.xsd import Schema
class Downloader:
"""
Helper class to download a schema or a definitions with all their imports
locally. The imports paths will be adjusted if necessary.
:param output: Output path
"""
__slots__ = ("output", "base_path", "downloaded")
def __init__(self, output: Path):
self.output = output
self.base_path: Optional[Path] = None
self.downloaded: Dict = {}
def wget(self, uri: str, location: Optional[str] = None):
"""Download handler for any uri input with circular protection."""
if not (uri in self.downloaded or (location and location in self.downloaded)):
self.downloaded[uri] = None
self.downloaded[location] = None
self.adjust_base_path(uri)
logger.info("Fetching %s", uri)
input_stream = opener.open(uri).read() # nosec
if uri.endswith("wsdl"):
self.parse_definitions(uri, input_stream)
else:
self.parse_schema(uri, input_stream)
self.write_file(uri, location, input_stream.decode())
def parse_schema(self, uri: str, content: bytes):
"""Convert content to a schema instance and process all sub imports."""
parser = SchemaParser(location=uri)
schema = parser.from_bytes(content, Schema)
self.wget_included(schema)
def parse_definitions(self, uri: str, content: bytes):
"""Convert content to a definitions instance and process all sub
imports."""
parser = DefinitionsParser(location=uri)
definitions = parser.from_bytes(content, Definitions)
self.wget_included(definitions)
for schema in definitions.schemas:
self.wget_included(schema)
def wget_included(self, definition: Union[Schema, Definitions]):
for included in definition.included():
if included.location:
schema_location = getattr(included, "schema_location", None)
self.wget(included.location, schema_location)
def adjust_base_path(self, uri: str):
"""
Adjust base path for every new uri loaded.
Example runs:
- file:///schemas/air_v48_0/Air.wsdl -> file:///schemas/air_v48_0
- file:///schemas/common_v48_0/CommonReqRsp.xsd -> file:///schemas
"""
if not self.base_path:
self.base_path = Path(uri).parent
logger.info("Setting base path to %s", self.base_path)
else:
common_path = os.path.commonpath((str(self.base_path) or "", uri))
if common_path:
common_path_path = Path(common_path)
if common_path_path < self.base_path:
self.base_path = Path(common_path)
logger.info("Adjusting base path to %s", self.base_path)
def adjust_imports(self, path: Path, content: str) -> str:
"""Try to adjust the import locations for external locations that are
not relative to the first requested uri."""
matches = re.findall(r"ocation=\"(.*)\"", content)
for match in matches:
if isinstance(self.downloaded.get(match), Path):
location = os.path.relpath(self.downloaded[match], path)
replace = str(location).replace("\\", "/")
content = content.replace(f'ocation="{match}"', f'ocation="{replace}"')
return content
def write_file(self, uri: str, location: Optional[str], content: str):
"""
Write the given uri and it's content according to the base path and if
the uri is relative to first requested uri.
Keep track of all the written file paths, in case we have to
modify the location attribute in an upcoming schema/definition
import.
"""
common_path = os.path.commonpath((self.base_path or "", uri))
if common_path:
file_path = self.output.joinpath(Path(uri).relative_to(common_path))
else:
file_path = self.output.joinpath(Path(uri).name)
content = self.adjust_imports(file_path.parent, content)
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_text(content, encoding="utf-8")
logger.info("Writing %s", file_path)
self.downloaded[uri] = file_path
if location:
self.downloaded[location] = file_path
|