File: verifier.py

package info (click to toggle)
python-signxml 4.0.5%2Bdfsg-2
  • links: PTS
  • area: main
  • in suites: forky, trixie
  • size: 4,896 kB
  • sloc: xml: 9,822; python: 2,370; javascript: 57; makefile: 35; sh: 8
file content (659 lines) | stat: -rw-r--r-- 35,047 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
from base64 import b64decode
from dataclasses import dataclass, replace
from typing import Callable, FrozenSet, List, Optional, Tuple, Union
from warnings import warn

import cryptography.exceptions
from cryptography import x509
from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa, utils
from cryptography.hazmat.primitives.asymmetric.padding import MGF1, PSS, AsymmetricPadding, PKCS1v15
from cryptography.hazmat.primitives.hmac import HMAC
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat, load_der_public_key
from lxml import etree

from .algorithms import (
    CanonicalizationMethod,
    DigestAlgorithm,
    SignatureConstructionMethod,
    SignatureMethod,
    digest_algorithm_implementations,
)
from .exceptions import InvalidCertificate, InvalidDigest, InvalidInput, InvalidSignature, SignXMLException
from .processor import XMLSignatureProcessor
from .util import (
    X509CertChainVerifier,
    _remove_sig,
    add_pem_header,
    bits_to_bytes_unit,
    bytes_to_long,
    ds_tag,
    namespaces,
)


@dataclass(frozen=True)
class SignatureConfiguration:
    """
    A container holding signature settings that will be used to assert properties of the signature.
    """

    require_x509: bool = True
    """
    If ``True``, a valid X.509 certificate-based signature with an established chain of trust is required to
    pass validation. If ``False``, other types of valid signatures (e.g. HMAC or RSA public key) are accepted.

    Ensure you establish trust: to specify an HMAC shared secret that will be required to be used by the signature, use
    the ``hmac_key`` parameter. If ``require_x509`` is ``False`` and no ``hmac_key`` is set, any valid XML signature
    satisfying the other SignatureConfiguration parameters will be accepted, regardless of the key used. You must then
    establish trust by examining the ``VerifyResult`` contents, in particular asserting on the value of ``signature_key``.
    """

    location: str = ".//"
    """
    XPath location where the signature tag will be expected. By default, the signature tag is expected to be a child of
    the top level element (i.e. enveloped at the top level). If your signature is enveloping (i.e. the ``ds:Signature``
    tag is itself the top level tag), it is recommended that you set this to ``./``. If your signature is nested
    elsewhere in the document, you can reference the full path as ``./{ns}Tag1/{ns}Tag2/{ns}Tag3/``. If you wish to
    search for the signature anywhere in the document, you can set this to ``.//``.
    """

    expect_references: Union[int, bool] = 1
    """
    Number of references to expect in the signature. If this is not 1, an array of VerifyResults is returned.
    If set to a non-integer, any number of references is accepted (otherwise a mismatch raises an error).
    """

    signature_methods: FrozenSet[SignatureMethod] = frozenset(sm for sm in SignatureMethod if "SHA1" not in sm.name)
    """
    Set of acceptable signature methods (signature algorithms). Any signature generated using an algorithm not listed
    here will fail verification. It is recommended that you restrict this set to only methods expected to be used in
    signatures handled by your application.
    """

    digest_algorithms: FrozenSet[DigestAlgorithm] = frozenset(da for da in DigestAlgorithm if "SHA1" not in da.name)
    """
    Set of acceptable digest algorithms. Any signature or reference transform generated using an algorithm not listed
    here will cause verification to fail. It is recommended that you restrict this set to only methods expected to be
    used in signatures handled by your application.
    """

    ignore_ambiguous_key_info: bool = False
    """
    Ignore the presence of a KeyValue element when X509Data is present in the signature and used for verifying.
    The presence of both elements is an ambiguity and a security hazard. The public key used to sign the
    document is already encoded in the certificate (which is in X509Data), so the verifier must either ignore
    KeyValue or make sure it matches what's in the certificate. SignXML does not implement the functionality
    necessary to match the keys, and throws an InvalidInput error instead. Set this to True to bypass the error
    and validate the signature using X509Data only.
    """


@dataclass(frozen=True)
class VerifyResult:
    """
    This is a dataclass representing structured data returned by :func:`signxml.XMLVerifier.verify`. The results of a
    verification contain the signed bytes, the parsed signed XML, the parsed signature XML, and the key that was used
    to verify the signature. Example usage:

        verified_data = signxml.XMLVerifier().verify(input_data).signed_xml
    """

    signed_data: bytes
    "The binary data as it was signed"

    signed_xml: Optional[etree._Element]
    "The signed data parsed as XML (or None if parsing failed)"

    signature_xml: etree._Element
    "The signature element parsed as XML"

    signature_key: bytes
    "The cryptographic key that was used to verify the signature, in PEM format (for asymmetric keys) or raw secret bytes (for HMAC keys)"


class XMLVerifier(XMLSignatureProcessor):
    """
    Create a new XML Signature Verifier object, which can be used to verify multiple pieces of data.
    """

    _default_reference_c14n_method = CanonicalizationMethod.CANONICAL_XML_1_0

    def _get_signature(self, root):
        if root.tag == ds_tag("Signature"):
            return root
        else:
            return self._find(root, "Signature", xpath=self.config.location)

    def _verify_signature_with_pubkey(
        self,
        *,
        signed_info_c14n: bytes,
        raw_signature: bytes,
        signature_alg: SignatureMethod,
        key_value: Optional[etree._Element] = None,
        der_encoded_key_value: Optional[etree._Element] = None,
        signing_certificate: Optional[x509.Certificate] = None,
    ) -> Tuple[bytes, Union[bytes, dsa.DSAPublicKey, rsa.RSAPublicKey, ec.EllipticCurvePublicKey]]:
        if der_encoded_key_value is not None:
            assert der_encoded_key_value.text is not None
            key = load_der_public_key(b64decode(der_encoded_key_value.text))
        elif signing_certificate is not None:
            key = signing_certificate.public_key()
        elif key_value is None:
            raise InvalidInput("Expected one of key_value, der_encoded_key_value, or signing_certificate to be set")

        digest_alg_impl = digest_algorithm_implementations[signature_alg]()
        if signature_alg.name.startswith("ECDSA_"):
            if key_value is not None:
                ec_key_value = self._find(key_value, "dsig11:ECKeyValue")
                named_curve = self._find(ec_key_value, "dsig11:NamedCurve")
                public_key = self._find(ec_key_value, "dsig11:PublicKey")
                key_data = b64decode(public_key.text)[1:]
                x = bytes_to_long(key_data[: len(key_data) // 2])
                y = bytes_to_long(key_data[len(key_data) // 2 :])
                curve_class = self.known_ecdsa_curves[named_curve.get("URI")]
                ecpn = ec.EllipticCurvePublicNumbers(x=x, y=y, curve=curve_class())  # type: ignore[abstract]
                key = ecpn.public_key()
            elif not isinstance(key, ec.EllipticCurvePublicKey):
                raise InvalidInput("DER encoded key value does not match specified signature algorithm")
            signature_for_ecdsa = self._encode_dss_signature(raw_signature, key.key_size)
            key.verify(signature_for_ecdsa, data=signed_info_c14n, signature_algorithm=ec.ECDSA(digest_alg_impl))
        elif signature_alg.name.startswith("DSA_"):
            if key_value is not None:
                dsa_key_value = self._find(key_value, "DSAKeyValue")
                p = self._get_long(dsa_key_value, "P")
                q = self._get_long(dsa_key_value, "Q")
                g = self._get_long(dsa_key_value, "G", require=False)
                y = self._get_long(dsa_key_value, "Y")
                dsapn = dsa.DSAPublicNumbers(y=y, parameter_numbers=dsa.DSAParameterNumbers(p=p, q=q, g=g))
                key = dsapn.public_key()
            elif not isinstance(key, dsa.DSAPublicKey):
                raise InvalidInput("DER encoded key value does not match specified signature algorithm")
            # TODO: supply meaningful key_size_bits for signature length assertion
            dss_signature = self._encode_dss_signature(raw_signature, len(raw_signature) * 8 // 2)
            key.verify(dss_signature, data=signed_info_c14n, algorithm=digest_alg_impl)
        elif signature_alg.name.startswith("RSA_") or signature_alg.name.startswith("SHA"):
            if key_value is not None:
                rsa_key_value = self._find(key_value, "RSAKeyValue")
                modulus = self._get_long(rsa_key_value, "Modulus")
                exponent = self._get_long(rsa_key_value, "Exponent")
                key = rsa.RSAPublicNumbers(e=exponent, n=modulus).public_key()
            elif not isinstance(key, rsa.RSAPublicKey):
                raise InvalidInput("DER encoded key value does not match specified signature algorithm")
            if signature_alg.name.startswith("RSA_"):
                padding: AsymmetricPadding = PKCS1v15()
            else:
                padding = PSS(mgf=MGF1(algorithm=digest_alg_impl), salt_length=digest_alg_impl.digest_size)
            key.verify(raw_signature, data=signed_info_c14n, padding=padding, algorithm=digest_alg_impl)
        else:
            raise InvalidInput(f"Unsupported signature algorithm {signature_alg}")
        return signed_info_c14n, key

    def _encode_dss_signature(self, raw_signature: bytes, key_size_bits: int) -> bytes:
        want_raw_signature_len = bits_to_bytes_unit(key_size_bits) * 2
        if len(raw_signature) != want_raw_signature_len:
            raise InvalidSignature(
                "Expected %d byte SignatureValue, got %d" % (want_raw_signature_len, len(raw_signature))
            )
        int_len = len(raw_signature) // 2
        r = bytes_to_long(raw_signature[:int_len])
        s = bytes_to_long(raw_signature[int_len:])
        return utils.encode_dss_signature(r, s)

    def _get_inclusive_ns_prefixes(self, transform_node):
        inclusive_namespaces = transform_node.find("./ec:InclusiveNamespaces[@PrefixList]", namespaces=namespaces)
        if inclusive_namespaces is None:
            return None
        else:
            return inclusive_namespaces.get("PrefixList").split(" ")

    def _apply_transforms(self, payload, *, transforms_node: etree._Element, signature: etree._Element):
        transforms, c14n_applied = [], False
        if transforms_node is not None:
            transforms = self._findall(transforms_node, "Transform")

        for transform in transforms:
            if transform.get("Algorithm") == SignatureConstructionMethod.enveloped.value:
                _remove_sig(signature, idempotent=True)

        for transform in transforms:
            if transform.get("Algorithm") == "http://www.w3.org/2000/09/xmldsig#base64":
                payload = b64decode(payload.text)

        for transform in transforms:
            algorithm = transform.get("Algorithm")
            try:
                c14n_algorithm_from_transform = CanonicalizationMethod(algorithm)
            except ValueError:
                continue
            inclusive_ns_prefixes = self._get_inclusive_ns_prefixes(transform)

            # Create a separate copy of the node so we can modify the tree and avoid any c14n inconsistencies from
            # namespaces propagating from parent nodes. The lxml docs recommend using copy.deepcopy for this, but it
            # doesn't seem to preserve namespaces. It would be nice to find a less heavy-handed way of doing this.
            payload = self._fromstring(self._tostring(payload))

            payload = self._c14n(
                payload, algorithm=c14n_algorithm_from_transform, inclusive_ns_prefixes=inclusive_ns_prefixes
            )
            c14n_applied = True

        if not c14n_applied and not isinstance(payload, (str, bytes)):
            payload = self._c14n(payload, algorithm=self._default_reference_c14n_method)

        return payload

    def get_cert_chain_verifier(self, ca_pem_file):
        return X509CertChainVerifier(ca_pem_file=ca_pem_file)

    def _match_key_values(self, key_value, der_encoded_key_value, signing_cert, signature_alg):
        if self.config.ignore_ambiguous_key_info is False:
            return
        cert_pub_key = signing_cert.public_key()
        # If both X509Data and KeyValue are present, match one against the other and raise an error on mismatch
        if key_value is not None:
            match_result = self._check_key_value_matches_cert_public_key(key_value, cert_pub_key, signature_alg)
            if match_result is False:
                raise InvalidInput(
                    "Both X509Data and KeyValue found and they represent different public keys. "
                    "Use verify(ignore_ambiguous_key_info=True) to ignore KeyValue and validate "
                    "using X509Data only."
                )

        # If both X509Data and DEREncodedKeyValue are present, match one against the other and raise an error on
        # mismatch
        if der_encoded_key_value is not None:
            match_result = self._check_der_key_value_matches_cert_public_key(
                der_encoded_key_value, signing_cert.public_key(), signature_alg
            )
            if match_result is False:
                raise InvalidInput(
                    "Both X509Data and DEREncodedKeyValue found and they represent different "
                    "public keys. Use verify(ignore_ambiguous_key_info=True) to ignore "
                    "DEREncodedKeyValue and validate using X509Data only."
                )

    def check_digest_alg_expected(self, digest_alg):
        if digest_alg not in self.config.digest_algorithms:
            raise InvalidInput(f"Digest algorithm {digest_alg.name} forbidden by configuration")

    def check_signature_alg_expected(self, signature_alg):
        if signature_alg not in self.config.signature_methods:
            raise InvalidInput(f"Signature method {signature_alg.name} forbidden by configuration")

    def verify(
        self,
        data,
        *,
        x509_cert: Optional[Union[str, x509.Certificate]] = None,
        cert_subject_name: Optional[str] = None,
        cert_resolver: Optional[Callable] = None,
        ca_pem_file: Optional[Union[str, bytes]] = None,
        hmac_key: Optional[bytes] = None,
        validate_schema: bool = True,
        parser=None,
        uri_resolver: Optional[Callable] = None,
        id_attribute: Optional[str] = None,
        expect_config: SignatureConfiguration = SignatureConfiguration(),
        **deprecated_kwargs,
    ) -> Union[VerifyResult, List[VerifyResult]]:
        """
        Verify the XML signature supplied in the data and return a list of :class:`VerifyResult` data structures
        representing the data signed by the signature, or raise an exception if the signature is not valid. By default,
        this requires the signature to be generated using a valid X.509 certificate. To enable other means of signature
        validation, set ``expect_config`` to a configuration with the **require_x509** parameter set to `False`.

        .. admonition:: See what is signed

         It is important to understand and follow the best practice rule of "See what is signed" when verifying XML
         signatures. The gist of this rule is: if your application neglects to verify that the information it trusts is
         what was actually signed, the attacker can supply a valid signature but point you to malicious data that wasn't
         signed by that signature.

         In SignXML, you can ensure that the information signed is what you expect to be signed by only trusting the
         data returned by ``XMLVerifier.verify()``. The ``signed_xml`` attribute of the return value is the XML node or string
         that was signed. We also recommend that you assert the expected location for the signature within the document:

         .. code-block:: python

             from signxml import XMLVerifier, SignatureConfiguration
             config = SignatureConfiguration(location="./")
             XMLVerifier(...).verify(..., expect_config=config)

         Depending on the canonicalization method used by the signature, comments in the XML data may not be subject to
         signing, so may need to be untrusted. If so, they are excised from the return value of ``verify()``.

         **Recommended reading:** http://www.w3.org/TR/xmldsig-bestpractices/#practices-applications

        .. admonition:: Establish trust

         If you do not supply any keyword arguments to ``verify()``, the default behavior is to trust **any** valid XML
         signature generated using a valid X.509 certificate trusted by your system's CA store. This means anyone can
         get an SSL certificate and generate a signature that you will trust. To establish trust in the signer, use the
         ``x509_cert`` argument to specify a certificate that was pre-shared out-of-band (e.g. via SAML metadata, as
         shown in :ref:`Verifying SAML assertions <verifying-saml-assertions>`), or ``cert_subject_name`` to specify a
         subject name that must be in the signing X.509 certificate given by the signature (verified as if it were a
         domain name), or ``ca_pem_file`` to give a custom CA.

        :param data: Signature data to verify
        :type data: String, file-like object, or XML ElementTree Element API compatible object
        :param x509_cert:
            A trusted external X.509 certificate, given as a PEM-formatted string or cryptography.x509.Certificate
            object, to use for verification. Overrides any X.509 certificate information supplied by the signature. If
            left set to ``None``, requires that the signature supply a valid X.509 certificate chain that validates
            against the known certificate authorities. Implies **require_x509=True**.
        :param cert_subject_name:
            Subject Common Name to check the signing X.509 certificate against. Implies **require_x509=True**.
        :param cert_resolver:
            Function to use to resolve trusted X.509 certificates when X509IssuerSerial and X509Digest references are
            found in the signature. The function is called with the keyword arguments ``x509_issuer_name``,
            ``x509_serial_number`` and ``x509_digest``, and is expected to return an iterable of one or more
            strings containing a PEM-formatted certificate and a chain of intermediate certificates, if needed.
            Implies **require_x509=True**.
        :param ca_pem_file:
            Filename of a PEM file containing certificate authority information to use when verifying certificate-based
            signatures.
        :param hmac_key: If using HMAC, a string containing the shared secret.
        :param validate_schema: Whether to validate **data** against the XML Signature schema.
        :param parser:
            Custom XML parser instance to use when parsing **data**. The default parser arguments used by SignXML are:
            ``resolve_entities=False``. See https://lxml.de/FAQ.html#how-do-i-use-lxml-safely-as-a-web-service-endpoint.
        :type parser: :class:`lxml.etree.XMLParser` compatible parser
        :param uri_resolver:
            Function to use to resolve reference URIs that are not empty and don't start with "#" (such references are
            only expected in detached signatures; if you don't expect such signatures, leave this unset to prevent them
            from validating). The function is called with a single string argument containing the URI to be resolved,
            and is expected to return a :class:`lxml.etree._Element` node or bytes.
        :param id_attribute:
            Name of the attribute whose value ``URI`` refers to. By default, SignXML will search for "Id", then "ID".
        :param expect_config:
            Expected signature configuration. Pass a :class:`SignatureConfiguration` object to describe expected
            properties of the verified signature. Signatures with unexpected configurations will fail validation.
        :param deprecated_kwargs:
            Direct application of the parameters **require_x509**, **expect_references**, and
            **ignore_ambiguous_key_info** is deprecated. Use **expect_config** instead.

        :raises: :class:`signxml.exceptions.InvalidSignature`
        """
        self.hmac_key = hmac_key
        if hmac_key is not None:
            if expect_config.signature_methods == SignatureConfiguration().signature_methods:  # default value
                expect_config = replace(
                    expect_config,
                    signature_methods=frozenset(sm for sm in SignatureMethod if sm.name.startswith("HMAC")),
                )
            elif any(not sm.name.startswith("HMAC") for sm in expect_config.signature_methods):
                raise InvalidInput("When hmac_key is set, all expected signature methods must use HMAC")

        self.config = expect_config
        if deprecated_kwargs:
            self.config = replace(expect_config, **deprecated_kwargs)
        self._parser = parser

        if x509_cert or cert_resolver:
            self.config = replace(self.config, require_x509=True)

        if x509_cert and str(type(x509_cert)) == "<class 'OpenSSL.crypto.X509'>":
            warn(
                "SignXML received a PyOpenSSL object as x509_cert input. Please pass a Cryptography.X509 object instead.",
                DeprecationWarning,
            )
            x509_cert = x509_cert.to_cryptography()  # type: ignore[union-attr]

        self.x509_cert = x509_cert

        if id_attribute is not None:
            self.id_attributes = (id_attribute,)

        root = self.get_root(data)
        signature_ref = self._get_signature(root)

        # We could do a deep copy here, but it wouldn't preserve root level namespaces
        signature = self._fromstring(self._tostring(signature_ref))

        if validate_schema:
            self.validate_schema(signature)

        signed_info = self._find(signature, "SignedInfo")
        c14n_method = self._find(signed_info, "CanonicalizationMethod")
        c14n_algorithm = CanonicalizationMethod(c14n_method.get("Algorithm"))
        inclusive_ns_prefixes = self._get_inclusive_ns_prefixes(c14n_method)
        signature_method = self._find(signed_info, "SignatureMethod")
        signature_value = self._find(signature, "SignatureValue")
        signature_alg = SignatureMethod(signature_method.get("Algorithm"))
        self.check_signature_alg_expected(signature_alg)
        raw_signature = b64decode(signature_value.text)
        x509_data = signature.find("ds:KeyInfo/ds:X509Data", namespaces=namespaces)
        key_value = signature.find("ds:KeyInfo/ds:KeyValue", namespaces=namespaces)
        der_encoded_key_value = signature.find("ds:KeyInfo/dsig11:DEREncodedKeyValue", namespaces=namespaces)
        signed_info_c14n = self._c14n(
            signed_info, algorithm=c14n_algorithm, inclusive_ns_prefixes=inclusive_ns_prefixes
        )

        if x509_data is not None or self.config.require_x509:
            if self.x509_cert is None:
                if x509_data is None:
                    raise InvalidInput("Expected a X.509 certificate based signature")
                certs = [cert.text for cert in self._findall(x509_data, "X509Certificate")]
                if len(certs) == 0:
                    x509_iss = x509_data.find("ds:X509IssuerSerial/ds:X509IssuerName", namespaces=namespaces)
                    x509_sn = x509_data.find("ds:X509IssuerSerial/ds:X509SerialNumber", namespaces=namespaces)
                    x509_digest = x509_data.find("dsig11:X509Digest", namespaces=namespaces)
                    if cert_resolver and any(i is not None for i in (x509_iss, x509_sn, x509_digest)):
                        cert_chain = cert_resolver(
                            x509_issuer_name=x509_iss.text if x509_iss is not None else None,
                            x509_serial_number=x509_sn.text if x509_sn is not None else None,
                            x509_digest=x509_digest.text if x509_digest is not None else None,
                        )
                        if len(cert_chain) == 0:
                            raise InvalidCertificate("No certificate found for given X509 data")
                        if not all(isinstance(c, x509.Certificate) for c in cert_chain):
                            cert_chain = [x509.load_pem_x509_certificate(add_pem_header(cert)) for cert in cert_chain]
                    else:
                        msg = "Expected to find an X509Certificate element in the signature"
                        msg += " (X509SubjectName, X509SKI are not supported)"
                        raise InvalidInput(msg)
                else:
                    cert_chain = [x509.load_pem_x509_certificate(add_pem_header(cert)) for cert in certs]

                cert_verifier = self.get_cert_chain_verifier(ca_pem_file=ca_pem_file)

                signing_cert = cert_verifier.verify(cert_chain)
            elif isinstance(self.x509_cert, x509.Certificate):
                signing_cert = self.x509_cert
            else:
                signing_cert = x509.load_pem_x509_certificate(add_pem_header(self.x509_cert))

            if cert_subject_name is not None:
                cn_oid = x509.oid.NameOID.COMMON_NAME
                subject_cn_from_signing_cert = signing_cert.subject.get_attributes_for_oid(cn_oid)[0].value
                if subject_cn_from_signing_cert != cert_subject_name:
                    raise InvalidSignature("Certificate subject common name mismatch")

            try:
                verified_signed_info_c14n, key_used = self._verify_signature_with_pubkey(
                    signed_info_c14n=signed_info_c14n,
                    raw_signature=raw_signature,
                    signing_certificate=signing_cert,
                    signature_alg=signature_alg,
                )
            except cryptography.exceptions.InvalidSignature as e:
                raise InvalidSignature(f"Signature verification failed: {e}")

            self._match_key_values(
                key_value=key_value,
                der_encoded_key_value=der_encoded_key_value,
                signing_cert=signing_cert,
                signature_alg=signature_alg,
            )
        elif signature_alg.name.startswith("HMAC_"):
            if self.hmac_key is None:
                raise InvalidInput('Parameter "hmac_key" is required when verifying a HMAC signature')

            signer = HMAC(key=self.hmac_key, algorithm=digest_algorithm_implementations[signature_alg]())
            signer.update(signed_info_c14n)
            try:
                signer.verify(raw_signature)
                verified_signed_info_c14n = signed_info_c14n
            except cryptography.exceptions.InvalidSignature:
                raise InvalidSignature("Signature mismatch (HMAC)")
            key_used = self.hmac_key
        else:
            if key_value is None and der_encoded_key_value is None:
                raise InvalidInput("Expected to find either KeyValue or X509Data XML element in KeyInfo")

            verified_signed_info_c14n, key_used = self._verify_signature_with_pubkey(
                signed_info_c14n=signed_info_c14n,
                raw_signature=raw_signature,
                key_value=key_value,
                der_encoded_key_value=der_encoded_key_value,
                signature_alg=signature_alg,
            )

        verified_signed_info = self._fromstring(verified_signed_info_c14n)
        verify_results: List[VerifyResult] = []
        for idx, reference in enumerate(self._findall(verified_signed_info, "Reference")):
            verify_results.append(
                self._verify_reference(reference, idx, root, uri_resolver, c14n_algorithm, signature, key_used)
            )

        if type(self.config.expect_references) is int and len(verify_results) != self.config.expect_references:
            msg = "Expected to find {} references, but found {}"
            raise InvalidSignature(msg.format(self.config.expect_references, len(verify_results)))

        return verify_results if self.config.expect_references > 1 else verify_results[0]

    def _verify_reference(self, reference, index, root, uri_resolver, c14n_algorithm, signature, signature_key_used):
        copied_root = self._fromstring(self._tostring(root))
        copied_signature_ref = self._get_signature(copied_root)
        transforms = self._find(reference, "Transforms", require=False)
        digest_method_alg_name = self._find(reference, "DigestMethod").get("Algorithm")
        digest_value = self._find(reference, "DigestValue")
        payload = self._resolve_reference(copied_root, reference, uri_resolver=uri_resolver)
        payload_c14n = self._apply_transforms(payload, transforms_node=transforms, signature=copied_signature_ref)
        digest_alg = DigestAlgorithm(digest_method_alg_name)
        self.check_digest_alg_expected(digest_alg)

        if b64decode(digest_value.text) != self._get_digest(payload_c14n, digest_alg):
            raise InvalidDigest(f"Digest mismatch for reference {index} ({reference.get('URI')})")

        # We return the signed XML (and only that) to ensure no access to unsigned data happens.
        # Note it is essential to roundtrip the payload and render it from canonicalized XML, to avoid returning
        # untrusted comments, avoid text nodes being broken up even after comments are excised, etc.
        try:
            payload_c14n_xml = self._fromstring(payload_c14n)
        except etree.XMLSyntaxError:
            payload_c14n_xml = None

        if isinstance(signature_key_used, bytes):
            signature_key = signature_key_used
        else:
            signature_key = signature_key_used.public_bytes(Encoding.PEM, PublicFormat.SubjectPublicKeyInfo)
        return VerifyResult(payload_c14n, payload_c14n_xml, signature, signature_key=signature_key)

    def validate_schema(self, signature):
        last_exception = None
        for schema in self.schemas():
            try:
                schema.assertValid(signature)
                return
            except Exception as e:
                last_exception = e
        if last_exception is not None:
            raise last_exception
        raise SignXMLException("Invalid state")

    def _check_key_value_matches_cert_public_key(self, key_value, public_key, signature_alg: SignatureMethod):
        if signature_alg.name.startswith("ECDSA_") and isinstance(public_key, ec.EllipticCurvePublicKey):
            ec_key_value = self._find(key_value, "dsig11:ECKeyValue")
            named_curve = self._find(ec_key_value, "dsig11:NamedCurve")
            pub_key = self._find(ec_key_value, "dsig11:PublicKey")
            key_data = b64decode(pub_key.text)[1:]
            x = bytes_to_long(key_data[: len(key_data) // 2])
            y = bytes_to_long(key_data[len(key_data) // 2 :])
            curve_class = self.known_ecdsa_curves[named_curve.get("URI")]

            pubk_curve = public_key.public_numbers().curve
            pubk_x = public_key.public_numbers().x
            pubk_y = public_key.public_numbers().y

            return curve_class == pubk_curve and x == pubk_x and y == pubk_y

        elif signature_alg.name.startswith("DSA_") and isinstance(public_key, dsa.DSAPublicKey):
            dsa_key_value = self._find(key_value, "DSAKeyValue")
            p = self._get_long(dsa_key_value, "P")
            q = self._get_long(dsa_key_value, "Q")
            g = self._get_long(dsa_key_value, "G", require=False)

            pubk_p = public_key.public_numbers().parameter_numbers.p
            pubk_q = public_key.public_numbers().parameter_numbers.q
            pubk_g = public_key.public_numbers().parameter_numbers.g

            return p == pubk_p and q == pubk_q and g == pubk_g

        elif signature_alg.name.startswith("RSA_") and isinstance(public_key, rsa.RSAPublicKey):
            rsa_key_value = self._find(key_value, "RSAKeyValue")
            n = self._get_long(rsa_key_value, "Modulus")
            e = self._get_long(rsa_key_value, "Exponent")

            pubk_n = public_key.public_numbers().n
            pubk_e = public_key.public_numbers().e

            return n == pubk_n and e == pubk_e

        raise NotImplementedError()

    def _check_der_key_value_matches_cert_public_key(self, der_encoded_key_value, public_key, signature_alg):
        # TODO: Add a test case for this functionality
        der_public_key = load_der_public_key(b64decode(der_encoded_key_value.text))

        if (
            signature_alg.name.startswith("ECDSA_")
            and isinstance(der_public_key, ec.EllipticCurvePublicKey)
            and isinstance(public_key, ec.EllipticCurvePublicKey)
        ):
            curve_class = der_public_key.public_numbers().curve
            x = der_public_key.public_numbers().x
            y = der_public_key.public_numbers().y

            pubk_curve = public_key.public_numbers().curve
            pubk_x = public_key.public_numbers().x
            pubk_y = public_key.public_numbers().y

            return curve_class == pubk_curve and x == pubk_x and y == pubk_y

        elif (
            signature_alg.name.startswith("DSA_")
            and isinstance(der_public_key, dsa.DSAPublicKey)
            and isinstance(public_key, dsa.DSAPublicKey)
        ):
            p = der_public_key.public_numbers().parameter_numbers.p
            q = der_public_key.public_numbers().parameter_numbers.q
            g = der_public_key.public_numbers().parameter_numbers.g

            pubk_p = public_key.public_numbers().parameter_numbers.p
            pubk_q = public_key.public_numbers().parameter_numbers.q
            pubk_g = public_key.public_numbers().parameter_numbers.g

            return p == pubk_p and q == pubk_q and g == pubk_g

        elif (
            signature_alg.name.startswith("RSA_")
            and isinstance(der_public_key, rsa.RSAPublicKey)
            and isinstance(public_key, rsa.RSAPublicKey)
        ):
            n = der_public_key.public_numbers().n
            e = der_public_key.public_numbers().e

            pubk_n = public_key.public_numbers().n
            pubk_e = public_key.public_numbers().e

            return n == pubk_n and e == pubk_e

        raise NotImplementedError()

    def _get_long(self, element, query, require=True):
        result = self._find(element, query, require=require)
        if result is not None:
            result = bytes_to_long(b64decode(result.text))
        return result