1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
|
import logging
import os
from typing import Any, List, Tuple
from xml.etree import ElementTree as stdlibElementTree
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.hashes import Hash
from lxml import etree
from .algorithms import CanonicalizationMethod, DigestAlgorithm, digest_algorithm_implementations
from .exceptions import InvalidInput
from .util import namespaces
logger = logging.getLogger(__name__)
class XMLProcessor:
_schemas: List[Any] = []
schema_files: List[Any] = []
_default_parser, _parser = None, None
_schema_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "schemas"))
@classmethod
def schemas(cls):
if len(cls._schemas) == 0:
for schema_file in cls.schema_files:
schema_path = os.path.join(cls._schema_dir, schema_file)
cls._schemas.append(etree.XMLSchema(etree.parse(schema_path)))
return cls._schemas
@property
def parser(self):
if self._parser is None:
if self._default_parser is None:
self._default_parser = etree.XMLParser(resolve_entities=False)
return self._default_parser
return self._parser
def _fromstring(self, xml_string, **kwargs):
xml_node = etree.fromstring(xml_string, parser=self.parser, **kwargs)
for entity in xml_node.iter(etree.Entity):
raise InvalidInput("Entities are not supported in XML input")
return xml_node
def _tostring(self, xml_node, **kwargs):
return etree.tostring(xml_node, **kwargs)
def get_root(self, data):
if isinstance(data, (str, bytes)):
return self._fromstring(data)
elif isinstance(data, stdlibElementTree.Element):
# TODO: add debug level logging statement re: performance impact here
return self._fromstring(stdlibElementTree.tostring(data, encoding="utf-8"))
else:
# Create a separate copy of the node so we can modify the tree and avoid any c14n inconsistencies from
# namespaces propagating from parent nodes. The lxml docs recommend using copy.deepcopy for this, but it
# doesn't seem to preserve namespaces. It would be nice to find a less heavy-handed way of doing this.
return self._fromstring(self._tostring(data))
class XMLSignatureProcessor(XMLProcessor):
schema_files = ["xmldsig1-schema.xsd"]
# See https://tools.ietf.org/html/rfc5656
known_ecdsa_curves = {
"urn:oid:1.2.840.10045.3.1.7": ec.SECP256R1,
"urn:oid:1.3.132.0.34": ec.SECP384R1,
"urn:oid:1.3.132.0.35": ec.SECP521R1,
"urn:oid:1.3.132.0.1": ec.SECT163K1,
"urn:oid:1.2.840.10045.3.1.1": ec.SECP192R1,
"urn:oid:1.3.132.0.33": ec.SECP224R1,
"urn:oid:1.3.132.0.26": ec.SECT233K1,
"urn:oid:1.3.132.0.27": ec.SECT233R1,
"urn:oid:1.3.132.0.16": ec.SECT283R1,
"urn:oid:1.3.132.0.36": ec.SECT409K1,
"urn:oid:1.3.132.0.37": ec.SECT409R1,
"urn:oid:1.3.132.0.38": ec.SECT571K1,
}
known_ecdsa_curve_oids = {ec().name: oid for oid, ec in known_ecdsa_curves.items()} # type: ignore[abstract]
excise_empty_xmlns_declarations = False
id_attributes: Tuple[str, ...] = ("Id", "ID", "id", "xml:id")
def _get_digest(self, data, algorithm: DigestAlgorithm):
algorithm_implementation = digest_algorithm_implementations[algorithm]()
hasher = Hash(algorithm=algorithm_implementation)
hasher.update(data)
return hasher.finalize()
def _find(self, element, query, require=True, xpath=""):
namespace = "ds"
if ":" in query:
namespace, _, query = query.partition(":")
result = element.find(f"{xpath}{namespace}:{query}", namespaces=namespaces)
if require and result is None:
raise InvalidInput(f"Expected to find XML element {query} in {element.tag}")
return result
def _findall(self, element, query, xpath=""):
namespace = "ds"
if ":" in query:
namespace, _, query = query.partition(":")
return element.findall(f"{xpath}{namespace}:{query}", namespaces=namespaces)
def _c14n(self, nodes, algorithm: CanonicalizationMethod, inclusive_ns_prefixes=None):
exclusive, with_comments = False, False
if algorithm.value.startswith("http://www.w3.org/2001/10/xml-exc-c14n#"):
exclusive = True
if algorithm.value.endswith("#WithComments"):
with_comments = True
if not isinstance(nodes, list):
nodes = [nodes]
c14n = b""
for node in nodes:
c14n += etree.tostring(
node,
method="c14n",
exclusive=exclusive,
with_comments=with_comments,
inclusive_ns_prefixes=inclusive_ns_prefixes,
)
if exclusive is False and self.excise_empty_xmlns_declarations is True:
# Incorrect legacy behavior. See also:
# - https://github.com/XML-Security/signxml/issues/193
# - http://www.w3.org/TR/xml-c14n, "namespace axis"
# - http://www.w3.org/TR/xml-c14n2/#sec-Namespace-Processing
c14n = c14n.replace(b' xmlns=""', b"")
logger.debug("Canonicalized string (exclusive=%s, with_comments=%s): %s", exclusive, with_comments, c14n)
return c14n
def _resolve_reference(self, doc_root, reference, uri_resolver=None):
uri = reference.get("URI")
if uri is None:
raise InvalidInput("References without URIs are not supported")
elif uri == "":
return doc_root
elif uri.startswith("#xpointer("):
raise InvalidInput("XPointer references are not supported")
# doc_root.xpath(uri.lstrip("#"))[0]
elif uri.startswith("#"):
for id_attribute in self.id_attributes:
xpath_query = f"//*[@*[local-name() = '{id_attribute}']=$uri]"
results = doc_root.xpath(xpath_query, uri=uri.lstrip("#"))
if len(results) > 1:
raise InvalidInput(f"Ambiguous reference URI {uri} resolved to {len(results)} nodes")
elif len(results) == 1:
return results[0]
raise InvalidInput(f"Unable to resolve reference URI: {uri}")
else:
if uri_resolver is None:
raise InvalidInput(f"External URI dereferencing is not configured: {uri}")
result = uri_resolver(uri)
if result is None:
raise InvalidInput(f"Unable to resolve reference URI: {uri}")
return result
|