File: _auth.py

package info (click to toggle)
python-pymysql 0.9.3-2%2Bdeb11u1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 816 kB
  • sloc: python: 6,029; makefile: 153; sh: 37
file content (265 lines) | stat: -rw-r--r-- 7,720 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
"""
Implements auth methods
"""
from ._compat import text_type, PY2
from .constants import CLIENT
from .err import OperationalError
from .util import byte2int, int2byte


try:
    from cryptography.hazmat.backends import default_backend
    from cryptography.hazmat.primitives import serialization, hashes
    from cryptography.hazmat.primitives.asymmetric import padding
    _have_cryptography = True
except ImportError:
    _have_cryptography = False

from functools import partial
import hashlib
import io
import struct
import warnings


DEBUG = False
SCRAMBLE_LENGTH = 20
sha1_new = partial(hashlib.new, 'sha1')


# mysql_native_password
# https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41


def scramble_native_password(password, message):
    """Scramble used for mysql_native_password"""
    if not password:
        return b''

    stage1 = sha1_new(password).digest()
    stage2 = sha1_new(stage1).digest()
    s = sha1_new()
    s.update(message[:SCRAMBLE_LENGTH])
    s.update(stage2)
    result = s.digest()
    return _my_crypt(result, stage1)


def _my_crypt(message1, message2):
    result = bytearray(message1)
    if PY2:
        message2 = bytearray(message2)

    for i in range(len(result)):
        result[i] ^= message2[i]

    return bytes(result)


# old_passwords support ported from libmysql/password.c
# https://dev.mysql.com/doc/internals/en/old-password-authentication.html

SCRAMBLE_LENGTH_323 = 8


class RandStruct_323(object):

    def __init__(self, seed1, seed2):
        self.max_value = 0x3FFFFFFF
        self.seed1 = seed1 % self.max_value
        self.seed2 = seed2 % self.max_value

    def my_rnd(self):
        self.seed1 = (self.seed1 * 3 + self.seed2) % self.max_value
        self.seed2 = (self.seed1 + self.seed2 + 33) % self.max_value
        return float(self.seed1) / float(self.max_value)


def scramble_old_password(password, message):
    """Scramble for old_password"""
    warnings.warn("old password (for MySQL <4.1) is used.  Upgrade your password with newer auth method.\n"
                  "old password support will be removed in future PyMySQL version")
    hash_pass = _hash_password_323(password)
    hash_message = _hash_password_323(message[:SCRAMBLE_LENGTH_323])
    hash_pass_n = struct.unpack(">LL", hash_pass)
    hash_message_n = struct.unpack(">LL", hash_message)

    rand_st = RandStruct_323(
        hash_pass_n[0] ^ hash_message_n[0], hash_pass_n[1] ^ hash_message_n[1]
    )
    outbuf = io.BytesIO()
    for _ in range(min(SCRAMBLE_LENGTH_323, len(message))):
        outbuf.write(int2byte(int(rand_st.my_rnd() * 31) + 64))
    extra = int2byte(int(rand_st.my_rnd() * 31))
    out = outbuf.getvalue()
    outbuf = io.BytesIO()
    for c in out:
        outbuf.write(int2byte(byte2int(c) ^ byte2int(extra)))
    return outbuf.getvalue()


def _hash_password_323(password):
    nr = 1345345333
    add = 7
    nr2 = 0x12345671

    # x in py3 is numbers, p27 is chars
    for c in [byte2int(x) for x in password if x not in (' ', '\t', 32, 9)]:
        nr ^= (((nr & 63) + add) * c) + (nr << 8) & 0xFFFFFFFF
        nr2 = (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF
        add = (add + c) & 0xFFFFFFFF

    r1 = nr & ((1 << 31) - 1)  # kill sign bits
    r2 = nr2 & ((1 << 31) - 1)
    return struct.pack(">LL", r1, r2)


# sha256_password


def _roundtrip(conn, send_data):
    conn.write_packet(send_data)
    pkt = conn._read_packet()
    pkt.check_error()
    return pkt


def _xor_password(password, salt):
    password_bytes = bytearray(password)
    salt = bytearray(salt)  # for PY2 compat.
    salt_len = len(salt)
    for i in range(len(password_bytes)):
        password_bytes[i] ^= salt[i % salt_len]
    return bytes(password_bytes)


def sha2_rsa_encrypt(password, salt, public_key):
    """Encrypt password with salt and public_key.

    Used for sha256_password and caching_sha2_password.
    """
    if not _have_cryptography:
        raise RuntimeError("cryptography is required for sha256_password or caching_sha2_password")
    message = _xor_password(password + b'\0', salt)
    rsa_key = serialization.load_pem_public_key(public_key, default_backend())
    return rsa_key.encrypt(
        message,
        padding.OAEP(
            mgf=padding.MGF1(algorithm=hashes.SHA1()),
            algorithm=hashes.SHA1(),
            label=None,
        ),
    )


def sha256_password_auth(conn, pkt):
    if conn._secure:
        if DEBUG:
            print("sha256: Sending plain password")
        data = conn.password + b'\0'
        return _roundtrip(conn, data)

    if pkt.is_auth_switch_request():
        conn.salt = pkt.read_all()
        if not conn.server_public_key and conn.password:
            # Request server public key
            if DEBUG:
                print("sha256: Requesting server public key")
            pkt = _roundtrip(conn, b'\1')

    if pkt.is_extra_auth_data():
        conn.server_public_key = pkt._data[1:]
        if DEBUG:
            print("Received public key:\n", conn.server_public_key.decode('ascii'))

    if conn.password:
        if not conn.server_public_key:
            raise OperationalError("Couldn't receive server's public key")

        data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key)
    else:
        data = b''

    return _roundtrip(conn, data)


def scramble_caching_sha2(password, nonce):
    # (bytes, bytes) -> bytes
    """Scramble algorithm used in cached_sha2_password fast path.

    XOR(SHA256(password), SHA256(SHA256(SHA256(password)), nonce))
    """
    if not password:
        return b''

    p1 = hashlib.sha256(password).digest()
    p2 = hashlib.sha256(p1).digest()
    p3 = hashlib.sha256(p2 + nonce).digest()

    res = bytearray(p1)
    if PY2:
        p3 = bytearray(p3)
    for i in range(len(p3)):
        res[i] ^= p3[i]

    return bytes(res)


def caching_sha2_password_auth(conn, pkt):
    # No password fast path
    if not conn.password:
        return _roundtrip(conn, b'')

    if pkt.is_auth_switch_request():
        # Try from fast auth
        if DEBUG:
            print("caching sha2: Trying fast path")
        conn.salt = pkt.read_all()
        scrambled = scramble_caching_sha2(conn.password, conn.salt)
        pkt = _roundtrip(conn, scrambled)
    # else: fast auth is tried in initial handshake

    if not pkt.is_extra_auth_data():
        raise OperationalError(
            "caching sha2: Unknown packet for fast auth: %s" % pkt._data[:1]
        )

    # magic numbers:
    # 2 - request public key
    # 3 - fast auth succeeded
    # 4 - need full auth

    pkt.advance(1)
    n = pkt.read_uint8()

    if n == 3:
        if DEBUG:
            print("caching sha2: succeeded by fast path.")
        pkt = conn._read_packet()
        pkt.check_error()  # pkt must be OK packet
        return pkt

    if n != 4:
        raise OperationalError("caching sha2: Unknwon result for fast auth: %s" % n)

    if DEBUG:
        print("caching sha2: Trying full auth...")

    if conn._secure:
        if DEBUG:
            print("caching sha2: Sending plain password via secure connection")
        return _roundtrip(conn, conn.password + b'\0')

    if not conn.server_public_key:
        pkt = _roundtrip(conn, b'\x02')  # Request public key
        if not pkt.is_extra_auth_data():
            raise OperationalError(
                "caching sha2: Unknown packet for public key: %s" % pkt._data[:1]
            )

        conn.server_public_key = pkt._data[1:]
        if DEBUG:
            print(conn.server_public_key.decode('ascii'))

    data = sha2_rsa_encrypt(conn.password, conn.salt, conn.server_public_key)
    pkt = _roundtrip(conn, data)