File: recovery.py

package info (click to toggle)
python-shamir-mnemonic 0.3.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 244 kB
  • sloc: python: 1,173; makefile: 34
file content (109 lines) | stat: -rw-r--r-- 3,905 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from collections import defaultdict
from dataclasses import dataclass, field, replace
from typing import Any, Dict, Optional, Tuple

from .constants import GROUP_PREFIX_LENGTH_WORDS
from .shamir import ShareGroup, recover_ems
from .share import Share, ShareCommonParameters
from .utils import MnemonicError

UNDETERMINED = -1


class RecoveryState:
    """Object for keeping track of running Shamir recovery."""

    def __init__(self) -> None:
        self.last_share: Optional[Share] = None
        self.groups: Dict[int, ShareGroup] = defaultdict(ShareGroup)
        self.parameters: Optional[ShareCommonParameters] = None

    def group_prefix(self, group_index: int) -> str:
        """Return three starting words of a given group."""
        if not self.last_share:
            raise RuntimeError("Add at least one share first")

        fake_share = replace(self.last_share, group_index=group_index)
        return " ".join(fake_share.words()[:GROUP_PREFIX_LENGTH_WORDS])

    def group_status(self, group_index: int) -> Tuple[int, int]:
        """Return completion status of given group.

        Result consists of the number of shares already entered, and the threshold
        for recovering the group.
        """
        group = self.groups[group_index]
        if not group:
            return 0, UNDETERMINED

        return len(group), group.member_threshold()

    def group_is_complete(self, group_index: int) -> bool:
        """Check whether a given group is already complete."""
        return self.groups[group_index].is_complete()

    def groups_complete(self) -> int:
        """Return the number of groups that are already complete."""
        if self.parameters is None:
            return 0

        return sum(
            self.group_is_complete(i) for i in range(self.parameters.group_count)
        )

    def is_complete(self) -> bool:
        """Check whether the recovery set is complete.

        That is, at least M groups must be complete, where M is the global threshold.
        """
        if self.parameters is None:
            return False
        return self.groups_complete() >= self.parameters.group_threshold

    def matches(self, share: Share) -> bool:
        """Check whether the provided share matches the current set, i.e., has the same
        common parameters.
        """
        if self.parameters is None:
            return True
        return share.common_parameters() == self.parameters

    def add_share(self, share: Share) -> bool:
        """Add a share to the recovery set."""
        if not self.matches(share):
            raise MnemonicError(
                "This mnemonic is not part of the current set. Please try again."
            )
        self.groups[share.group_index].add(share)
        self.last_share = share
        if self.parameters is None:
            self.parameters = share.common_parameters()
        return True

    def __contains__(self, obj: Any) -> bool:
        if not isinstance(obj, Share):
            return False

        if not self.matches(obj):
            return False

        if not self.groups:
            return False

        return obj in self.groups[obj.group_index]

    def recover(self, passphrase: bytes) -> bytes:
        """Recover the master secret, given a passphrase."""
        # Select a subset of shares which meets the thresholds.
        reduced_groups: Dict[int, ShareGroup] = {}
        for group_index, group in self.groups.items():
            if group.is_complete():
                reduced_groups[group_index] = group.get_minimal_group()

            # some groups have been added so parameters must be known
            assert self.parameters is not None
            if len(reduced_groups) >= self.parameters.group_threshold:
                break

        encrypted_master_secret = recover_ems(reduced_groups)
        return encrypted_master_secret.decrypt(passphrase)