File: test_kdf_chain.py

package info (click to toggle)
python-doubleratchet 1.1.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 496 kB
  • sloc: python: 2,194; makefile: 13
file content (127 lines) | stat: -rw-r--r-- 4,156 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import random
from typing import Set

from doubleratchet import KDFChain
from doubleratchet.recommended import HashFunction, kdf_hkdf

from .test_recommended_kdfs import generate_unique_random_data


__all__ = [
    "test_kdf_chain"
]


try:
    import pytest
except ImportError:
    pass
else:
    pytestmark = pytest.mark.asyncio


class KDF(kdf_hkdf.KDF):
    """
    The KDF to use for testing.
    """

    @staticmethod
    def _get_hash_function() -> HashFunction:
        return HashFunction.SHA_512

    @staticmethod
    def _get_info() -> bytes:
        return "test_kdf_chain info".encode("ASCII")


async def test_kdf_chain() -> None:
    """
    Test the KDF chain implementation.
    """

    initial_key_set: Set[bytes] = set()
    input_data_set: Set[bytes] = set()
    output_data_set: Set[bytes] = set()

    for _ in range(25):
        # Generate random parameters
        while True:
            initial_key = generate_unique_random_data(0, 2 ** 16, initial_key_set)
            input_data = generate_unique_random_data(0, 2 ** 16, input_data_set)

            output_data_length = random.randrange(2, 2 ** 16)

            digest_size = HashFunction.SHA_512.hash_size
            if len(initial_key) + output_data_length <= 255 * digest_size:
                break

        # Create the KDF chain
        kdf_chain = KDFChain.create(KDF, initial_key)

        # Perform 100 derivation steps
        for step_counter in range(100):
            output_data = await kdf_chain.step(input_data, output_data_length)

            # Assert correct length and uniqueness of the result
            assert len(output_data) == output_data_length
            assert output_data not in output_data_set
            output_data_set.add(output_data)

            # Assert that the chain length is counted correctly
            assert kdf_chain.length == step_counter + 1

        # Save the output data derived in the final step to be able to confirm determinism
        final_step_output_data = output_data

        # Create another KDF chain with the same parameters
        output_data_set.clear()
        kdf_chain = KDFChain.create(KDF, initial_key)

        # Repeat the 100 derivation steps
        for step_counter in range(100):
            output_data = await kdf_chain.step(input_data, output_data_length)

            # Assert correct length and uniqueness of the result
            assert len(output_data) == output_data_length
            assert output_data not in output_data_set
            output_data_set.add(output_data)

            # Assert that the chain length is counted correctly
            assert kdf_chain.length == step_counter + 1

        # Assert determinism
        assert output_data == final_step_output_data

        # Create another KDF chain with the same parameters
        output_data_set.clear()
        kdf_chain = KDFChain.create(KDF, initial_key)

        # Repeat only the first 50 derivation steps
        for step_counter in range(50):
            output_data = await kdf_chain.step(input_data, output_data_length)

            # Assert correct length and uniqueness of the result
            assert len(output_data) == output_data_length
            assert output_data not in output_data_set
            output_data_set.add(output_data)

            # Assert that the chain length is counted correctly
            assert kdf_chain.length == step_counter + 1

        # Serialize and deserialize the KDF chain
        kdf_chain = KDFChain.from_json(kdf_chain.json, KDF)

        # Perform the remaining 50 derivation steps
        for step_counter in range(50):
            output_data = await kdf_chain.step(input_data, output_data_length)

            # Assert correct length and uniqueness of the result
            assert len(output_data) == output_data_length
            assert output_data not in output_data_set
            output_data_set.add(output_data)

            # Assert that the chain length is counted correctly
            assert kdf_chain.length == step_counter + 51

        # Assert that the serialization didn't modify the chain
        assert output_data == final_step_output_data