File: test_diffie_hellman_ratchet.py

package info (click to toggle)
python-doubleratchet 1.1.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 496 kB
  • sloc: python: 2,194; makefile: 13
file content (304 lines) | stat: -rw-r--r-- 13,875 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
from typing import List, Set, Type
from warnings import catch_warnings

from doubleratchet import (
    DoSProtectionException,
    DuplicateMessageException
)
from doubleratchet.diffie_hellman_ratchet import DiffieHellmanRatchet
from doubleratchet.recommended import (
    diffie_hellman_ratchet_curve25519 as dhr25519,
    diffie_hellman_ratchet_curve448 as dhr448,
    HashFunction,
    kdf_hkdf
)

from .test_recommended_kdfs import generate_unique_random_data


__all__ = [
    "test_diffie_hellman_ratchet"
]


try:
    import pytest
except ImportError:
    pass
else:
    pytestmark = pytest.mark.asyncio


class RootChainKDF(kdf_hkdf.KDF):
    """
    The root chain KDF to use for testing.
    """

    @staticmethod
    def _get_hash_function() -> HashFunction:
        return HashFunction.SHA_512

    @staticmethod
    def _get_info() -> bytes:
        return "test_diffie_hellman_ratchet Root Chain info".encode("ASCII")


class MessageChainKDF(kdf_hkdf.KDF):
    """
    The message chain KDF to use for testing.
    """

    @staticmethod
    def _get_hash_function() -> HashFunction:
        return HashFunction.SHA_512_256

    @staticmethod
    def _get_info() -> bytes:
        return "test_diffie_hellman_ratchet Message Chain info".encode("ASCII")


async def test_diffie_hellman_ratchet() -> None:
    """
    Test the Diffie-Hellman ratchet implementation.
    """
    # pylint: disable=protected-access

    impls: List[Type[DiffieHellmanRatchet]] = [ dhr25519.DiffieHellmanRatchet, dhr448.DiffieHellmanRatchet ]

    for impl in impls:
        root_chain_key_set: Set[bytes] = set()
        message_chain_constant_set: Set[bytes] = set()
        for _ in range(100):
            # Generate random parameters
            root_chain_key = generate_unique_random_data(32, 32 + 1, root_chain_key_set)
            message_chain_constant = generate_unique_random_data(0, 2 ** 16, message_chain_constant_set)

            bob_priv = impl._generate_priv()

            # Create instances for Alice and Bob and exchange an initial message
            alice_dhr = await impl.create(
                None,
                impl._derive_pub(bob_priv),
                RootChainKDF,
                root_chain_key,
                MessageChainKDF,
                message_chain_constant,
                10
            )
            encryption_key, header = await alice_dhr.next_encryption_key()

            bob_dhr = await impl.create(
                bob_priv,
                header.ratchet_pub,
                RootChainKDF,
                root_chain_key,
                MessageChainKDF,
                message_chain_constant,
                10
            )
            decryption_key, skipped_message_keys = await bob_dhr.next_decryption_key(header)
            assert header.previous_sending_chain_length == 0
            assert header.sending_chain_length == 0
            assert len(skipped_message_keys) == 0
            assert len(encryption_key) == len(decryption_key) == 32
            assert encryption_key == decryption_key
            alice_pub = header.ratchet_pub

            # Test that Bob can send to Alice now
            encryption_key, header = await bob_dhr.next_encryption_key()
            decryption_key, skipped_message_keys = await alice_dhr.next_decryption_key(header)
            assert header.previous_sending_chain_length == 0
            assert header.sending_chain_length == 0
            assert len(skipped_message_keys) == 0
            assert len(encryption_key) == len(decryption_key) == 32
            assert encryption_key == decryption_key
            assert header.ratchet_pub != impl._derive_pub(bob_priv)
            bob_pub = header.ratchet_pub

            # Test that n increases in the header and the ratchet pub stays the same
            encryption_key, header = await bob_dhr.next_encryption_key()
            decryption_key, skipped_message_keys = await alice_dhr.next_decryption_key(header)
            assert header.previous_sending_chain_length == 0
            assert header.sending_chain_length == 1
            assert len(skipped_message_keys) == 0
            assert len(encryption_key) == len(decryption_key) == 32
            assert encryption_key == decryption_key
            assert header.ratchet_pub == bob_pub
            bob_pub = header.ratchet_pub

            # Test that switching sender/receiver triggers a Diffie-Hellman ratchet step
            encryption_key, header = await alice_dhr.next_encryption_key()
            decryption_key, skipped_message_keys = await bob_dhr.next_decryption_key(header)
            assert header.previous_sending_chain_length == 1
            assert header.sending_chain_length == 0
            assert len(skipped_message_keys) == 0
            assert len(encryption_key) == len(decryption_key) == 32
            assert encryption_key == decryption_key
            assert header.ratchet_pub != alice_pub
            alice_pub = header.ratchet_pub

            # Test that pn is set correctly in the header
            encryption_key, header = await bob_dhr.next_encryption_key()
            decryption_key, skipped_message_keys = await alice_dhr.next_decryption_key(header)
            assert header.previous_sending_chain_length == 2
            assert header.sending_chain_length == 0
            assert len(skipped_message_keys) == 0
            assert len(encryption_key) == len(decryption_key) == 32
            assert encryption_key == decryption_key
            assert header.ratchet_pub != bob_pub
            bob_pub = header.ratchet_pub

            # Test a few skipped messages (simple case, no Diffie-Hellman ratchet steps)
            skipped_encryption_key_1, skipped_header_1 = await bob_dhr.next_encryption_key()
            skipped_encryption_key_2, skipped_header_2 = await bob_dhr.next_encryption_key()
            skipped_encryption_key_3, skipped_header_3 = await bob_dhr.next_encryption_key()
            encryption_key, header = await bob_dhr.next_encryption_key()
            decryption_key, skipped_message_keys = await alice_dhr.next_decryption_key(header)
            assert header.previous_sending_chain_length == 2
            assert header.sending_chain_length == 4
            assert len(skipped_message_keys) == 3
            assert len(encryption_key) == len(decryption_key) == 32
            assert encryption_key == decryption_key
            assert header.ratchet_pub == bob_pub
            bob_pub = header.ratchet_pub

            # Check the skipped message keys
            assert skipped_header_1.ratchet_pub == bob_pub
            assert skipped_header_2.ratchet_pub == bob_pub
            assert skipped_header_3.ratchet_pub == bob_pub
            assert skipped_header_1.previous_sending_chain_length == 2
            assert skipped_header_2.previous_sending_chain_length == 2
            assert skipped_header_3.previous_sending_chain_length == 2
            assert skipped_header_1.sending_chain_length == 1
            assert skipped_header_2.sending_chain_length == 2
            assert skipped_header_3.sending_chain_length == 3
            assert skipped_message_keys[(bob_pub, 1)] == skipped_encryption_key_1
            assert skipped_message_keys[(bob_pub, 2)] == skipped_encryption_key_2
            assert skipped_message_keys[(bob_pub, 3)] == skipped_encryption_key_3

            # Test that attempting to acquire one of these keys again raises an exception
            try:
                await alice_dhr.next_decryption_key(skipped_header_3)
                assert False
            except DuplicateMessageException:
                pass

            try:
                await alice_dhr.next_decryption_key(header)
                assert False
            except DuplicateMessageException:
                pass

            # Test the more complicated case of skipped message keys (after a Diffie-Hellman ratchet step)
            skipped_encryption_key, skipped_header = await bob_dhr.next_encryption_key()  # Prepare a message
            encryption_key, header = await alice_dhr.next_encryption_key()  # Perform a DH ratchet step
            decryption_key, skipped_message_keys = await bob_dhr.next_decryption_key(header)
            encryption_key, header = await bob_dhr.next_encryption_key()  # Let Alice decrypt a fresh message
            decryption_key, skipped_message_keys = await alice_dhr.next_decryption_key(header)
            assert len(encryption_key) == len(decryption_key) == 32
            assert encryption_key == decryption_key
            assert len(skipped_message_keys) == 1
            skipped_message_keys_key = (skipped_header.ratchet_pub, skipped_header.sending_chain_length)
            assert skipped_message_keys[skipped_message_keys_key] == skipped_encryption_key

            # Decrypting this message should not raise an exception but mess up the ratchet instead and return
            # a wrong key:
            decryption_key, skipped_message_keys = await alice_dhr.next_decryption_key(skipped_header)
            assert decryption_key != skipped_encryption_key

            # The ratchets are now completely desynchronized, the only option is creating new ratchets. The
            # Double Ratchet mitigates this issue.
            alice_dhr = await impl.create(
                None,
                impl._derive_pub(bob_priv),
                RootChainKDF,
                root_chain_key,
                MessageChainKDF,
                message_chain_constant,
                10
            )
            encryption_key, header = await alice_dhr.next_encryption_key()

            bob_dhr = await impl.create(
                bob_priv,
                header.ratchet_pub,
                RootChainKDF,
                root_chain_key,
                MessageChainKDF,
                message_chain_constant,
                10
            )
            decryption_key, skipped_message_keys = await bob_dhr.next_decryption_key(header)
            assert header.previous_sending_chain_length == 0
            assert header.sending_chain_length == 0
            assert len(skipped_message_keys) == 0
            assert len(encryption_key) == len(decryption_key) == 32
            assert encryption_key == decryption_key
            alice_pub = header.ratchet_pub

            # Test the (hard) DoS protection by skipping more than 10 messages:
            for _ in range(25):
                await bob_dhr.next_encryption_key()
            encryption_key, header = await bob_dhr.next_encryption_key()
            try:
                await alice_dhr.next_decryption_key(header)
                assert False
            except DoSProtectionException:
                pass

            # Perform a Diffie-Hellman ratchet step
            encryption_key, header = await alice_dhr.next_encryption_key()
            decryption_key, skipped_message_keys = await bob_dhr.next_decryption_key(header)

            # Test the (soft) DoS protection:
            encryption_key, header = await bob_dhr.next_encryption_key()
            with catch_warnings(record=True) as warnings:
                decryption_key, skipped_message_keys = await alice_dhr.next_decryption_key(header)
                assert len(warnings) == 1
                assert issubclass(warnings[0].category, UserWarning)
                assert "DoS" in str(warnings[0].message)
            assert len(skipped_message_keys) == 0  # Without DoS protection, this would be 25+
            assert len(encryption_key) == len(decryption_key) == 32
            assert encryption_key == decryption_key

            # Make sure that a root key of a different size than 32 bytes is rejected
            try:
                await impl.create(
                    None,
                    impl._derive_pub(bob_priv),
                    RootChainKDF,
                    b"\00" * 64,
                    MessageChainKDF,
                    message_chain_constant,
                    10
                )
                assert False
            except ValueError as e:
                assert "key" in str(e)
                assert "root chain" in str(e)
                assert "32 bytes" in str(e)

            # Test that (de)serializing doesn't influence the functionality
            encryption_key, header = await alice_dhr.next_encryption_key()
            decryption_key, skipped_message_keys = await bob_dhr.next_decryption_key(header)
            assert encryption_key == decryption_key
            alice_dhr = impl.from_json(alice_dhr.json, RootChainKDF, MessageChainKDF,
                                       message_chain_constant, 10)
            encryption_key, header = await alice_dhr.next_encryption_key()
            decryption_key, skipped_message_keys = await bob_dhr.next_decryption_key(header)
            assert encryption_key == decryption_key
            bob_dhr = impl.from_json(bob_dhr.json, RootChainKDF, MessageChainKDF,
                                     message_chain_constant, 10)
            encryption_key, header = await alice_dhr.next_encryption_key()
            decryption_key, skipped_message_keys = await bob_dhr.next_decryption_key(header)
            assert encryption_key == decryption_key

            # Make sure that a message can be decrypted twice by restoring an old serialized state
            encryption_key, header = await alice_dhr.next_encryption_key()
            bob_dhr_serialized = bob_dhr.json
            decryption_key, skipped_message_keys = await bob_dhr.next_decryption_key(header)
            assert encryption_key == decryption_key
            bob_dhr = impl.from_json(bob_dhr_serialized, RootChainKDF, MessageChainKDF,
                                     message_chain_constant, 10)
            decryption_key, skipped_message_keys = await bob_dhr.next_decryption_key(header)
            assert encryption_key == decryption_key