1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
|
#!/usr/bin/python3
"""Tests for fips203 python module
From the ffi/python/ directory, do:
PYTHONPATH=. test/nist/keygen.py
"""
from __future__ import annotations
import fips203
import json
import re
from binascii import a2b_hex, b2a_hex
import glob
from typing import Dict, Union, List, TypedDict
with open(
glob.glob('/usr/share/cargo/registry/fips203-*/tests/nist_vectors/ML-KEM-keyGen-FIPS203/internalProjection.json')[0]
) as f:
t = json.load(f)
assert t["vsId"] == 42
assert t["algorithm"] == "ML-KEM"
assert t["mode"] == "keyGen"
assert t["revision"] == "FIPS203"
assert t["isSample"] == False
class KeyGenTestData(TypedDict):
tcId: int
deferred: bool
z: str
d: str
ek: str
dk: str
class KeyGenTest:
def __init__(self, data: KeyGenTestData):
self.tcId = data["tcId"]
self.deferred = data["deferred"]
self.d = a2b_hex(data["d"])
self.z = a2b_hex(data["z"])
self.ek = a2b_hex(data["ek"])
self.dk = a2b_hex(data["dk"])
def run(self, group: TestGroup) -> None:
seed = fips203.Seed(self.d + self.z)
(ek, dk) = seed.keygen(group.strength)
if bytes(ek) != self.ek:
raise Exception(
f"""test {self.tcId} (group {group.tgId}, str: {group.strength}) ek failed:
got: {b2a_hex(bytes(ek))}
wanted: {b2a_hex(self.ek)}"""
)
if bytes(dk) != self.dk:
raise Exception(
f"""test {self.tcId} (group {group.tgId}, str: {group.strength}) dk failed:
got: {b2a_hex(bytes(dk))}
wanted: {b2a_hex(self.dk)}"""
)
class TestGroupData(TypedDict):
tgId: int
testType: str
parameterSet: str
tests: List[KeyGenTestData]
class TestGroup:
param_matcher = re.compile("^ML-KEM-(?P<strength>512|768|1024)$")
def __init__(self, d: TestGroupData) -> None:
self.tgId: int = d["tgId"]
self.testType: str = d["testType"]
assert self.testType == "AFT" # i don't know what AFT means
self.parameterSet: str = d["parameterSet"]
m = self.param_matcher.match(self.parameterSet)
assert m
self.strength: int = int(m["strength"])
self.tests: List[KeyGenTest] = []
for t in d["tests"]:
self.tests.append(KeyGenTest(t))
def run(self) -> None:
for t in self.tests:
t.run(self)
groups: List[TestGroup] = []
for g in t["testGroups"]:
groups.append(TestGroup(g))
for g in groups:
g.run()
|