File: _jwe_algorithms.py

package info (click to toggle)
python-authlib 1.6.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 3,016 kB
  • sloc: python: 26,998; makefile: 53; sh: 14
file content (216 lines) | stat: -rw-r--r-- 7,199 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import struct

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.concatkdf import ConcatKDFHash

from authlib.jose.errors import InvalidEncryptionAlgorithmForECDH1PUWithKeyWrappingError
from authlib.jose.rfc7516 import JWEAlgorithmWithTagAwareKeyAgreement
from authlib.jose.rfc7518 import AESAlgorithm
from authlib.jose.rfc7518 import CBCHS2EncAlgorithm
from authlib.jose.rfc7518 import ECKey
from authlib.jose.rfc7518 import u32be_len_input
from authlib.jose.rfc8037 import OKPKey


class ECDH1PUAlgorithm(JWEAlgorithmWithTagAwareKeyAgreement):
    EXTRA_HEADERS = ["epk", "apu", "apv", "skid"]
    ALLOWED_KEY_CLS = (ECKey, OKPKey)

    # https://datatracker.ietf.org/doc/html/draft-madden-jose-ecdh-1pu-04
    def __init__(self, key_size=None):
        if key_size is None:
            self.name = "ECDH-1PU"
            self.description = "ECDH-1PU in the Direct Key Agreement mode"
        else:
            self.name = f"ECDH-1PU+A{key_size}KW"
            self.description = (
                f"ECDH-1PU using Concat KDF and CEK wrapped with A{key_size}KW"
            )
        self.key_size = key_size
        self.aeskw = AESAlgorithm(key_size)

    def prepare_key(self, raw_data):
        if isinstance(raw_data, self.ALLOWED_KEY_CLS):
            return raw_data
        return ECKey.import_key(raw_data)

    def generate_preset(self, enc_alg, key):
        epk = self._generate_ephemeral_key(key)
        h = self._prepare_headers(epk)
        preset = {"epk": epk, "header": h}
        if self.key_size is not None:
            cek = enc_alg.generate_cek()
            preset["cek"] = cek
        return preset

    def compute_shared_key(self, shared_key_e, shared_key_s):
        return shared_key_e + shared_key_s

    def compute_fixed_info(self, headers, bit_size, tag):
        if tag is None:
            cctag = b""
        else:
            cctag = u32be_len_input(tag)

        # AlgorithmID
        if self.key_size is None:
            alg_id = u32be_len_input(headers["enc"])
        else:
            alg_id = u32be_len_input(headers["alg"])

        # PartyUInfo
        apu_info = u32be_len_input(headers.get("apu"), True)

        # PartyVInfo
        apv_info = u32be_len_input(headers.get("apv"), True)

        # SuppPubInfo
        pub_info = struct.pack(">I", bit_size) + cctag

        return alg_id + apu_info + apv_info + pub_info

    def compute_derived_key(self, shared_key, fixed_info, bit_size):
        ckdf = ConcatKDFHash(
            algorithm=hashes.SHA256(),
            length=bit_size // 8,
            otherinfo=fixed_info,
            backend=default_backend(),
        )
        return ckdf.derive(shared_key)

    def deliver_at_sender(
        self,
        sender_static_key,
        sender_ephemeral_key,
        recipient_pubkey,
        headers,
        bit_size,
        tag,
    ):
        shared_key_s = sender_static_key.exchange_shared_key(recipient_pubkey)
        shared_key_e = sender_ephemeral_key.exchange_shared_key(recipient_pubkey)
        shared_key = self.compute_shared_key(shared_key_e, shared_key_s)

        fixed_info = self.compute_fixed_info(headers, bit_size, tag)

        return self.compute_derived_key(shared_key, fixed_info, bit_size)

    def deliver_at_recipient(
        self,
        recipient_key,
        sender_static_pubkey,
        sender_ephemeral_pubkey,
        headers,
        bit_size,
        tag,
    ):
        shared_key_s = recipient_key.exchange_shared_key(sender_static_pubkey)
        shared_key_e = recipient_key.exchange_shared_key(sender_ephemeral_pubkey)
        shared_key = self.compute_shared_key(shared_key_e, shared_key_s)

        fixed_info = self.compute_fixed_info(headers, bit_size, tag)

        return self.compute_derived_key(shared_key, fixed_info, bit_size)

    def _generate_ephemeral_key(self, key):
        return key.generate_key(key["crv"], is_private=True)

    def _prepare_headers(self, epk):
        # REQUIRED_JSON_FIELDS contains only public fields
        pub_epk = {k: epk[k] for k in epk.REQUIRED_JSON_FIELDS}
        pub_epk["kty"] = epk.kty
        return {"epk": pub_epk}

    def generate_keys_and_prepare_headers(self, enc_alg, key, sender_key, preset=None):
        if not isinstance(enc_alg, CBCHS2EncAlgorithm):
            raise InvalidEncryptionAlgorithmForECDH1PUWithKeyWrappingError()

        if preset and "epk" in preset:
            epk = preset["epk"]
            h = {}
        else:
            epk = self._generate_ephemeral_key(key)
            h = self._prepare_headers(epk)

        if preset and "cek" in preset:
            cek = preset["cek"]
        else:
            cek = enc_alg.generate_cek()

        return {"epk": epk, "cek": cek, "header": h}

    def _agree_upon_key_at_sender(
        self, enc_alg, headers, key, sender_key, epk, tag=None
    ):
        if self.key_size is None:
            bit_size = enc_alg.CEK_SIZE
        else:
            bit_size = self.key_size

        public_key = key.get_op_key("wrapKey")

        return self.deliver_at_sender(
            sender_key, epk, public_key, headers, bit_size, tag
        )

    def _wrap_cek(self, cek, dk):
        kek = self.aeskw.prepare_key(dk)
        return self.aeskw.wrap_cek(cek, kek)

    def agree_upon_key_and_wrap_cek(
        self, enc_alg, headers, key, sender_key, epk, cek, tag
    ):
        dk = self._agree_upon_key_at_sender(enc_alg, headers, key, sender_key, epk, tag)
        return self._wrap_cek(cek, dk)

    def wrap(self, enc_alg, headers, key, sender_key, preset=None):
        # In this class this method is used in direct key agreement mode only
        if self.key_size is not None:
            raise RuntimeError("Invalid algorithm state detected")

        if preset and "epk" in preset:
            epk = preset["epk"]
            h = {}
        else:
            epk = self._generate_ephemeral_key(key)
            h = self._prepare_headers(epk)

        dk = self._agree_upon_key_at_sender(enc_alg, headers, key, sender_key, epk)

        return {"ek": b"", "cek": dk, "header": h}

    def unwrap(self, enc_alg, ek, headers, key, sender_key, tag=None):
        if "epk" not in headers:
            raise ValueError('Missing "epk" in headers')

        if self.key_size is None:
            bit_size = enc_alg.CEK_SIZE
        else:
            bit_size = self.key_size

        sender_pubkey = sender_key.get_op_key("wrapKey")
        epk = key.import_key(headers["epk"])
        epk_pubkey = epk.get_op_key("wrapKey")
        dk = self.deliver_at_recipient(
            key, sender_pubkey, epk_pubkey, headers, bit_size, tag
        )

        if self.key_size is None:
            return dk

        kek = self.aeskw.prepare_key(dk)
        return self.aeskw.unwrap(enc_alg, ek, headers, kek)


JWE_DRAFT_ALG_ALGORITHMS = [
    ECDH1PUAlgorithm(None),  # ECDH-1PU
    ECDH1PUAlgorithm(128),  # ECDH-1PU+A128KW
    ECDH1PUAlgorithm(192),  # ECDH-1PU+A192KW
    ECDH1PUAlgorithm(256),  # ECDH-1PU+A256KW
]


def register_jwe_alg_draft(cls):
    for alg in JWE_DRAFT_ALG_ALGORITHMS:
        cls.register_algorithm(alg)