File: fips203.py

package info (click to toggle)
rust-fips203-ffi 0.4.3-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 232 kB
  • sloc: python: 430; ansic: 199; makefile: 51; sh: 4
file content (460 lines) | stat: -rw-r--r-- 15,313 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
'''FIPS 203 (ML-KEM) Asymmetric Post-Quantum Cryptography

This Python module provides an implementation of FIPS 203, the
Module-Lattice-based Key Encapsulation Mechanism Standard.

The underlying mechanism is intended to offer "post-quantum"
asymmetric encryption and decryption.

## Example

The following example shows using the standard ML-KEM algorithm to
produce identical 32-byte shared secrets:

```
from fips203 import ML_KEM_512

(encapsulation_key, decapsulation_key) = ML_KEM_512.keygen()
(ciphertext, shared_secret_1) = encapsulation_key.encaps()
shared_secret_2 = decapsulation_key.decaps(ciphertext)
assert(shared_secret_1 == shared_secret_2)
```


Key generation can also be done deterministically, by passing a
`SEED_SIZE`-byte seed (the concatenation of d and z) to `keygen`:

```
from fips203 import ML_KEM_512, Seed

seed1 = Seed()  # Generate a random seed
(ek1, dk1) = ML_KEM_512.keygen(seed1)

seed2 = Seed(b'\x00'*ML_KEM_512.SEED_SIZE)  # This seed is clearly not a secret!
(ek2, dk2) = ML_KEM_512.keygen(seed2)
```


Encapsulation keys, decapsulation keys, seeds, and ciphertexts can all
be serialized by accessing them as `bytes`, and deserialized by
initializing them with the appropriate size bytes object.

A serialization example:

```
from fips203 import ML_KEM_768

seed = Seed()
(ek,dk) = ML_KEM_768.keygen(seed)
with open('encapskey.bin', 'wb') as f:
    f.write(bytes(ek))
with open('decapskey.bin', 'wb') as f:
    f.write(bytes(dk))
with open('seed.bin', 'wb') as f:
    f.write(bytes(seed)
```

A deserialization example, followed by use:

```
import fips203

with open('encapskey.bin', 'b') as f:
    ekdata = f.read()

ek = fips203.EncapsulationKey(ekdata)
(ct, ss) = ek.Encaps()
```

The expected sizes (in bytes) of the different objects in each
parameter set can be accessed with `EK_SIZE`, `DK_SIZE`, `CT_SIZE`,
`SEED_SIZE`, and `SS_SIZE`:

```
from fips203 import ML_KEM_768

print(f"ML-KEM-768 Ciphertext size (in bytes) is {ML_KEM_768.CT_SIZE}")
```

## Implementation Notes

This is a wrapper around libfips203, built from the Rust fips203-ffi crate.

If that library is not installed in the expected path for libraries on
your system, any attempt to use this module will fail.

This module should have reasonable type annotations and docstrings for
the public interface.  If you discover a problem with type
annotations, or see a way that this kind of documentation could be
improved, please report it!

## See Also

- https://doi.org/10.6028/NIST.FIPS.203.ipd
- https://github.com/integritychain/fips203

## Bug Reporting

Please report issues at https://github.com/integritychain/fips203/issues
'''
from __future__ import annotations

'''__version__ should track package.version from  ../Cargo.toml'''
__version__ = '0.4.3'
__author__ = 'Daniel Kahn Gillmor <dkg@fifthhorseman.net>'
__all__ = [
    'ML_KEM_512',
    'ML_KEM_768',
    'ML_KEM_1024',
    'Ciphertext',
    'EncapsulationKey',
    'DecapsulationKey',
    'Seed',
]

import ctypes
import ctypes.util
import enum
import secrets
from typing import Tuple, Dict, Any, Union, Optional
from abc import ABC
import sys


class _SharedSecret(ctypes.Structure):
    _fields_ = [('data', ctypes.c_uint8 * 32)]

class _Seed(ctypes.Structure):
    _fields_ = [('data', ctypes.c_uint8 * 64)]

class Err(enum.IntEnum):
    OK = 0
    NULL_PTR_ERROR = 1
    SERIALIZATION_ERROR = 2
    DESERIALIZATION_ERROR = 3
    KEYGEN_ERROR = 4
    ENCAPSULATION_ERROR = 5
    DECAPSULATION_ERROR = 6


class Seed():
    '''ML-KEM Seed

    This seed can be used to generate an ML-KEM keypair
    '''
    def __init__(self, data: Optional[bytes] = None) -> None:
        '''If initialized with None, the seed will be randomly populated.'''
        self._seed = _Seed()
        if data is None:
            # FIXME: perhaps use ml_kem_populate_seed instead?
            data = secrets.token_bytes(len(self._seed.data))
        if len(data) != len(self._seed.data):
            raise ValueError(f"Expected {len(self._seed.data)} bytes, "
                             f"got {len(data)}.")
        for i in range(len(data)):
            self._seed.data[i] = data[i]

    def __repr__(self) -> str:
        return '<ML-KEM Seed>'

    def __bytes__(self) -> bytes:
        return bytes(self._seed.data)

    def keygen(self, strength: int) -> Tuple[EncapsulationKey, DecapsulationKey]:
        for kt in ML_KEM_512, ML_KEM_768, ML_KEM_1024:
            if kt._strength == strength:
                return kt.keygen(self)
        raise Exception(f"Unknown strength: {strength}, must be 512, 768, or 1024.")

class Ciphertext():
    '''ML-KEM Ciphertext

    Serialize this object by asking for it as `bytes`.

    You can convert it to a 32-byte shared secret by passing it to the
    Decaps() function of the appropriate Decapsulation Key.

    '''
    def __init__(self, data: Union[bytes, int]) -> None:
        '''Create ML-KEM Ciphertext from bytes (or strength level).'''
        if isinstance(data, bytes):
            self._strength = _ML_KEM.strength_from_length('CT_SIZE', len(data))
        elif isinstance(data, int):
            self._strength = data
        else:
            raise Exception("Initialize ML-KEM Ciphertext object with "
                            f"bytes or a strength level, not {type(data)}")
        self._ffi = _ML_KEM.strength(self._strength)
        self._ct = self._ffi['Ciphertext']()

        if isinstance(data, bytes):
            self._set(data)

    def __repr__(self) -> str:
        return f'<ML-KEM-{self._strength} Ciphertext>'

    def __bytes__(self) -> bytes:
        return bytes(self._ct.data)

    def _set(self, data: bytes) -> None:
        if len(data) != len(self._ct.data):
            raise ValueError(f"Expected {len(self._ct.data)} bytes, "
                             f"got {len(data)}")
        for i in range(len(data)):
            self._ct.data[i] = data[i]


class EncapsulationKey():
    '''ML-KEM Encapsulation Key

    Serialize this object by asking for it as `bytes`.

    Produce a Ciphertext and a 32-byte shared secret by invoking
    Encaps() on it.
    '''
    def __init__(self, data: Union[bytes, int]) -> None:
        '''Create ML-KEM Encapsulation Key from bytes (or strength level).'''
        if isinstance(data, bytes):
            self._strength = _ML_KEM.strength_from_length('EK_SIZE', len(data))
        elif isinstance(data, int):
            self._strength = data
        else:
            raise Exception("Initialize ML-KEM Encapsulation Key with "
                            f"bytes or a strength level, not {type(data)}")
        self._ffi = _ML_KEM.strength(self._strength)
        self._ek = self._ffi['EncapsKey']()
        if isinstance(data, bytes):
            self._set(data)

    def __repr__(self) -> str:
        return f'<ML-KEM-{self._strength} Encapsulation Key>'

    def __bytes__(self) -> bytes:
        return bytes(self._ek.data)

    def _set(self, data: bytes) -> None:
        if len(data) != len(self._ek.data):
            raise ValueError(f"Expected {len(self._ek.data)} bytes, "
                             f"got {len(data)}")
        for i in range(len(data)):
            self._ek.data[i] = data[i]

    def encaps(self) -> Tuple[Ciphertext, bytes]:
        '''Produce a new Ciphertext and corresponding 32-byte shared secret.'''
        ct = Ciphertext(self._strength)
        ss = _SharedSecret()
        ret = Err(self._ffi['encaps'](ctypes.byref(self._ek),
                                      ctypes.byref(ct._ct),
                                      ctypes.byref(ss)))
        if ret is not Err.OK:
            raise Exception(f"ml_kem_{self._strength}_encaps() "
                            f"returned {ret} ({ret.name})")
        return (ct, bytes(ss.data))


class DecapsulationKey():
    '''ML-KEM Decapsulation Key

    Serialize this object by asking for it as `bytes`.

    Produce a 32-byte shared secret from a Ciphertext by invoking
    Decaps() on it.
    '''
    def __init__(self, data: Union[bytes, int]) -> None:
        '''Create ML-KEM Decapsulation Key from bytes (or strength level).'''
        if isinstance(data, bytes):
            self._strength = _ML_KEM.strength_from_length('DK_SIZE', len(data))
        elif isinstance(data, int):
            self._strength = data
        else:
            raise Exception("Initialize ML-KEM Encapsulation Key with bytes "
                            f"or a strength level, not {type(data)}")
        self._ffi = _ML_KEM.strength(self._strength)
        self._dk = self._ffi['DecapsKey']()
        if isinstance(data, bytes):
            self._set(data)

    def __repr__(self) -> str:
        return f'<ML-KEM-{self._strength} Decapsulation Key>'

    def __bytes__(self) -> bytes:
        return bytes(self._dk.data)

    def _set(self, data: bytes) -> None:
        if len(data) != len(self._dk.data):
            raise ValueError(f"Expected {len(self._dk.data)} bytes, "
                             f"got {len(data)}")
        for i in range(len(data)):
            self._dk.data[i] = data[i]

    def decaps(self, ct: Ciphertext) -> bytes:
        '''Get 32-byte shared secret corresponding to the given Ciphertext.'''
        if self._strength != ct._strength:
            raise Exception(f"Cannot decapsulate {ct} with {self}")
        ss = _SharedSecret()
        ret = Err(self._ffi['decaps'](ctypes.byref(self._dk),
                                      ctypes.byref(ct._ct),
                                      ctypes.byref(ss)))
        if ret is not Err.OK:
            raise Exception(f"ml_kem_{self._strength}_decaps() "
                            f"returned {ret} ({ret.name})")
        return bytes(ss.data)


class _ML_KEM():
    params: Dict[int, Dict[str, int]] = {
        512: {
            'EK_SIZE': 800,
            'DK_SIZE': 1632,
            'CT_SIZE': 768,
            },
        768: {
            'EK_SIZE': 1184,
            'DK_SIZE': 2400,
            'CT_SIZE': 1088,
            },
        1024: {
            'EK_SIZE': 1568,
            'DK_SIZE': 3168,
            'CT_SIZE': 1568,
            },
        }
    lib = ctypes.CDLL(ctypes.util.find_library('fips203'))
    if not hasattr(lib, 'ml_kem_512_keygen'):
        if sys.platform == 'darwin':
            lib = ctypes.CDLL('../../target/debug/libfips203.dylib')
        else:
            lib = ctypes.CDLL("../../target/debug/libfips203.so")

    # use Any below because i don't know how to specify the type of the FuncPtr
    ffi: Dict[int, Dict[str, Any]] = {}

    @classmethod
    def strength(cls, level: int) -> Dict[str, Any]:
        if level not in cls.ffi:
            class _EncapsKey(ctypes.Structure):
                _fields_ = [('data', ctypes.c_uint8 *
                             cls.params[level]['EK_SIZE'])]

            class _DecapsKey(ctypes.Structure):
                _fields_ = [('data', ctypes.c_uint8 *
                             cls.params[level]['DK_SIZE'])]

            class _Ciphertext(ctypes.Structure):
                _fields_ = [('data', ctypes.c_uint8 *
                             cls.params[level]['CT_SIZE'])]
            ffi: Dict[str, Any] = {}

            ffi['keygen'] = cls.lib[f'ml_kem_{level}_keygen']
            ffi['keygen'].argtypes = [ctypes.POINTER(_EncapsKey),
                                      ctypes.POINTER(_DecapsKey)]
            ffi['keygen'].restype = ctypes.c_uint8

            ffi['keygen_from_seed'] = cls.lib[f'ml_kem_{level}_keygen_from_seed']
            ffi['keygen_from_seed'].argtypes = [ctypes.POINTER(_Seed),
                                                ctypes.POINTER(_EncapsKey),
                                                ctypes.POINTER(_DecapsKey)]
            ffi['keygen_from_seed'].restype = ctypes.c_uint8

            ffi['encaps'] = cls.lib[f'ml_kem_{level}_encaps']
            ffi['encaps'].argtypes = [ctypes.POINTER(_EncapsKey),
                                      ctypes.POINTER(_Ciphertext),
                                      ctypes.POINTER(_SharedSecret)]
            ffi['encaps'].restype = ctypes.c_uint8

            ffi['decaps'] = cls.lib[f'ml_kem_{level}_decaps']
            ffi['decaps'].argtypes = [ctypes.POINTER(_DecapsKey),
                                      ctypes.POINTER(_Ciphertext),
                                      ctypes.POINTER(_SharedSecret)]
            ffi['decaps'].restype = ctypes.c_uint8

            ffi['EncapsKey'] = _EncapsKey
            ffi['DecapsKey'] = _DecapsKey
            ffi['Ciphertext'] = _Ciphertext

            cls.ffi[level] = ffi

        return cls.ffi[level]

    @classmethod
    def strength_from_length(cls, object_type: str, object_len: int) -> int:
        for strength in cls.params:
            if cls.params[strength][object_type] == object_len:
                return strength
        raise Exception(f"No ML-KEM parameter set has {object_type} "
                        f"of {object_len} bytes")

    @classmethod
    def _keygen(cls, strength: int) -> Tuple[EncapsulationKey,
                                             DecapsulationKey]:
        ek = EncapsulationKey(strength)
        dk = DecapsulationKey(strength)

        ret = Err(cls.strength(strength)['keygen'](ctypes.byref(ek._ek),
                                                   ctypes.byref(dk._dk)))
        if ret is not Err.OK:
            raise Exception(f"ml_kem_{strength}_keygen() returned "
                            f"{ret} ({ret.name})")
        return (ek, dk)


    @classmethod
    def _keygen_from_seed(cls, strength: int, seed: Seed) -> Tuple[EncapsulationKey,
                                                                   DecapsulationKey]:
        ek = EncapsulationKey(strength)
        dk = DecapsulationKey(strength)

        ret = Err(cls.strength(strength)['keygen_from_seed'](
            ctypes.byref(seed._seed),
            ctypes.byref(ek._ek),
            ctypes.byref(dk._dk)
        ))
        if ret is not Err.OK:
            raise Exception(f"ml_kem_{strength}_keygen() returned "
                            f"{ret} ({ret.name})")
        return (ek, dk)

class ML_KEM(ABC):
    '''Abstract base class for all ML-KEM (FIPS 203) parameter sets.'''

    _strength: int
    EK_SIZE: int
    DK_SIZE: int
    CT_SIZE: int
    SS_SIZE: int = 32
    SEED_SIZE: int = 64

    @classmethod
    def keygen(cls, seed: Optional[Seed] = None) -> Tuple[EncapsulationKey, DecapsulationKey]:
        '''Generate a pair of Encapsulation and Decapsulation Keys.

        If a Seed is supplied, do a deterministic generation from the seed.
        Otherwise, randomly generate the key.'''
        if seed is None:
            return _ML_KEM._keygen(cls._strength)
        else:
            return _ML_KEM._keygen_from_seed(cls._strength, seed)


class ML_KEM_512(ML_KEM):
    '''ML-KEM-512 (FIPS 203) Implementation.'''
    _strength: int = 512
    EK_SIZE: int = 800
    DK_SIZE: int = 1632
    CT_SIZE: int = 768


class ML_KEM_768(ML_KEM):
    '''ML-KEM-768 (FIPS 203) Implementation.'''
    _strength: int = 768
    EK_SIZE: int = 1184
    DK_SIZE: int = 2400
    CT_SIZE: int = 1088


class ML_KEM_1024(ML_KEM):
    '''ML-KEM-1024 (FIPS 203) Implementation.'''
    _strength: int = 1024
    EK_SIZE: int = 1568
    DK_SIZE: int = 3168
    CT_SIZE: int = 1568