File: jws_algs.py

package info (click to toggle)
python-authlib 1.6.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 3,016 kB
  • sloc: python: 26,998; makefile: 53; sh: 14
file content (221 lines) | stat: -rw-r--r-- 6,573 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
"""authlib.jose.rfc7518.
~~~~~~~~~~~~~~~~~~~~

"alg" (Algorithm) Header Parameter Values for JWS per `Section 3`_.

.. _`Section 3`: https://tools.ietf.org/html/rfc7518#section-3
"""

import hashlib
import hmac

from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric.ec import ECDSA
from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature
from cryptography.hazmat.primitives.asymmetric.utils import encode_dss_signature

from ..rfc7515 import JWSAlgorithm
from .ec_key import ECKey
from .oct_key import OctKey
from .rsa_key import RSAKey
from .util import decode_int
from .util import encode_int


class NoneAlgorithm(JWSAlgorithm):
    name = "none"
    description = "No digital signature or MAC performed"

    def prepare_key(self, raw_data):
        return None

    def sign(self, msg, key):
        return b""

    def verify(self, msg, sig, key):
        return sig == b""


class HMACAlgorithm(JWSAlgorithm):
    """HMAC using SHA algorithms for JWS. Available algorithms:

    - HS256: HMAC using SHA-256
    - HS384: HMAC using SHA-384
    - HS512: HMAC using SHA-512
    """

    SHA256 = hashlib.sha256
    SHA384 = hashlib.sha384
    SHA512 = hashlib.sha512

    def __init__(self, sha_type):
        self.name = f"HS{sha_type}"
        self.description = f"HMAC using SHA-{sha_type}"
        self.hash_alg = getattr(self, f"SHA{sha_type}")

    def prepare_key(self, raw_data):
        return OctKey.import_key(raw_data)

    def sign(self, msg, key):
        # it is faster than the one in cryptography
        op_key = key.get_op_key("sign")
        return hmac.new(op_key, msg, self.hash_alg).digest()

    def verify(self, msg, sig, key):
        op_key = key.get_op_key("verify")
        v_sig = hmac.new(op_key, msg, self.hash_alg).digest()
        return hmac.compare_digest(sig, v_sig)


class RSAAlgorithm(JWSAlgorithm):
    """RSA using SHA algorithms for JWS. Available algorithms:

    - RS256: RSASSA-PKCS1-v1_5 using SHA-256
    - RS384: RSASSA-PKCS1-v1_5 using SHA-384
    - RS512: RSASSA-PKCS1-v1_5 using SHA-512
    """

    SHA256 = hashes.SHA256
    SHA384 = hashes.SHA384
    SHA512 = hashes.SHA512

    def __init__(self, sha_type):
        self.name = f"RS{sha_type}"
        self.description = f"RSASSA-PKCS1-v1_5 using SHA-{sha_type}"
        self.hash_alg = getattr(self, f"SHA{sha_type}")
        self.padding = padding.PKCS1v15()

    def prepare_key(self, raw_data):
        return RSAKey.import_key(raw_data)

    def sign(self, msg, key):
        op_key = key.get_op_key("sign")
        return op_key.sign(msg, self.padding, self.hash_alg())

    def verify(self, msg, sig, key):
        op_key = key.get_op_key("verify")
        try:
            op_key.verify(sig, msg, self.padding, self.hash_alg())
            return True
        except InvalidSignature:
            return False


class ECAlgorithm(JWSAlgorithm):
    """ECDSA using SHA algorithms for JWS. Available algorithms:

    - ES256: ECDSA using P-256 and SHA-256
    - ES384: ECDSA using P-384 and SHA-384
    - ES512: ECDSA using P-521 and SHA-512
    """

    SHA256 = hashes.SHA256
    SHA384 = hashes.SHA384
    SHA512 = hashes.SHA512

    def __init__(self, name, curve, sha_type):
        self.name = name
        self.curve = curve
        self.description = f"ECDSA using {self.curve} and SHA-{sha_type}"
        self.hash_alg = getattr(self, f"SHA{sha_type}")

    def prepare_key(self, raw_data):
        key = ECKey.import_key(raw_data)
        if key["crv"] != self.curve:
            raise ValueError(
                f'Key for "{self.name}" not supported, only "{self.curve}" allowed'
            )
        return key

    def sign(self, msg, key):
        op_key = key.get_op_key("sign")
        der_sig = op_key.sign(msg, ECDSA(self.hash_alg()))
        r, s = decode_dss_signature(der_sig)
        size = key.curve_key_size
        return encode_int(r, size) + encode_int(s, size)

    def verify(self, msg, sig, key):
        key_size = key.curve_key_size
        length = (key_size + 7) // 8

        if len(sig) != 2 * length:
            return False

        r = decode_int(sig[:length])
        s = decode_int(sig[length:])
        der_sig = encode_dss_signature(r, s)

        try:
            op_key = key.get_op_key("verify")
            op_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
            return True
        except InvalidSignature:
            return False


class RSAPSSAlgorithm(JWSAlgorithm):
    """RSASSA-PSS using SHA algorithms for JWS. Available algorithms:

    - PS256: RSASSA-PSS using SHA-256 and MGF1 with SHA-256
    - PS384: RSASSA-PSS using SHA-384 and MGF1 with SHA-384
    - PS512: RSASSA-PSS using SHA-512 and MGF1 with SHA-512
    """

    SHA256 = hashes.SHA256
    SHA384 = hashes.SHA384
    SHA512 = hashes.SHA512

    def __init__(self, sha_type):
        self.name = f"PS{sha_type}"
        tpl = "RSASSA-PSS using SHA-{} and MGF1 with SHA-{}"
        self.description = tpl.format(sha_type, sha_type)
        self.hash_alg = getattr(self, f"SHA{sha_type}")

    def prepare_key(self, raw_data):
        return RSAKey.import_key(raw_data)

    def sign(self, msg, key):
        op_key = key.get_op_key("sign")
        return op_key.sign(
            msg,
            padding.PSS(
                mgf=padding.MGF1(self.hash_alg()), salt_length=self.hash_alg.digest_size
            ),
            self.hash_alg(),
        )

    def verify(self, msg, sig, key):
        op_key = key.get_op_key("verify")
        try:
            op_key.verify(
                sig,
                msg,
                padding.PSS(
                    mgf=padding.MGF1(self.hash_alg()),
                    salt_length=self.hash_alg.digest_size,
                ),
                self.hash_alg(),
            )
            return True
        except InvalidSignature:
            return False


JWS_ALGORITHMS = [
    NoneAlgorithm(),  # none
    HMACAlgorithm(256),  # HS256
    HMACAlgorithm(384),  # HS384
    HMACAlgorithm(512),  # HS512
    RSAAlgorithm(256),  # RS256
    RSAAlgorithm(384),  # RS384
    RSAAlgorithm(512),  # RS512
    ECAlgorithm("ES256", "P-256", 256),
    ECAlgorithm("ES384", "P-384", 384),
    ECAlgorithm("ES512", "P-521", 512),
    ECAlgorithm("ES256K", "secp256k1", 256),  # defined in RFC8812
    RSAPSSAlgorithm(256),  # PS256
    RSAPSSAlgorithm(384),  # PS384
    RSAPSSAlgorithm(512),  # PS512
]