File: test_vectors.py

package info (click to toggle)
python-noiseprotocol 0.3.1-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,656 kB
  • sloc: python: 1,259; makefile: 25
file content (129 lines) | stat: -rw-r--r-- 5,494 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import json
import logging
import os

import pytest

from noiseprotocol.connection import NoiseConnection, Keypair

logger = logging.getLogger(__name__)

vector_files = [
    'vectors/cacophony.txt',
    'vectors/snow-multipsk.txt'
]

# As in test vectors specification (https://github.com/noiseprotocol/noise_wiki/wiki/Test-vectors)
# We use this to cast read strings into bytes
byte_field = 'protocol_name'
hexbyte_fields = ('init_prologue', 'init_static', 'init_ephemeral', 'init_remote_static', 'resp_static',
                  'resp_prologue', 'resp_ephemeral', 'resp_remote_static', 'handshake_hash')
list_fields = ('init_psks', 'resp_psks')
dict_field = 'messages'


def _prepare_test_vectors():
    vectors = []
    for path in vector_files:
        with open(os.path.join(os.path.dirname(__file__), path)) as fd:
            logging.info('Reading vectors from file {}'.format(path))
            vectors_list = json.load(fd)

        for vector in vectors_list:
            for key, value in vector.copy().items():
                if key == byte_field:
                    vector[key] = value.encode()
                if key in hexbyte_fields:
                    vector[key] = bytes.fromhex(value)
                if key in list_fields:
                    vector[key] = [bytes.fromhex(k) for k in value]
                if key == dict_field:
                    vector[key] = []
                    for dictionary in value:
                        vector[key].append({k: bytes.fromhex(v) for k, v in dictionary.items()})
            vectors.append(vector)
    return vectors


def idfn(vector):
    return vector['protocol_name']


@pytest.mark.filterwarnings('ignore: This implementation of ed448')
@pytest.mark.filterwarnings('ignore: One of ephemeral keypairs')
class TestVectors(object):
    @pytest.fixture(params=_prepare_test_vectors(), ids=idfn)
    def vector(self, request):
        yield request.param

    def _set_keypairs(self, vector, connection):
        role = 'init' if connection.noise_protocol.initiator else 'resp'
        setters = [
            (connection.set_keypair_from_private_bytes, Keypair.STATIC, role + '_static'),
            (connection.set_keypair_from_private_bytes, Keypair.EPHEMERAL, role + '_ephemeral'),
            (connection.set_keypair_from_public_bytes, Keypair.REMOTE_STATIC, role + '_remote_static')
        ]
        for fn, keypair, name in setters:
            if name in vector:
                fn(keypair, vector[name])

    def test_vector(self, vector):
        initiator = NoiseConnection.from_name(vector['protocol_name'])
        responder = NoiseConnection.from_name(vector['protocol_name'])
        if 'init_psks' in vector and 'resp_psks' in vector:
            initiator.set_psks(psks=vector['init_psks'])
            responder.set_psks(psks=vector['resp_psks'])

        initiator.set_prologue(vector['init_prologue'])
        initiator.set_as_initiator()
        self._set_keypairs(vector, initiator)

        responder.set_prologue(vector['resp_prologue'])
        responder.set_as_responder()
        self._set_keypairs(vector, responder)

        initiator.start_handshake()
        responder.start_handshake()

        initiator_to_responder = True
        handshake_finished = False
        for message in vector['messages']:
            if not handshake_finished:
                if initiator_to_responder:
                    sender, receiver = initiator, responder
                else:
                    sender, receiver = responder, initiator

                sender_result = sender.write_message(message['payload'])
                assert sender_result == message['ciphertext']

                receiver_result = receiver.read_message(sender_result)
                assert receiver_result == message['payload']

                if not (sender.handshake_finished and receiver.handshake_finished):
                    # Not finished with handshake, fail if one would finish before other
                    assert sender.handshake_finished == receiver.handshake_finished
                else:
                    # Handshake done
                    handshake_finished = True

                    # Verify handshake hash
                    if 'handshake_hash' in vector:
                        assert initiator.noise_protocol.handshake_hash == responder.noise_protocol.handshake_hash == vector['handshake_hash']

                    # Verify split cipherstates keys
                    assert initiator.noise_protocol.cipher_state_encrypt.k == responder.noise_protocol.cipher_state_decrypt.k
                    if not initiator.noise_protocol.pattern.one_way:
                        assert initiator.noise_protocol.cipher_state_decrypt.k == responder.noise_protocol.cipher_state_encrypt.k
                    else:
                        assert initiator.noise_protocol.cipher_state_decrypt is responder.noise_protocol.cipher_state_encrypt is None
            else:
                if initiator.noise_protocol.pattern.one_way or initiator_to_responder:
                    sender, receiver = initiator, responder
                else:
                    sender, receiver = responder, initiator
                ciphertext = sender.encrypt(message['payload'])
                assert ciphertext == message['ciphertext']
                plaintext = receiver.decrypt(message['ciphertext'])
                assert plaintext == message['payload']
            initiator_to_responder = not initiator_to_responder