1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
|
import os
import random
from typing import Set, Type
from doubleratchet.recommended import HashFunction, kdf_hkdf, kdf_separate_hmacs
__all__ = [
"test_kdf_hkdf",
"test_kdf_separate_hmacs"
]
try:
import pytest
except ImportError:
pass
else:
pytestmark = pytest.mark.asyncio
def make_kdf_hkdf(hash_function: HashFunction, info: bytes) -> Type[kdf_hkdf.KDF]:
"""
Create a subclass of :class:`~doubleratchet.recommended.kdf_hkdf.KDF` using given hash function and info.
Args:
hash_function: The hash function to use.
info: The info to use.
Returns:
The subclass.
"""
class KDF(kdf_hkdf.KDF):
@staticmethod
def _get_hash_function() -> HashFunction:
return hash_function
@staticmethod
def _get_info() -> bytes:
return info
return KDF
def make_kdf_separate_hmacs(hash_function: HashFunction) -> Type[kdf_separate_hmacs.KDF]:
"""
Create a subclass of :class:`~doubleratchet.recommended.kdf_separate_hmacs.KDF` using given hash function.
Args:
hash_function: The hash function to use.
Returns:
The subclass.
"""
class KDF(kdf_separate_hmacs.KDF):
@staticmethod
def _get_hash_function() -> HashFunction:
return hash_function
return KDF
def generate_unique_random_data(lower_bound: int, upper_bound: int, data_set: Set[bytes]) -> bytes:
"""
Generate random data of random length (within certain bounds) and make sure that the generated data is
new.
Args:
lower_bound: The minimum number of bytes.
upper_bound: The maximum number of bytes (exclusive).
data_set: The set of random data that has been generated before, for uniqueness checks.
Returns:
The newly generated, unique random data.
"""
while True:
data = os.urandom(random.randrange(lower_bound, upper_bound))
if data not in data_set:
data_set.add(data)
return data
async def test_kdf_hkdf() -> None:
"""
Test the HKDF-based recommended KDF implementation.
"""
for hash_function in HashFunction:
key_set: Set[bytes] = set()
input_data_set: Set[bytes] = set()
output_data_set: Set[bytes] = set()
info_set: Set[bytes] = set()
for _ in range(50):
# Generate (unique) random parameters
key = generate_unique_random_data(0, 2 ** 16, key_set)
input_data = generate_unique_random_data(0, 2 ** 16, input_data_set)
info = generate_unique_random_data(0, 2 ** 16, info_set)
output_data_length = random.randrange(2, 255 * hash_function.hash_size + 1)
# Prepare the KDF
KDF = make_kdf_hkdf(hash_function, info)
# Perform a key derivation
output_data = await KDF.derive(key, input_data, output_data_length)
# Assert correct length and uniqueness of the result
assert len(output_data) == output_data_length
assert output_data not in output_data_set
output_data_set.add(output_data)
# Assert determinism
for _ in range(25):
output_data_repeated = await KDF.derive(key, input_data, output_data_length)
assert output_data_repeated == output_data
async def test_kdf_separate_hmacs() -> None:
"""
Test the separate HMAC-based recommended KDF implementation.
"""
for hash_function in HashFunction:
key_set: Set[bytes] = set()
input_data_set: Set[bytes] = set()
output_data_set: Set[bytes] = set()
# Prepare the KDF
KDF = make_kdf_separate_hmacs(hash_function)
for _ in range(50):
# Generate (unique) random parameters
key = generate_unique_random_data(0, 2 ** 16, key_set)
input_data = generate_unique_random_data(1, 2 ** 8, input_data_set)
output_data_length = len(input_data) * hash_function.hash_size
# Perform a key derivation
output_data = await KDF.derive(key, input_data, output_data_length)
# Assert correct length and uniqueness of the result
assert len(output_data) == output_data_length
assert output_data not in output_data_set
output_data_set.add(output_data)
# Assert determinism
for _ in range(25):
output_data_repeated = await KDF.derive(key, input_data, output_data_length)
assert output_data_repeated == output_data
|