File: processor.py

package info (click to toggle)
python-signxml 4.0.5%2Bdfsg-2
  • links: PTS
  • area: main
  • in suites: forky, trixie
  • size: 4,896 kB
  • sloc: xml: 9,822; python: 2,370; javascript: 57; makefile: 35; sh: 8
file content (160 lines) | stat: -rw-r--r-- 6,732 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import logging
import os
from typing import Any, List, Tuple
from xml.etree import ElementTree as stdlibElementTree

from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.hashes import Hash
from lxml import etree

from .algorithms import CanonicalizationMethod, DigestAlgorithm, digest_algorithm_implementations
from .exceptions import InvalidInput
from .util import namespaces

logger = logging.getLogger(__name__)


class XMLProcessor:
    _schemas: List[Any] = []
    schema_files: List[Any] = []
    _default_parser, _parser = None, None
    _schema_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "schemas"))

    @classmethod
    def schemas(cls):
        if len(cls._schemas) == 0:
            for schema_file in cls.schema_files:
                schema_path = os.path.join(cls._schema_dir, schema_file)
                cls._schemas.append(etree.XMLSchema(etree.parse(schema_path)))
        return cls._schemas

    @property
    def parser(self):
        if self._parser is None:
            if self._default_parser is None:
                self._default_parser = etree.XMLParser(resolve_entities=False)
            return self._default_parser
        return self._parser

    def _fromstring(self, xml_string, **kwargs):
        xml_node = etree.fromstring(xml_string, parser=self.parser, **kwargs)
        for entity in xml_node.iter(etree.Entity):
            raise InvalidInput("Entities are not supported in XML input")
        return xml_node

    def _tostring(self, xml_node, **kwargs):
        return etree.tostring(xml_node, **kwargs)

    def get_root(self, data):
        if isinstance(data, (str, bytes)):
            return self._fromstring(data)
        elif isinstance(data, stdlibElementTree.Element):
            # TODO: add debug level logging statement re: performance impact here
            return self._fromstring(stdlibElementTree.tostring(data, encoding="utf-8"))
        else:
            # Create a separate copy of the node so we can modify the tree and avoid any c14n inconsistencies from
            # namespaces propagating from parent nodes. The lxml docs recommend using copy.deepcopy for this, but it
            # doesn't seem to preserve namespaces. It would be nice to find a less heavy-handed way of doing this.
            return self._fromstring(self._tostring(data))


class XMLSignatureProcessor(XMLProcessor):
    schema_files = ["xmldsig1-schema.xsd"]

    # See https://tools.ietf.org/html/rfc5656
    known_ecdsa_curves = {
        "urn:oid:1.2.840.10045.3.1.7": ec.SECP256R1,
        "urn:oid:1.3.132.0.34": ec.SECP384R1,
        "urn:oid:1.3.132.0.35": ec.SECP521R1,
        "urn:oid:1.3.132.0.1": ec.SECT163K1,
        "urn:oid:1.2.840.10045.3.1.1": ec.SECP192R1,
        "urn:oid:1.3.132.0.33": ec.SECP224R1,
        "urn:oid:1.3.132.0.26": ec.SECT233K1,
        "urn:oid:1.3.132.0.27": ec.SECT233R1,
        "urn:oid:1.3.132.0.16": ec.SECT283R1,
        "urn:oid:1.3.132.0.36": ec.SECT409K1,
        "urn:oid:1.3.132.0.37": ec.SECT409R1,
        "urn:oid:1.3.132.0.38": ec.SECT571K1,
    }
    known_ecdsa_curve_oids = {ec().name: oid for oid, ec in known_ecdsa_curves.items()}  # type: ignore[abstract]

    excise_empty_xmlns_declarations = False

    id_attributes: Tuple[str, ...] = ("Id", "ID", "id", "xml:id")

    def _get_digest(self, data, algorithm: DigestAlgorithm):
        algorithm_implementation = digest_algorithm_implementations[algorithm]()
        hasher = Hash(algorithm=algorithm_implementation)
        hasher.update(data)
        return hasher.finalize()

    def _find(self, element, query, require=True, xpath=""):
        namespace = "ds"
        if ":" in query:
            namespace, _, query = query.partition(":")
        result = element.find(f"{xpath}{namespace}:{query}", namespaces=namespaces)

        if require and result is None:
            raise InvalidInput(f"Expected to find XML element {query} in {element.tag}")
        return result

    def _findall(self, element, query, xpath=""):
        namespace = "ds"
        if ":" in query:
            namespace, _, query = query.partition(":")
        return element.findall(f"{xpath}{namespace}:{query}", namespaces=namespaces)

    def _c14n(self, nodes, algorithm: CanonicalizationMethod, inclusive_ns_prefixes=None):
        exclusive, with_comments = False, False

        if algorithm.value.startswith("http://www.w3.org/2001/10/xml-exc-c14n#"):
            exclusive = True
        if algorithm.value.endswith("#WithComments"):
            with_comments = True

        if not isinstance(nodes, list):
            nodes = [nodes]

        c14n = b""
        for node in nodes:
            c14n += etree.tostring(
                node,
                method="c14n",
                exclusive=exclusive,
                with_comments=with_comments,
                inclusive_ns_prefixes=inclusive_ns_prefixes,
            )
        if exclusive is False and self.excise_empty_xmlns_declarations is True:
            # Incorrect legacy behavior. See also:
            # - https://github.com/XML-Security/signxml/issues/193
            # - http://www.w3.org/TR/xml-c14n, "namespace axis"
            # - http://www.w3.org/TR/xml-c14n2/#sec-Namespace-Processing
            c14n = c14n.replace(b' xmlns=""', b"")
        logger.debug("Canonicalized string (exclusive=%s, with_comments=%s): %s", exclusive, with_comments, c14n)
        return c14n

    def _resolve_reference(self, doc_root, reference, uri_resolver=None):
        uri = reference.get("URI")
        if uri is None:
            raise InvalidInput("References without URIs are not supported")
        elif uri == "":
            return doc_root
        elif uri.startswith("#xpointer("):
            raise InvalidInput("XPointer references are not supported")
            # doc_root.xpath(uri.lstrip("#"))[0]
        elif uri.startswith("#"):
            for id_attribute in self.id_attributes:
                xpath_query = f"//*[@*[local-name() = '{id_attribute}']=$uri]"
                results = doc_root.xpath(xpath_query, uri=uri.lstrip("#"))
                if len(results) > 1:
                    raise InvalidInput(f"Ambiguous reference URI {uri} resolved to {len(results)} nodes")
                elif len(results) == 1:
                    return results[0]
            raise InvalidInput(f"Unable to resolve reference URI: {uri}")
        else:
            if uri_resolver is None:
                raise InvalidInput(f"External URI dereferencing is not configured: {uri}")
            result = uri_resolver(uri)
            if result is None:
                raise InvalidInput(f"Unable to resolve reference URI: {uri}")
            return result