File: test_symmetric_key_ratchet.py

package info (click to toggle)
python-doubleratchet 1.1.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 496 kB
  • sloc: python: 2,194; makefile: 13
file content (121 lines) | stat: -rw-r--r-- 3,824 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from typing import Set

from doubleratchet import (
    Chain,
    ChainNotAvailableException,
    SymmetricKeyRatchet
)
from doubleratchet.recommended import HashFunction, kdf_hkdf

from .test_recommended_kdfs import generate_unique_random_data


__all__ = [
    "test_symmetric_key_ratchet"
]


try:
    import pytest
except ImportError:
    pass
else:
    pytestmark = pytest.mark.asyncio


class KDF(kdf_hkdf.KDF):
    """
    The KDF to use for testing.
    """

    @staticmethod
    def _get_hash_function() -> HashFunction:
        return HashFunction.SHA_512

    @staticmethod
    def _get_info() -> bytes:
        return "test_symmetric_key_ratchet info".encode("ASCII")


async def test_symmetric_key_ratchet() -> None:
    """
    Test the symmetric-key ratchet implementation.
    """

    constant_set: Set[bytes] = set()
    key_set: Set[bytes] = set()

    for _ in range(10000):
        constant = generate_unique_random_data(0, 2 ** 16, constant_set)

        skr_a = SymmetricKeyRatchet.create(KDF, constant)
        skr_b = SymmetricKeyRatchet.create(KDF, constant)

        assert skr_a.previous_sending_chain_length is None
        assert skr_b.previous_sending_chain_length is None
        assert skr_a.sending_chain_length is None
        assert skr_b.sending_chain_length is None
        assert skr_a.receiving_chain_length is None
        assert skr_b.receiving_chain_length is None

        key = generate_unique_random_data(32, 32 + 1, key_set)
        skr_a.replace_chain(Chain.SENDING, key)
        skr_b.replace_chain(Chain.RECEIVING, key)

        assert skr_a.previous_sending_chain_length is None
        assert skr_b.previous_sending_chain_length is None
        assert skr_a.sending_chain_length == 0
        assert skr_b.sending_chain_length is None
        assert skr_a.receiving_chain_length is None
        assert skr_b.receiving_chain_length == 0

        try:
            await skr_a.next_decryption_key()
            assert False
        except ChainNotAvailableException as e:
            assert "receiving chain" in str(e)
            assert "never initialized" in str(e)

        try:
            await skr_b.next_encryption_key()
            assert False
        except ChainNotAvailableException as e:
            assert "sending chain" in str(e)
            assert "never initialized" in str(e)

        assert await skr_a.next_encryption_key() == await skr_b.next_decryption_key()

        assert skr_a.sending_chain_length == 1
        assert skr_b.receiving_chain_length == 1

        key = generate_unique_random_data(32, 32 + 1, key_set)
        skr_a.replace_chain(Chain.SENDING, key)
        skr_b.replace_chain(Chain.RECEIVING, key)

        key = generate_unique_random_data(32, 32 + 1, key_set)
        skr_a.replace_chain(Chain.RECEIVING, key)
        skr_b.replace_chain(Chain.SENDING, key)

        assert await skr_a.next_encryption_key() == await skr_b.next_decryption_key()
        assert await skr_a.next_encryption_key() == await skr_b.next_decryption_key()

        assert await skr_b.next_encryption_key() == await skr_a.next_decryption_key()

        assert skr_a.previous_sending_chain_length == 1
        assert skr_b.previous_sending_chain_length is None
        assert skr_a.sending_chain_length == 2
        assert skr_b.sending_chain_length == 1
        assert skr_a.receiving_chain_length == 1
        assert skr_b.receiving_chain_length == 2

        assert len(await skr_a.next_encryption_key()) == 32
        await skr_b.next_decryption_key()

        try:
            skr_a.replace_chain(Chain.SENDING, b"\x00" * 64)
            assert False
        except ValueError as e:
            assert "chain key" in str(e)
            assert "32 bytes" in str(e)

        assert await skr_a.next_encryption_key() == await skr_b.next_decryption_key()