File: test_security.py

package info (click to toggle)
celery 5.5.3-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 8,008 kB
  • sloc: python: 64,346; sh: 795; makefile: 378
file content (110 lines) | stat: -rw-r--r-- 3,477 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import datetime
import os
import tempfile

import pytest
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID

from .tasks import add


class test_security:

    @pytest.fixture(autouse=True, scope='class')
    def class_certs(self, request):
        self.tmpdir = tempfile.mkdtemp()
        self.key_name = 'worker.key'
        self.cert_name = 'worker.pem'

        key = self.gen_private_key()
        cert = self.gen_certificate(key=key,
                                    common_name='celery cecurity integration')

        pem_key = key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.TraditionalOpenSSL,
            encryption_algorithm=serialization.NoEncryption()
        )

        pem_cert = cert.public_bytes(
            encoding=serialization.Encoding.PEM,
        )

        with open(self.tmpdir + '/' + self.key_name, 'wb') as key:
            key.write(pem_key)
        with open(self.tmpdir + '/' + self.cert_name, 'wb') as cert:
            cert.write(pem_cert)

        request.cls.tmpdir = self.tmpdir
        request.cls.key_name = self.key_name
        request.cls.cert_name = self.cert_name

        yield

        os.remove(self.tmpdir + '/' + self.key_name)
        os.remove(self.tmpdir + '/' + self.cert_name)
        os.rmdir(self.tmpdir)

    @pytest.fixture(autouse=True)
    def _prepare_setup(self, manager):
        manager.app.conf.update(
            security_key=f'{self.tmpdir}/{self.key_name}',
            security_certificate=f'{self.tmpdir}/{self.cert_name}',
            security_cert_store=f'{self.tmpdir}/*.pem',
            task_serializer='auth',
            event_serializer='auth',
            accept_content=['auth'],
            result_accept_content=['json']
        )

        manager.app.setup_security()

    def gen_private_key(self):
        """generate a private key with cryptography"""
        return rsa.generate_private_key(
            public_exponent=65537,
            key_size=2048,
            backend=default_backend(),
        )

    def gen_certificate(self, key, common_name, issuer=None, sign_key=None):
        """generate a certificate with cryptography"""

        now = datetime.datetime.now(datetime.timezone.utc)

        certificate = x509.CertificateBuilder().subject_name(
            x509.Name([
                x509.NameAttribute(NameOID.COMMON_NAME, common_name),
            ])
        ).issuer_name(
            x509.Name([
                x509.NameAttribute(
                    NameOID.COMMON_NAME,
                    issuer or common_name
                )
            ])
        ).not_valid_before(
            now
        ).not_valid_after(
            now + datetime.timedelta(seconds=86400)
        ).serial_number(
            x509.random_serial_number()
        ).public_key(
            key.public_key()
        ).add_extension(
            x509.BasicConstraints(ca=True, path_length=0), critical=True
        ).sign(
            private_key=sign_key or key,
            algorithm=hashes.SHA256(),
            backend=default_backend()
        )
        return certificate

    @pytest.mark.xfail(reason="Issue #5269")
    def test_security_task_done(self):
        t1 = add.delay(1, 1)
        assert t1.get() == 2