File: test_security.py

package info (click to toggle)
celery 5.6.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 8,336 kB
  • sloc: python: 67,264; sh: 795; makefile: 378
file content (118 lines) | stat: -rw-r--r-- 3,833 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import datetime
import os
import socket
import tempfile

import pytest
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID

from .tasks import add


class test_security:

    @pytest.fixture(autouse=True, scope='class')
    def class_certs(self, request):
        self.tmpdir = tempfile.mkdtemp()
        self.key_name = 'worker.key'
        self.cert_name = 'worker.pem'

        key = self.gen_private_key()
        cert = self.gen_certificate(key=key,
                                    common_name='celery cecurity integration')

        pem_key = key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.TraditionalOpenSSL,
            encryption_algorithm=serialization.NoEncryption()
        )

        pem_cert = cert.public_bytes(
            encoding=serialization.Encoding.PEM,
        )

        with open(self.tmpdir + '/' + self.key_name, 'wb') as key:
            key.write(pem_key)
        with open(self.tmpdir + '/' + self.cert_name, 'wb') as cert:
            cert.write(pem_cert)

        request.cls.tmpdir = self.tmpdir
        request.cls.key_name = self.key_name
        request.cls.cert_name = self.cert_name

        yield

        os.remove(self.tmpdir + '/' + self.key_name)
        os.remove(self.tmpdir + '/' + self.cert_name)
        os.rmdir(self.tmpdir)

    @pytest.fixture(autouse=True)
    def _prepare_setup(self, manager):
        manager.app.conf.update(
            security_key=f'{self.tmpdir}/{self.key_name}',
            security_certificate=f'{self.tmpdir}/{self.cert_name}',
            security_cert_store=f'{self.tmpdir}/*.pem',
            task_serializer='auth',
            event_serializer='auth',
            accept_content=['auth'],
            result_accept_content=['json']
        )

        manager.app.setup_security()

    def gen_private_key(self):
        """generate a private key with cryptography"""
        return rsa.generate_private_key(
            public_exponent=65537,
            key_size=2048,
            backend=default_backend(),
        )

    def gen_certificate(self, key, common_name, issuer=None, sign_key=None):
        """generate a certificate with cryptography"""

        now = datetime.datetime.now(datetime.timezone.utc)

        certificate = x509.CertificateBuilder().subject_name(
            x509.Name([
                x509.NameAttribute(NameOID.COMMON_NAME, common_name),
            ])
        ).issuer_name(
            x509.Name([
                x509.NameAttribute(
                    NameOID.COMMON_NAME,
                    issuer or common_name
                )
            ])
        ).not_valid_before(
            now
        ).not_valid_after(
            now + datetime.timedelta(seconds=86400)
        ).serial_number(
            x509.random_serial_number()
        ).public_key(
            key.public_key()
        ).add_extension(
            x509.BasicConstraints(ca=True, path_length=0), critical=True
        ).sign(
            private_key=sign_key or key,
            algorithm=hashes.SHA256(),
            backend=default_backend()
        )
        return certificate

    @pytest.mark.xfail(reason="Issue #5269")
    def test_security_task_done(self):
        t1 = add.apply_async((1, 1))
        try:
            result = t1.get(timeout=10)  # redis backend will timeout
            assert result == 2
        except (socket.timeout, TimeoutError) as e:
            pytest.fail(
                f"Timed out waiting for task result. Task was likely dropped by "
                f"worker due to security misconfig. Exception details: {e}"
            )