File: testserver.py

package info (click to toggle)
python-proton-core 0.4.0-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 540 kB
  • sloc: python: 3,574; makefile: 15
file content (89 lines) | stat: -rw-r--r-- 3,002 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""
Copyright (c) 2023 Proton AG

This file is part of Proton.

Proton is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Proton is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with ProtonVPN.  If not, see <https://www.gnu.org/licenses/>.
"""
from proton.session.srp.pmhash import pmhash
from proton.session.srp.util import (bytes_to_long, custom_hash, get_random_of_length,
                             SRP_LEN_BYTES, long_to_bytes)


class TestServer:
    def setup(self, username, modulus, verifier):
        self.hash_class = pmhash
        self.generator = 2
        self._authenticated = False

        self.user = username.encode()
        self.modulus = bytes_to_long(modulus)
        self.verifier = bytes_to_long(verifier)

        self.b = get_random_of_length(32)
        self.B = (
            self.calculate_k() * self.verifier + pow(
                self.generator, self.b, self.modulus
            )
        ) % self.modulus

        self.secret = None
        self.A = None
        self.u = None
        self.key = None

    def calculate_server_proof(self, client_proof):
        h = self.hash_class()
        h.update(long_to_bytes(self.A, SRP_LEN_BYTES))
        h.update(client_proof)
        h.update(long_to_bytes(self.secret, SRP_LEN_BYTES))
        return h.digest()

    def calculate_client_proof(self):
        h = self.hash_class()
        h.update(long_to_bytes(self.A, SRP_LEN_BYTES))
        h.update(long_to_bytes(self.B, SRP_LEN_BYTES))
        h.update(long_to_bytes(self.secret, SRP_LEN_BYTES))
        return h.digest()

    def calculate_k(self):
        h = self.hash_class()
        h.update(self.generator.to_bytes(SRP_LEN_BYTES, 'little'))
        h.update(long_to_bytes(self.modulus, SRP_LEN_BYTES))
        return bytes_to_long(h.digest())

    def get_challenge(self):
        return long_to_bytes(self.B, SRP_LEN_BYTES)

    def get_session_key(self):
        return long_to_bytes(self.secret, SRP_LEN_BYTES)  # if self._authenticated else None

    def get_authenticated(self):
        return self._authenticated

    def process_challenge(self, client_challenge, client_proof):
        self.A = bytes_to_long(client_challenge)
        self.u = custom_hash(self.hash_class, self.A, self.B)
        self.secret = pow(
            (
                self.A * pow(self.verifier, self.u, self.modulus)
            ),
            self.b, self.modulus
        )

        if client_proof != self.calculate_client_proof():
            return False

        self._authenticated = True
        return self.calculate_server_proof(client_proof)