
|
from dissononce.extras.meta.protocol.protocol import NoiseProtocol
from dissononce.extras.meta.protocol.factory import NoiseProtocolFactory
from dissononce.dh.private import PrivateKey
from dissononce.extras.dh.dangerous.dh_nogen import NoGenDH
from tests.structs.vector import Vector, VectorVars, VectorMessage
import json
import os
import binascii
class TestVectors(object):
DIR_VECTORS = os.path.join(os.path.dirname(__file__), 'vectors')
VECTOR_INIT = 'init'
VECTOR_RESP = 'resp'
VECTOR_PROLOGUE = 'prologue'
VECTOR_STATIC = 'static'
VECTOR_EPHEMERAL = 'ephemeral'
VECTOR_REMOTE_STATIC = 'remote_static'
VECTOR_PSKS = 'psks'
VECTOR_HANDSHAKE_HASH = 'handshake_hash'
VECTOR_MESSAGES = 'messages'
VECTOR_MESSAGE_PAYLOAD = 'payload'
VECTOR_MESSAGE_CIPHERTEXT = 'ciphertext'
def pytest_generate_tests(self, metafunc):
vectors_files = [os.path.join(self.DIR_VECTORS, f) for f in os.listdir(self.DIR_VECTORS) if os.path.isfile(os.path.join(self.DIR_VECTORS, f))]
vectors = map(self._read_vectors_file, vectors_files)
relevant_vectors = []
factory = NoiseProtocolFactory()
for v in vectors:
for protocol_vector in v['vectors']:
try:
vector = self._deserialize_vector(protocol_vector)
noiseprotocol = factory.get_noise_protocol(protocol_vector['protocol_name'])
relevant_vectors.append((noiseprotocol, vector))
except ValueError:
pass
metafunc.parametrize(('noiseprotocol', 'vector'), relevant_vectors)
def _read_vectors_file(self, path):
"""
:param path:
:type path: str
:return:
:rtype: dict
"""
with open(path, 'r') as f:
out = json.load(f)
return out
def _get_vector_prop(self, vectordict, initiator, prop, default=None):
"""
:param vectordict:
:type vectordict: dict
:param initiator:
:type initiator: bool | None
:param prop:
:type prop: str
:return:
:rtype:
"""
prefix = self.VECTOR_INIT if initiator == True else self.VECTOR_RESP if initiator == False else None
if prefix is not None:
property = '%s_%s' % (prefix, prop)
else:
property = prop
value = vectordict[property] if property in vectordict else None
return value or default
def _deserialize_vector(self, vectordict):
"""
:param vectordict:
:type vectordict: dict
:return:
:rtype:
"""
init_prologue = self._get_vector_prop(vectordict, True, self.VECTOR_PROLOGUE)
init_static = self._get_vector_prop(vectordict, True, self.VECTOR_STATIC)
init_ephemeral = self._get_vector_prop(vectordict, True, self.VECTOR_EPHEMERAL)
init_remote_static = self._get_vector_prop(vectordict, True, self.VECTOR_REMOTE_STATIC)
init_psks = self._get_vector_prop(vectordict, True, self.VECTOR_PSKS, default=[])
resp_prologue = self._get_vector_prop(vectordict, False, self.VECTOR_PROLOGUE)
resp_static = self._get_vector_prop(vectordict, False, self.VECTOR_STATIC)
resp_ephemeral = self._get_vector_prop(vectordict, False, self.VECTOR_EPHEMERAL)
resp_remote_static = self._get_vector_prop(vectordict, False, self.VECTOR_REMOTE_STATIC)
resp_psks = self._get_vector_prop(vectordict, False, self.VECTOR_PSKS, default=[])
handshake_hash = self._get_vector_prop(vectordict, None, self.VECTOR_HANDSHAKE_HASH)
messages = self._get_vector_prop(vectordict, None, self.VECTOR_MESSAGES, default=[])
return Vector (
init_vectorvars=VectorVars(
prologue=binascii.unhexlify(init_prologue) if init_prologue else None,
s=PrivateKey(binascii.unhexlify(init_static)) if init_static else None,
e=PrivateKey(binascii.unhexlify(init_ephemeral)) if init_ephemeral else None,
rs=PrivateKey(binascii.unhexlify(init_remote_static)) if init_remote_static else None,
psks=tuple([binascii.unhexlify(psk) for psk in init_psks])
),
resp_vectorvars=VectorVars(
prologue=binascii.unhexlify(resp_prologue) if resp_prologue else None,
s=PrivateKey(binascii.unhexlify(resp_static)) if resp_static else None,
e=PrivateKey(binascii.unhexlify(resp_ephemeral)) if resp_ephemeral else None,
rs=PrivateKey(binascii.unhexlify(resp_remote_static)) if resp_remote_static else None,
psks=tuple([binascii.unhexlify(psk) for psk in resp_psks])
),
handshake_hash=binascii.unhexlify(handshake_hash) if handshake_hash else None,
messages=[
VectorMessage(
binascii.unhexlify(message[self.VECTOR_MESSAGE_PAYLOAD]),
binascii.unhexlify(message[self.VECTOR_MESSAGE_CIPHERTEXT])
) for message in messages
]
)
def test_noise_protocol(self, noiseprotocol, vector):
"""
:param noiseprotocol:
:type noiseprotocol: NoiseProtocol
:type vector: Vector
:return:
:rtype:
"""
init_dh = NoGenDH(noiseprotocol.dh, vector.init_vectorvars.e)
resp_dh = NoGenDH(noiseprotocol.dh, vector.resp_vectorvars.e)
init_protocol = NoiseProtocol(noiseprotocol.pattern, init_dh, noiseprotocol.cipher, noiseprotocol.hash)
resp_protocol = NoiseProtocol(noiseprotocol.pattern, resp_dh, noiseprotocol.cipher, noiseprotocol.hash)
init_protocol_handshakestate = init_protocol.create_handshakestate()
resp_protocol_handshakestate = resp_protocol.create_handshakestate()
init_s = init_dh.generate_keypair(vector.init_vectorvars.s)
init_rs = noiseprotocol.dh.create_public(vector.init_vectorvars.rs.data) if vector.init_vectorvars.rs else None
resp_s = resp_dh.generate_keypair(vector.resp_vectorvars.s)
resp_rs = noiseprotocol.dh.create_public(vector.resp_vectorvars.rs.data) if vector.resp_vectorvars.rs else None
init_protocol_handshakestate.initialize(
handshake_pattern=noiseprotocol.pattern,
initiator=True,
prologue=vector.init_vectorvars.prologue,
s=init_s,
rs=init_rs,
psks=vector.init_vectorvars.psks
)
resp_protocol_handshakestate.initialize(
handshake_pattern=noiseprotocol.pattern,
initiator=False,
prologue=vector.resp_vectorvars.prologue,
s=resp_s,
rs=resp_rs,
psks=vector.resp_vectorvars.psks
)
init_cipherstates = None
resp_cipherstates = None
transport_messages_offset = 0
for i in range(0, len(vector.messages)):
message = vector.messages[i]
message_buffer = bytearray()
payload_buffer = bytearray()
if i % 2 == 0:
if init_cipherstates is None:
init_cipherstates = init_protocol_handshakestate.write_message(message.payload, message_buffer)
if resp_cipherstates is None:
resp_cipherstates = resp_protocol_handshakestate.read_message(bytes(message_buffer), payload_buffer)
else:
if resp_cipherstates is None:
resp_cipherstates = resp_protocol_handshakestate.write_message(message.payload, message_buffer)
if init_cipherstates is None:
init_cipherstates = init_protocol_handshakestate.read_message(bytes(message_buffer), payload_buffer)
if init_cipherstates is not None and resp_cipherstates is not None:
transport_messages_offset = i+1
break
else:
assert message.ciphertext == message_buffer
assert message.payload == payload_buffer
if vector.handshake_hash:
assert init_protocol_handshakestate.symmetricstate.get_handshake_hash() == vector.handshake_hash
assert resp_protocol_handshakestate.symmetricstate.get_handshake_hash() == vector.handshake_hash
for i in range(transport_messages_offset, len(vector.messages)):
message = vector.messages[i]
if init_protocol.oneway or i % 2 == 0:
assert message.ciphertext == init_cipherstates[0].encrypt_with_ad(b'', message.payload)
assert message.payload == resp_cipherstates[0].decrypt_with_ad(b'', message.ciphertext)
else:
assert message.ciphertext == resp_cipherstates[1].encrypt_with_ad(b'', message.payload)
assert message.payload == init_cipherstates[1].decrypt_with_ad(b'', message.ciphertext)
|