File: cipher.py

package info (click to toggle)
python-shamir-mnemonic 0.3.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 244 kB
  • sloc: python: 1,173; makefile: 34
file content (73 lines) | stat: -rw-r--r-- 2,100 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import hashlib

from .constants import (
    BASE_ITERATION_COUNT,
    CUSTOMIZATION_STRING_ORIG,
    ID_LENGTH_BITS,
    ROUND_COUNT,
)
from .utils import bits_to_bytes


def _xor(a: bytes, b: bytes) -> bytes:
    return bytes(x ^ y for x, y in zip(a, b))


def _round_function(i: int, passphrase: bytes, e: int, salt: bytes, r: bytes) -> bytes:
    """The round function used internally by the Feistel cipher."""
    return hashlib.pbkdf2_hmac(
        "sha256",
        bytes([i]) + passphrase,
        salt + r,
        (BASE_ITERATION_COUNT << e) // ROUND_COUNT,
        dklen=len(r),
    )


def _get_salt(identifier: int, extendable: bool) -> bytes:
    if extendable:
        return bytes()
    identifier_len = bits_to_bytes(ID_LENGTH_BITS)
    return CUSTOMIZATION_STRING_ORIG + identifier.to_bytes(identifier_len, "big")


def encrypt(
    master_secret: bytes,
    passphrase: bytes,
    iteration_exponent: int,
    identifier: int,
    extendable: bool,
) -> bytes:
    if len(master_secret) % 2 != 0:
        raise ValueError(
            "The length of the master secret in bytes must be an even number."
        )

    l = master_secret[: len(master_secret) // 2]
    r = master_secret[len(master_secret) // 2 :]
    salt = _get_salt(identifier, extendable)
    for i in range(ROUND_COUNT):
        f = _round_function(i, passphrase, iteration_exponent, salt, r)
        l, r = r, _xor(l, f)
    return r + l


def decrypt(
    encrypted_master_secret: bytes,
    passphrase: bytes,
    iteration_exponent: int,
    identifier: int,
    extendable: bool,
) -> bytes:
    if len(encrypted_master_secret) % 2 != 0:
        raise ValueError(
            "The length of the encrypted master secret in bytes must be an even number."
        )

    l = encrypted_master_secret[: len(encrypted_master_secret) // 2]
    r = encrypted_master_secret[len(encrypted_master_secret) // 2 :]
    salt = _get_salt(identifier, extendable)
    for i in reversed(range(ROUND_COUNT)):
        f = _round_function(i, passphrase, iteration_exponent, salt, r)
        l, r = r, _xor(l, f)
    return r + l