1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
|
import base64
import os
import pytest
from kombu.serialization import registry
from kombu.utils.encoding import bytes_to_str
from celery.exceptions import SecurityError
from celery.security.certificate import Certificate, CertStore
from celery.security.key import PrivateKey
from celery.security.serialization import DEFAULT_SEPARATOR, SecureSerializer, register_auth
from . import CERT1, CERT2, KEY1, KEY2
from .case import SecurityCase
class test_secureserializer(SecurityCase):
def _get_s(self, key, cert, certs, serializer="json"):
store = CertStore()
for c in certs:
store.add_cert(Certificate(c))
return SecureSerializer(
PrivateKey(key), Certificate(cert), store, serializer=serializer
)
@pytest.mark.parametrize(
"data", [1, "foo", b"foo", {"foo": 1}, {"foo": DEFAULT_SEPARATOR}]
)
@pytest.mark.parametrize("serializer", ["json", "pickle"])
def test_serialize(self, data, serializer):
s = self._get_s(KEY1, CERT1, [CERT1], serializer=serializer)
assert s.deserialize(s.serialize(data)) == data
def test_deserialize(self):
s = self._get_s(KEY1, CERT1, [CERT1])
with pytest.raises(SecurityError):
s.deserialize('bad data')
def test_unmatched_key_cert(self):
s = self._get_s(KEY1, CERT2, [CERT1, CERT2])
with pytest.raises(SecurityError):
s.deserialize(s.serialize('foo'))
def test_unknown_source(self):
s1 = self._get_s(KEY1, CERT1, [CERT2])
s2 = self._get_s(KEY1, CERT1, [])
with pytest.raises(SecurityError):
s1.deserialize(s1.serialize('foo'))
with pytest.raises(SecurityError):
s2.deserialize(s2.serialize('foo'))
def test_self_send(self):
s1 = self._get_s(KEY1, CERT1, [CERT1])
s2 = self._get_s(KEY1, CERT1, [CERT1])
assert s2.deserialize(s1.serialize('foo')) == 'foo'
def test_separate_ends(self):
s1 = self._get_s(KEY1, CERT1, [CERT2])
s2 = self._get_s(KEY2, CERT2, [CERT1])
assert s2.deserialize(s1.serialize('foo')) == 'foo'
def test_register_auth(self):
register_auth(KEY1, None, CERT1, '')
assert 'application/data' in registry._decoders
def test_lots_of_sign(self):
for i in range(1000):
rdata = bytes_to_str(base64.urlsafe_b64encode(os.urandom(265)))
s = self._get_s(KEY1, CERT1, [CERT1])
assert s.deserialize(s.serialize(rdata)) == rdata
|