File: rsa_key.py

package info (click to toggle)
python-authlib 1.6.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 3,016 kB
  • sloc: python: 26,998; makefile: 53; sh: 14
file content (127 lines) | stat: -rw-r--r-- 4,581 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKeyWithSerialization
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateNumbers
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
from cryptography.hazmat.primitives.asymmetric.rsa import rsa_crt_dmp1
from cryptography.hazmat.primitives.asymmetric.rsa import rsa_crt_dmq1
from cryptography.hazmat.primitives.asymmetric.rsa import rsa_crt_iqmp
from cryptography.hazmat.primitives.asymmetric.rsa import rsa_recover_prime_factors

from authlib.common.encoding import base64_to_int
from authlib.common.encoding import int_to_base64

from ..rfc7517 import AsymmetricKey


class RSAKey(AsymmetricKey):
    """Key class of the ``RSA`` key type."""

    kty = "RSA"
    PUBLIC_KEY_CLS = RSAPublicKey
    PRIVATE_KEY_CLS = RSAPrivateKeyWithSerialization

    PUBLIC_KEY_FIELDS = ["e", "n"]
    PRIVATE_KEY_FIELDS = ["d", "dp", "dq", "e", "n", "p", "q", "qi"]
    REQUIRED_JSON_FIELDS = ["e", "n"]
    SSH_PUBLIC_PREFIX = b"ssh-rsa"

    def dumps_private_key(self):
        numbers = self.private_key.private_numbers()
        return {
            "n": int_to_base64(numbers.public_numbers.n),
            "e": int_to_base64(numbers.public_numbers.e),
            "d": int_to_base64(numbers.d),
            "p": int_to_base64(numbers.p),
            "q": int_to_base64(numbers.q),
            "dp": int_to_base64(numbers.dmp1),
            "dq": int_to_base64(numbers.dmq1),
            "qi": int_to_base64(numbers.iqmp),
        }

    def dumps_public_key(self):
        numbers = self.public_key.public_numbers()
        return {"n": int_to_base64(numbers.n), "e": int_to_base64(numbers.e)}

    def load_private_key(self):
        obj = self._dict_data

        if "oth" in obj:  # pragma: no cover
            # https://tools.ietf.org/html/rfc7518#section-6.3.2.7
            raise ValueError('"oth" is not supported yet')

        public_numbers = RSAPublicNumbers(
            base64_to_int(obj["e"]), base64_to_int(obj["n"])
        )

        if has_all_prime_factors(obj):
            numbers = RSAPrivateNumbers(
                d=base64_to_int(obj["d"]),
                p=base64_to_int(obj["p"]),
                q=base64_to_int(obj["q"]),
                dmp1=base64_to_int(obj["dp"]),
                dmq1=base64_to_int(obj["dq"]),
                iqmp=base64_to_int(obj["qi"]),
                public_numbers=public_numbers,
            )
        else:
            d = base64_to_int(obj["d"])
            p, q = rsa_recover_prime_factors(public_numbers.n, d, public_numbers.e)
            numbers = RSAPrivateNumbers(
                d=d,
                p=p,
                q=q,
                dmp1=rsa_crt_dmp1(d, p),
                dmq1=rsa_crt_dmq1(d, q),
                iqmp=rsa_crt_iqmp(p, q),
                public_numbers=public_numbers,
            )

        return numbers.private_key(default_backend())

    def load_public_key(self):
        numbers = RSAPublicNumbers(
            base64_to_int(self._dict_data["e"]), base64_to_int(self._dict_data["n"])
        )
        return numbers.public_key(default_backend())

    @classmethod
    def generate_key(cls, key_size=2048, options=None, is_private=False) -> "RSAKey":
        if key_size < 512:
            raise ValueError("key_size must not be less than 512")
        if key_size % 8 != 0:
            raise ValueError("Invalid key_size for RSAKey")
        raw_key = rsa.generate_private_key(
            public_exponent=65537,
            key_size=key_size,
            backend=default_backend(),
        )
        if not is_private:
            raw_key = raw_key.public_key()
        return cls.import_key(raw_key, options=options)

    @classmethod
    def import_dict_key(cls, raw, options=None):
        cls.check_required_fields(raw)
        key = cls(options=options)
        key._dict_data = raw
        if "d" in raw and not has_all_prime_factors(raw):
            # reload dict key
            key.load_raw_key()
            key.load_dict_key()
        return key


def has_all_prime_factors(obj):
    props = ["p", "q", "dp", "dq", "qi"]
    props_found = [prop in obj for prop in props]
    if all(props_found):
        return True

    if any(props_found):
        raise ValueError(
            "RSA key must include all parameters if any are present besides d"
        )

    return False