File: _pysrp.py

package info (click to toggle)
python-proton-core 0.4.0-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 540 kB
  • sloc: python: 3,574; makefile: 15
file content (160 lines) | stat: -rw-r--r-- 5,092 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
Copyright (c) 2023 Proton AG

This file is part of Proton.

Proton is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Proton is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with ProtonVPN.  If not, see <https://www.gnu.org/licenses/>.
"""
# N    A large safe prime (N = 2q+1, where q is prime)
#      All arithmetic is done modulo N.
# g    A generator modulo N
# k    Multiplier parameter (k = H(N, g) in SRP-6a, k = 3 for legacy SRP-6)
# s    User's salt
# I    Username
# p    Cleartext Password
# H()  One-way hash function
# ^    (Modular) Exponentiation
# u    Random scrambling parameter
# a,b  Secret ephemeral values
# A,B  Public ephemeral values
# x    Private key (derived from p and s)
# v    Password verifier
from .pmhash import pmhash
from .util import (PM_VERSION, SRP_LEN_BYTES, SALT_LEN_BYTES, bytes_to_long, custom_hash,
                   get_random_of_length, hash_password, long_to_bytes)


def get_ng(n_bin, g_hex):
    return bytes_to_long(n_bin), int(g_hex, 16)


def hash_k(hash_class, g, modulus, width):
    h = hash_class()
    h.update(g.to_bytes(width, 'little'))
    h.update(modulus.to_bytes(width, 'little'))
    return bytes_to_long(h.digest())


def calculate_x(hash_class, salt, password, modulus, version):
    exp = hash_password(
        hash_class, password, salt, long_to_bytes(modulus, SRP_LEN_BYTES), version
    )
    return bytes_to_long(exp)


def calculate_client_proof(hash_class, A, B, K):
    h = hash_class()
    h.update(long_to_bytes(A, SRP_LEN_BYTES))
    h.update(long_to_bytes(B, SRP_LEN_BYTES))
    h.update(K)
    return h.digest()


def calculate_server_proof(hash_class, A, M, K):
    h = hash_class()
    h.update(long_to_bytes(A, SRP_LEN_BYTES))
    h.update(M)
    h.update(K)
    return h.digest()


class User(object):
    def __init__(self, password, n_bin, g_hex=b"2", bytes_a=None, bytes_A=None): # noqa
        if bytes_a and len(bytes_a) != 32:
            raise ValueError("32 bytes required for bytes_a")

        if not isinstance(password, str) or len(password) == 0:
            raise ValueError("Invalid password")

        self.N, self.g = get_ng(n_bin, g_hex)
        self.hash_class = pmhash
        self.k = hash_k(
            self.hash_class, self.g,
            self.N, SRP_LEN_BYTES
        )

        self.p = password.encode()
        if bytes_a:
            self.a = bytes_to_long(bytes_a)
        else:
            self.a = get_random_of_length(32)
        if bytes_A:
            self.A = bytes_to_long(bytes_A)
        else:
            self.A = pow(self.g, self.a, self.N)
        self.v = None
        self.M = None
        self.K = None
        self.expected_server_proof = None
        self._authenticated = False
        self.bytes_s = None
        self.S = None
        self.B = None
        self.u = None
        self.x = None

    def authenticated(self):
        return self._authenticated

    def get_ephemeral_secret(self):
        return long_to_bytes(self.a, SRP_LEN_BYTES)

    def get_session_key(self):
        return self.K if self._authenticated else None

    def get_challenge(self):
        return long_to_bytes(self.A, SRP_LEN_BYTES)

    # Returns M or None if SRP-6a safety check is violated
    def process_challenge(
        self, bytes_s, bytes_server_challenge, version=PM_VERSION
    ):
        self.bytes_s = bytes_s
        self.B = bytes_to_long(bytes_server_challenge)

        # SRP-6a safety check
        if (self.B % self.N) == 0:
            return None

        self.u = custom_hash(self.hash_class, self.A, self.B)

        # SRP-6a safety check
        if self.u == 0:
            return None

        self.x = calculate_x(self.hash_class, self.bytes_s, self.p, self.N, version)

        self.v = pow(self.g, self.x, self.N)

        self.S = pow(
            (self.B - self.k * self.v), (self.a + self.u * self.x), self.N
        )

        self.K = long_to_bytes(self.S, SRP_LEN_BYTES)
        self.M = calculate_client_proof(self.hash_class, self.A, self.B, self.K) # noqa
        self.expected_server_proof = calculate_server_proof(
            self.hash_class, self.A, self.M, self.K
        )

        return self.M

    def verify_session(self, server_proof):
        if self.expected_server_proof == server_proof:
            self._authenticated = True

    def compute_v(self, bytes_s=None, version=PM_VERSION):
        self.bytes_s = long_to_bytes(get_random_of_length(SALT_LEN_BYTES), SALT_LEN_BYTES) if bytes_s is None else bytes_s # noqa
        self.x = calculate_x(self.hash_class, self.bytes_s, self.p, self.N, version)

        return self.bytes_s, long_to_bytes(pow(self.g, self.x, self.N), SRP_LEN_BYTES) # noqa