1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
|
import json
import secrets
from itertools import combinations
from random import shuffle
import pytest
from bip32utils import BIP32Key
import shamir_mnemonic as shamir
from shamir_mnemonic import MnemonicError
MS = b"ABCDEFGHIJKLMNOP"
def test_basic_sharing_random():
secret = secrets.token_bytes(16)
mnemonics = shamir.generate_mnemonics(1, [(3, 5)], secret)[0]
assert shamir.combine_mnemonics(mnemonics[:3]) == shamir.combine_mnemonics(
mnemonics[2:]
)
def test_basic_sharing_fixed():
mnemonics = shamir.generate_mnemonics(1, [(3, 5)], MS)[0]
assert MS == shamir.combine_mnemonics(mnemonics[:3])
assert MS == shamir.combine_mnemonics(mnemonics[1:4])
with pytest.raises(MnemonicError):
shamir.combine_mnemonics(mnemonics[1:3])
def test_passphrase():
mnemonics = shamir.generate_mnemonics(1, [(3, 5)], MS, b"TREZOR")[0]
assert MS == shamir.combine_mnemonics(mnemonics[1:4], b"TREZOR")
assert MS != shamir.combine_mnemonics(mnemonics[1:4])
def test_non_extendable():
mnemonics = shamir.generate_mnemonics(1, [(3, 5)], MS, extendable=False)[0]
assert MS == shamir.combine_mnemonics(mnemonics[1:4])
def test_iteration_exponent():
mnemonics = shamir.generate_mnemonics(
1, [(3, 5)], MS, b"TREZOR", iteration_exponent=1
)[0]
assert MS == shamir.combine_mnemonics(mnemonics[1:4], b"TREZOR")
assert MS != shamir.combine_mnemonics(mnemonics[1:4])
mnemonics = shamir.generate_mnemonics(
1, [(3, 5)], MS, b"TREZOR", iteration_exponent=2
)[0]
assert MS == shamir.combine_mnemonics(mnemonics[1:4], b"TREZOR")
assert MS != shamir.combine_mnemonics(mnemonics[1:4])
def test_group_sharing():
group_threshold = 2
group_sizes = (5, 3, 5, 1)
member_thresholds = (3, 2, 2, 1)
mnemonics = shamir.generate_mnemonics(
group_threshold, list(zip(member_thresholds, group_sizes)), MS
)
# Test all valid combinations of mnemonics.
for groups in combinations(zip(mnemonics, member_thresholds), group_threshold):
for group1_subset in combinations(groups[0][0], groups[0][1]):
for group2_subset in combinations(groups[1][0], groups[1][1]):
mnemonic_subset = list(group1_subset + group2_subset)
shuffle(mnemonic_subset)
assert MS == shamir.combine_mnemonics(mnemonic_subset)
# Minimal sets of mnemonics.
assert MS == shamir.combine_mnemonics(
[mnemonics[2][0], mnemonics[2][2], mnemonics[3][0]]
)
assert MS == shamir.combine_mnemonics(
[mnemonics[2][3], mnemonics[3][0], mnemonics[2][4]]
)
# One complete group and one incomplete group out of two groups required.
with pytest.raises(MnemonicError):
shamir.combine_mnemonics(mnemonics[0][2:] + [mnemonics[1][0]])
# One group of two required.
with pytest.raises(MnemonicError):
shamir.combine_mnemonics(mnemonics[0][1:4])
def test_group_sharing_threshold_1():
group_threshold = 1
group_sizes = (5, 3, 5, 1)
member_thresholds = (3, 2, 2, 1)
mnemonics = shamir.generate_mnemonics(
group_threshold, list(zip(member_thresholds, group_sizes)), MS
)
# Test all valid combinations of mnemonics.
for group, member_threshold in zip(mnemonics, member_thresholds):
for group_subset in combinations(group, member_threshold):
mnemonic_subset = list(group_subset)
shuffle(mnemonic_subset)
assert MS == shamir.combine_mnemonics(mnemonic_subset)
def test_all_groups_exist():
for group_threshold in (1, 2, 5):
mnemonics = shamir.generate_mnemonics(
group_threshold, [(3, 5), (1, 1), (2, 3), (2, 5), (3, 5)], MS
)
assert len(mnemonics) == 5
assert len(sum(mnemonics, [])) == 19
def test_invalid_sharing():
# Short master secret.
with pytest.raises(ValueError):
shamir.generate_mnemonics(1, [(2, 3)], MS[:14])
# Odd length master secret.
with pytest.raises(ValueError):
shamir.generate_mnemonics(1, [(2, 3)], MS + b"X")
# Group threshold exceeds number of groups.
with pytest.raises(ValueError):
shamir.generate_mnemonics(3, [(3, 5), (2, 5)], MS)
# Invalid group threshold.
with pytest.raises(ValueError):
shamir.generate_mnemonics(0, [(3, 5), (2, 5)], MS)
# Member threshold exceeds number of members.
with pytest.raises(ValueError):
shamir.generate_mnemonics(2, [(3, 2), (2, 5)], MS)
# Invalid member threshold.
with pytest.raises(ValueError):
shamir.generate_mnemonics(2, [(0, 2), (2, 5)], MS)
# Group with multiple members and member threshold 1.
with pytest.raises(ValueError):
shamir.generate_mnemonics(2, [(3, 5), (1, 3), (2, 5)], MS)
def test_vectors():
with open("vectors.json", "r") as f:
vectors = json.load(f)
for description, mnemonics, secret_hex, xprv in vectors:
if secret_hex:
secret = bytes.fromhex(secret_hex)
assert secret == shamir.combine_mnemonics(
mnemonics, b"TREZOR"
), 'Incorrect secret for test vector "{}".'.format(description)
assert (
BIP32Key.fromEntropy(secret).ExtendedKey() == xprv
), 'Incorrect xprv for test vector "{}".'.format(description)
else:
with pytest.raises(MnemonicError):
shamir.combine_mnemonics(mnemonics)
pytest.fail(
'Failed to raise exception for test vector "{}".'.format(
description
)
)
def test_split_ems():
encrypted_master_secret = shamir.EncryptedMasterSecret.from_master_secret(
MS, b"TREZOR", identifier=42, extendable=True, iteration_exponent=1
)
grouped_shares = shamir.split_ems(1, [(3, 5)], encrypted_master_secret)
mnemonics = [share.mnemonic() for share in grouped_shares[0]]
recovered = shamir.combine_mnemonics(mnemonics[:3], b"TREZOR")
assert recovered == MS
def test_recover_ems():
mnemonics = shamir.generate_mnemonics(1, [(3, 5)], MS, b"TREZOR")[0]
groups = shamir.decode_mnemonics(mnemonics[:3])
encrypted_master_secret = shamir.recover_ems(groups)
recovered = encrypted_master_secret.decrypt(b"TREZOR")
assert recovered == MS
|