1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
|
import json
import logging
import os
import pytest
from noiseprotocol.connection import NoiseConnection, Keypair
logger = logging.getLogger(__name__)
vector_files = [
'vectors/cacophony.txt',
'vectors/snow-multipsk.txt'
]
# As in test vectors specification (https://github.com/noiseprotocol/noise_wiki/wiki/Test-vectors)
# We use this to cast read strings into bytes
byte_field = 'protocol_name'
hexbyte_fields = ('init_prologue', 'init_static', 'init_ephemeral', 'init_remote_static', 'resp_static',
'resp_prologue', 'resp_ephemeral', 'resp_remote_static', 'handshake_hash')
list_fields = ('init_psks', 'resp_psks')
dict_field = 'messages'
def _prepare_test_vectors():
vectors = []
for path in vector_files:
with open(os.path.join(os.path.dirname(__file__), path)) as fd:
logging.info('Reading vectors from file {}'.format(path))
vectors_list = json.load(fd)
for vector in vectors_list:
for key, value in vector.copy().items():
if key == byte_field:
vector[key] = value.encode()
if key in hexbyte_fields:
vector[key] = bytes.fromhex(value)
if key in list_fields:
vector[key] = [bytes.fromhex(k) for k in value]
if key == dict_field:
vector[key] = []
for dictionary in value:
vector[key].append({k: bytes.fromhex(v) for k, v in dictionary.items()})
vectors.append(vector)
return vectors
def idfn(vector):
return vector['protocol_name']
@pytest.mark.filterwarnings('ignore: This implementation of ed448')
@pytest.mark.filterwarnings('ignore: One of ephemeral keypairs')
class TestVectors(object):
@pytest.fixture(params=_prepare_test_vectors(), ids=idfn)
def vector(self, request):
yield request.param
def _set_keypairs(self, vector, connection):
role = 'init' if connection.noise_protocol.initiator else 'resp'
setters = [
(connection.set_keypair_from_private_bytes, Keypair.STATIC, role + '_static'),
(connection.set_keypair_from_private_bytes, Keypair.EPHEMERAL, role + '_ephemeral'),
(connection.set_keypair_from_public_bytes, Keypair.REMOTE_STATIC, role + '_remote_static')
]
for fn, keypair, name in setters:
if name in vector:
fn(keypair, vector[name])
def test_vector(self, vector):
initiator = NoiseConnection.from_name(vector['protocol_name'])
responder = NoiseConnection.from_name(vector['protocol_name'])
if 'init_psks' in vector and 'resp_psks' in vector:
initiator.set_psks(psks=vector['init_psks'])
responder.set_psks(psks=vector['resp_psks'])
initiator.set_prologue(vector['init_prologue'])
initiator.set_as_initiator()
self._set_keypairs(vector, initiator)
responder.set_prologue(vector['resp_prologue'])
responder.set_as_responder()
self._set_keypairs(vector, responder)
initiator.start_handshake()
responder.start_handshake()
initiator_to_responder = True
handshake_finished = False
for message in vector['messages']:
if not handshake_finished:
if initiator_to_responder:
sender, receiver = initiator, responder
else:
sender, receiver = responder, initiator
sender_result = sender.write_message(message['payload'])
assert sender_result == message['ciphertext']
receiver_result = receiver.read_message(sender_result)
assert receiver_result == message['payload']
if not (sender.handshake_finished and receiver.handshake_finished):
# Not finished with handshake, fail if one would finish before other
assert sender.handshake_finished == receiver.handshake_finished
else:
# Handshake done
handshake_finished = True
# Verify handshake hash
if 'handshake_hash' in vector:
assert initiator.noise_protocol.handshake_hash == responder.noise_protocol.handshake_hash == vector['handshake_hash']
# Verify split cipherstates keys
assert initiator.noise_protocol.cipher_state_encrypt.k == responder.noise_protocol.cipher_state_decrypt.k
if not initiator.noise_protocol.pattern.one_way:
assert initiator.noise_protocol.cipher_state_decrypt.k == responder.noise_protocol.cipher_state_encrypt.k
else:
assert initiator.noise_protocol.cipher_state_decrypt is responder.noise_protocol.cipher_state_encrypt is None
else:
if initiator.noise_protocol.pattern.one_way or initiator_to_responder:
sender, receiver = initiator, responder
else:
sender, receiver = responder, initiator
ciphertext = sender.encrypt(message['payload'])
assert ciphertext == message['ciphertext']
plaintext = receiver.decrypt(message['ciphertext'])
assert plaintext == message['payload']
initiator_to_responder = not initiator_to_responder
|