1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
|
import sys
from io import BytesIO
from unittest import skipIf
import xmlsec
from tests import base
from xmlsec import constants as consts
class TestBase64LineSize(base.TestMemoryLeaks):
def tearDown(self):
xmlsec.base64_default_line_size(64)
super(TestBase64LineSize, self).tearDown()
def test_get_base64_default_line_size(self):
self.assertEqual(xmlsec.base64_default_line_size(), 64)
def test_set_base64_default_line_size_positional_arg(self):
xmlsec.base64_default_line_size(0)
self.assertEqual(xmlsec.base64_default_line_size(), 0)
def test_set_base64_default_line_size_keyword_arg(self):
xmlsec.base64_default_line_size(size=0)
self.assertEqual(xmlsec.base64_default_line_size(), 0)
def test_set_base64_default_line_size_with_bad_args(self):
size = xmlsec.base64_default_line_size()
for bad_size in (None, '', object()):
with self.assertRaises(TypeError):
xmlsec.base64_default_line_size(bad_size)
self.assertEqual(xmlsec.base64_default_line_size(), size)
def test_set_base64_default_line_size_rejects_negative_values(self):
size = xmlsec.base64_default_line_size()
with self.assertRaises(ValueError):
xmlsec.base64_default_line_size(-1)
self.assertEqual(xmlsec.base64_default_line_size(), size)
class TestCallbacks(base.TestMemoryLeaks):
def setUp(self):
super().setUp()
xmlsec.cleanup_callbacks()
def _sign_doc(self):
root = self.load_xml("doc.xml")
sign = xmlsec.template.create(root, c14n_method=consts.TransformExclC14N, sign_method=consts.TransformRsaSha1)
xmlsec.template.add_reference(sign, consts.TransformSha1, uri="cid:123456")
ctx = xmlsec.SignatureContext()
ctx.key = xmlsec.Key.from_file(self.path("rsakey.pem"), format=consts.KeyDataFormatPem)
ctx.sign(sign)
return sign
def _expect_sign_failure(self):
with self.assertRaisesRegex(xmlsec.Error, 'failed to sign'):
self._sign_doc()
def _mismatch_callbacks(self, match_cb=lambda filename: False):
return [
match_cb,
lambda filename: None,
lambda none, buf: 0,
lambda none: None,
]
def _register_mismatch_callbacks(self, match_cb=lambda filename: False):
xmlsec.register_callbacks(*self._mismatch_callbacks(match_cb))
def _register_match_callbacks(self):
xmlsec.register_callbacks(
lambda filename: filename == b'cid:123456',
lambda filename: BytesIO(b'<html><head/><body/></html>'),
lambda bio, buf: bio.readinto(buf),
lambda bio: bio.close(),
)
def _find(self, elem, *tags):
try:
return elem.xpath(
'./' + '/'.join('xmldsig:{}'.format(tag) for tag in tags),
namespaces={
'xmldsig': 'http://www.w3.org/2000/09/xmldsig#',
},
)[0]
except IndexError as e:
raise KeyError(tags) from e
def _verify_external_data_signature(self):
signature = self._sign_doc()
digest = self._find(signature, 'SignedInfo', 'Reference', 'DigestValue').text
self.assertEqual(digest, 'VihZwVMGJ48NsNl7ertVHiURXk8=')
def test_sign_external_data_no_callbacks_fails(self):
self._expect_sign_failure()
def test_sign_external_data_default_callbacks_fails(self):
xmlsec.register_default_callbacks()
self._expect_sign_failure()
def test_sign_external_data_no_matching_callbacks_fails(self):
self._register_mismatch_callbacks()
self._expect_sign_failure()
def test_sign_data_from_callbacks(self):
self._register_match_callbacks()
self._verify_external_data_signature()
def test_sign_data_not_first_callback(self):
bad_match_calls = 0
def match_cb(filename):
nonlocal bad_match_calls
bad_match_calls += 1
return False
for _ in range(2):
self._register_mismatch_callbacks(match_cb)
self._register_match_callbacks()
for _ in range(2):
self._register_mismatch_callbacks()
self._verify_external_data_signature()
self.assertEqual(bad_match_calls, 0)
@skipIf(sys.platform == "win32", "unclear behaviour on windows")
def test_failed_sign_because_default_callbacks(self):
mismatch_calls = 0
def mismatch_cb(filename):
nonlocal mismatch_calls
mismatch_calls += 1
return False
# NB: These first two sets of callbacks should never get called,
# because the default callbacks always match beforehand:
self._register_match_callbacks()
self._register_mismatch_callbacks(mismatch_cb)
xmlsec.register_default_callbacks()
self._register_mismatch_callbacks(mismatch_cb)
self._register_mismatch_callbacks(mismatch_cb)
self._expect_sign_failure()
self.assertEqual(mismatch_calls, 2)
def test_register_non_callables(self):
for idx in range(4):
cbs = self._mismatch_callbacks()
cbs[idx] = None
self.assertRaises(TypeError, xmlsec.register_callbacks, *cbs)
def test_sign_external_data_fails_on_read_callback_wrong_returns(self):
xmlsec.register_callbacks(
lambda filename: filename == b'cid:123456',
lambda filename: BytesIO(b'<html><head/><body/></html>'),
lambda bio, buf: None,
lambda bio: bio.close(),
)
self._expect_sign_failure()
|