1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380
|
import base64
import copy
import json
import os
from typing import Set, Dict, Any, List, Tuple
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.x448 import X448PrivateKey
from doubleratchet import (
AuthenticationFailedException,
DoubleRatchet as DR,
DuplicateMessageException,
EncryptedMessage,
Header,
InconsistentSerializationException
)
from doubleratchet.recommended import (
aead_aes_hmac,
diffie_hellman_ratchet_curve25519 as dhr25519,
diffie_hellman_ratchet_curve448 as dhr448,
HashFunction,
kdf_hkdf
)
from .test_recommended_kdfs import generate_unique_random_data
__all__ = [
"test_double_ratchet",
"test_migrations"
]
try:
import pytest
except ImportError:
pass
else:
pytestmark = pytest.mark.asyncio
class RootChainKDF(kdf_hkdf.KDF):
"""
The root chain KDF to use while testing.
"""
@staticmethod
def _get_hash_function() -> HashFunction:
return HashFunction.SHA_256
@staticmethod
def _get_info() -> bytes:
return "test_double_ratchet Root Chain KDF info".encode("ASCII")
class MessageChainKDF(kdf_hkdf.KDF):
"""
The message chain KDF to use while testing.
"""
@staticmethod
def _get_hash_function() -> HashFunction:
return HashFunction.SHA_512_256
@staticmethod
def _get_info() -> bytes:
return "test_double_ratchet Message Chain KDF info".encode("ASCII")
class AEAD(aead_aes_hmac.AEAD):
"""
The AEAD to use while testing.
"""
@staticmethod
def _get_hash_function() -> HashFunction:
return HashFunction.SHA_512
@staticmethod
def _get_info() -> bytes:
return "test_double_ratchet AEAD info".encode("ASCII")
class DoubleRatchet(DR):
"""
The Double Ratchet to use while testing.
"""
@staticmethod
def _build_associated_data(associated_data: bytes, header: Header) -> bytes:
return (
associated_data
+ header.ratchet_pub
+ header.sending_chain_length.to_bytes(8, "big")
+ header.previous_sending_chain_length.to_bytes(8, "big")
)
drc: Dict[str, Any] = {
"diffie_hellman_ratchet_class": dhr448.DiffieHellmanRatchet,
"root_chain_kdf": RootChainKDF,
"message_chain_kdf": MessageChainKDF,
"message_chain_constant": "test_double_ratchet Message Chain constant".encode("ASCII"),
"dos_protection_threshold": 10,
"max_num_skipped_message_keys": 15,
"aead": AEAD
}
async def test_double_ratchet() -> None:
"""
Test the Double Ratchet implementation.
"""
shared_secret_set: Set[bytes] = set()
message_set: Set[bytes] = set()
ad_set: Set[bytes] = set()
# for _ in range(200):
for _ in range(10):
bob_ratchet_priv = X448PrivateKey.generate()
bob_ratchet_pub = bob_ratchet_priv.public_key().public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw
)
shared_secret = generate_unique_random_data(32, 32 + 1, shared_secret_set)
message = generate_unique_random_data(0, 2 ** 16, message_set)
ad = generate_unique_random_data(0, 2 ** 16, ad_set)
# Test that passing a shared secret which doesn't consist of 32 bytes raises an exception:
try:
await DoubleRatchet.encrypt_initial_message(
shared_secret=b"\x00" * 64,
recipient_ratchet_pub=bob_ratchet_pub,
message=message,
associated_data=ad,
**drc
)
assert False
except ValueError as e:
assert "shared secret" in str(e)
assert "32 bytes" in str(e)
# Test that passing a DoS protection threshold higher than the maximum number of skipped message key
# raises an exception:
try:
drc_copy = copy.copy(drc)
drc_copy["dos_protection_threshold"] = 20
await DoubleRatchet.encrypt_initial_message(
shared_secret=shared_secret,
recipient_ratchet_pub=bob_ratchet_pub,
message=message,
associated_data=ad,
**drc_copy
)
assert False
except ValueError as e:
assert "dos_protection_threshold" in str(e)
assert "bigger than" in str(e)
assert "max_num_skipped_message_keys" in str(e)
# Encrypt an initial message from Alice to Bob
alice_dr, encrypted_message = await DoubleRatchet.encrypt_initial_message(
shared_secret=shared_secret,
recipient_ratchet_pub=bob_ratchet_pub,
message=message,
associated_data=ad,
**drc
)
bob_dr, plaintext = await DoubleRatchet.decrypt_initial_message(
shared_secret=shared_secret,
own_ratchet_priv=bob_ratchet_priv.private_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PrivateFormat.Raw,
encryption_algorithm=serialization.NoEncryption()
),
message=encrypted_message,
associated_data=ad,
**drc
)
assert alice_dr.sending_chain_length == 1
assert alice_dr.receiving_chain_length is None
assert bob_dr.sending_chain_length == 0
assert bob_dr.receiving_chain_length == 1
assert plaintext == message
# Send a message back and forth
assert await bob_dr.decrypt_message(await alice_dr.encrypt_message(message, ad), ad) == message
assert await alice_dr.decrypt_message(await bob_dr.encrypt_message(message, ad), ad) == message
# Make sure that each ciphertext is different even though the message is always the same:
encrypted_message_1 = await alice_dr.encrypt_message(message, ad)
assert await bob_dr.decrypt_message(encrypted_message_1, ad) == message
encrypted_message_2 = await alice_dr.encrypt_message(message, ad)
assert await bob_dr.decrypt_message(encrypted_message_2, ad) == message
assert encrypted_message_1.ciphertext != encrypted_message_2.ciphertext
# Test the first case of skipped messages (without a Diffie-Hellman ratchet step):
skipped_message_1 = await alice_dr.encrypt_message(message, ad)
skipped_message_2 = await alice_dr.encrypt_message(message, ad)
assert await bob_dr.decrypt_message(await alice_dr.encrypt_message(message, ad), ad) == message
assert await bob_dr.decrypt_message(skipped_message_2, ad) == message
assert await bob_dr.decrypt_message(skipped_message_1, ad) == message
# Test the second case of skipped messages (with Diffie-Hellman ratchet steps):
skipped_message_1 = await alice_dr.encrypt_message(message, ad)
skipped_message_2 = await alice_dr.encrypt_message(message, ad)
assert await alice_dr.decrypt_message(await bob_dr.encrypt_message(message, ad), ad) == message
assert await bob_dr.decrypt_message(await alice_dr.encrypt_message(message, ad), ad) == message
assert await alice_dr.decrypt_message(await bob_dr.encrypt_message(message, ad), ad) == message
skipped_message_3 = await alice_dr.encrypt_message(message, ad)
skipped_message_4 = await alice_dr.encrypt_message(message, ad)
assert await bob_dr.decrypt_message(await alice_dr.encrypt_message(message, ad), ad) == message
assert await bob_dr.decrypt_message(skipped_message_4, ad) == message
assert await bob_dr.decrypt_message(skipped_message_3, ad) == message
assert await bob_dr.decrypt_message(skipped_message_2, ad) == message
assert await bob_dr.decrypt_message(skipped_message_1, ad) == message
# Test that only the last 15 skipped message keys are kept around:
skipped_message = await alice_dr.encrypt_message(message, ad) # Skipped messages: 1
for _ in range(7):
await alice_dr.encrypt_message(message, ad) # Skipped messages: 8
assert await bob_dr.decrypt_message(await alice_dr.encrypt_message(message, ad), ad) == message
for _ in range(7):
await alice_dr.encrypt_message(message, ad) # Skipped messages: 15
assert await bob_dr.decrypt_message(await alice_dr.encrypt_message(message, ad), ad) == message
assert await bob_dr.decrypt_message(skipped_message, ad) == message
skipped_message = await alice_dr.encrypt_message(message, ad) # Skipped messages: 1
for _ in range(7):
await alice_dr.encrypt_message(message, ad) # Skipped messages: 8
assert await bob_dr.decrypt_message(await alice_dr.encrypt_message(message, ad), ad) == message
for _ in range(8):
await alice_dr.encrypt_message(message, ad) # Skipped messages: 16
assert await bob_dr.decrypt_message(await alice_dr.encrypt_message(message, ad), ad) == message
try:
await bob_dr.decrypt_message(skipped_message, ad)
assert False
except DuplicateMessageException:
pass
# Test decrypting a message twice, before a Diffie-Hellman ratchet step. This should throw a
# DuplicateMessageException and leave the ratchet intact:
encrypted_message = await alice_dr.encrypt_message(message, ad)
assert await bob_dr.decrypt_message(encrypted_message, ad) == message
try:
await bob_dr.decrypt_message(encrypted_message, ad)
assert False
except DuplicateMessageException:
pass
# Test decrypting a message twice, after a Diffie-Hellman ratchet step. The Double Ratchet does not
# detect this, thus no DuplicateMessageException should be raise, but a generic
# AuthenticationFailedException:
assert await alice_dr.decrypt_message(await bob_dr.encrypt_message(message, ad), ad) == message
assert await bob_dr.decrypt_message(await alice_dr.encrypt_message(message, ad), ad) == message
assert await alice_dr.decrypt_message(await bob_dr.encrypt_message(message, ad), ad) == message
encrypted_message = await alice_dr.encrypt_message(message, ad)
assert await bob_dr.decrypt_message(await alice_dr.encrypt_message(message, ad), ad) == message
assert await alice_dr.decrypt_message(await bob_dr.encrypt_message(message, ad), ad) == message
assert await bob_dr.decrypt_message(await alice_dr.encrypt_message(message, ad), ad) == message
assert await bob_dr.decrypt_message(encrypted_message, ad) == message
try:
await bob_dr.decrypt_message(encrypted_message, ad)
assert False
except AuthenticationFailedException:
pass
# Even after this failure, the ratchets should still work as before:
assert await alice_dr.decrypt_message(await bob_dr.encrypt_message(message, ad), ad) == message
assert await alice_dr.decrypt_message(await bob_dr.encrypt_message(message, ad), ad) == message
assert await bob_dr.decrypt_message(await alice_dr.encrypt_message(message, ad), ad) == message
assert await bob_dr.decrypt_message(await alice_dr.encrypt_message(message, ad), ad) == message
# Test that (de)serialization doesn't damage the instances:
assert await bob_dr.decrypt_message(await alice_dr.encrypt_message(message, ad), ad) == message
alice_dr = DoubleRatchet.from_json(alice_dr.json, **drc)
assert await bob_dr.decrypt_message(await alice_dr.encrypt_message(message, ad), ad) == message
bob_dr = DoubleRatchet.from_json(bob_dr.json, **drc)
assert await bob_dr.decrypt_message(await alice_dr.encrypt_message(message, ad), ad) == message
skipped_message = await alice_dr.encrypt_message(message, ad)
assert await bob_dr.decrypt_message(await alice_dr.encrypt_message(message, ad), ad) == message
bob_dr = DoubleRatchet.from_json(bob_dr.json, **drc)
assert await bob_dr.decrypt_message(skipped_message, ad) == message
# Test that (de)serialization can be used to decrypt the same message twice:
encrypted_message = await alice_dr.encrypt_message(message, ad)
bob_dr_serialized = bob_dr.json
assert await bob_dr.decrypt_message(encrypted_message, ad) == message
bob_dr = DoubleRatchet.from_json(bob_dr_serialized, **drc)
assert await bob_dr.decrypt_message(encrypted_message, ad) == message
MIGRATION_DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "migration_data")
async def test_migrations() -> None:
"""
Test serialization format migrations.
"""
double_ratchet_configuration: Dict[str, Any] = {
"diffie_hellman_ratchet_class": dhr25519.DiffieHellmanRatchet,
"root_chain_kdf": RootChainKDF,
"message_chain_kdf": MessageChainKDF,
"message_chain_constant": "test_double_ratchet Message Chain constant".encode("ASCII"),
"dos_protection_threshold": 10,
"max_num_skipped_message_keys": 15,
"aead": AEAD
}
associated_data = "test_double_ratchet associated data".encode("ASCII")
with open(os.path.join(MIGRATION_DATA_DIR, "dr-alice-pre-stable.json"), encoding="utf-8") as file:
alice_dr_serialized = json.load(file)
with open(os.path.join(MIGRATION_DATA_DIR, "dr-bob-pre-stable.json"), encoding="utf-8") as file:
bob_dr_serialized = json.load(file)
with open(os.path.join(MIGRATION_DATA_DIR, "uninitialized-dr-pre-stable.json"), encoding="utf-8") as file:
uninitialized_dr_serialized = json.load(file)
with open(
os.path.join(MIGRATION_DATA_DIR, "alice-skipped-messages-pre-stable.json"),
encoding="utf-8"
) as file:
alice_skipped_messages_serialized = json.load(file)
with open(
os.path.join(MIGRATION_DATA_DIR, "bob-skipped-messages-pre-stable.json"),
encoding="utf-8"
) as file:
bob_skipped_messages_serialized = json.load(file)
alice_skipped_messages: List[Tuple[EncryptedMessage, bytes]] = []
for skipped_message, plaintext in alice_skipped_messages_serialized:
alice_skipped_messages.append((
EncryptedMessage(
header=Header(
ratchet_pub=base64.b64decode(skipped_message["header"]["ratchet_pub"].encode("ASCII")),
previous_sending_chain_length=skipped_message["header"]["pn"],
sending_chain_length=skipped_message["header"]["n"]
),
ciphertext=base64.b64decode(skipped_message["ciphertext"].encode("ASCII"))
),
base64.b64decode(plaintext.encode("ASCII"))
))
bob_skipped_messages: List[Tuple[EncryptedMessage, bytes]] = []
for skipped_message, plaintext in bob_skipped_messages_serialized:
bob_skipped_messages.append((
EncryptedMessage(
header=Header(
ratchet_pub=base64.b64decode(skipped_message["header"]["ratchet_pub"].encode("ASCII")),
previous_sending_chain_length=skipped_message["header"]["pn"],
sending_chain_length=skipped_message["header"]["n"]
),
ciphertext=base64.b64decode(skipped_message["ciphertext"].encode("ASCII"))
),
base64.b64decode(plaintext.encode("ASCII"))
))
# Verify that the uninitialized ratchet data can't be migrated
try:
DoubleRatchet.from_json(uninitialized_dr_serialized, **double_ratchet_configuration)
assert False
except InconsistentSerializationException:
pass
# Migrate the two valid ratchets
alice_dr = DoubleRatchet.from_json(alice_dr_serialized, **double_ratchet_configuration)
bob_dr = DoubleRatchet.from_json(bob_dr_serialized, **double_ratchet_configuration)
# Verify that skipped messages can be correctly decrypted using the restored instances
for encrypted_message, plaintext in alice_skipped_messages:
assert await bob_dr.decrypt_message(encrypted_message, associated_data) == plaintext
for encrypted_message, plaintext in bob_skipped_messages:
assert await alice_dr.decrypt_message(encrypted_message, associated_data) == plaintext
|