File: test_connect.py

package info (click to toggle)
python-clickhouse-driver 0.2.5-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,516 kB
  • sloc: python: 10,950; pascal: 42; makefile: 31; sh: 3
file content (351 lines) | stat: -rw-r--r-- 11,967 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
# coding: utf-8
import socket
from io import BytesIO
from unittest.mock import patch

from clickhouse_driver import errors
from clickhouse_driver.client import Client
from clickhouse_driver.protocol import ClientPacketTypes, ServerPacketTypes
from clickhouse_driver.bufferedreader import BufferedReader
from clickhouse_driver.writer import write_binary_str
from tests.testcase import BaseTestCase
from unittest import TestCase


class PacketsTestCase(BaseTestCase):
    def test_packets_to_str(self):
        self.assertEqual(ClientPacketTypes.to_str(2), 'Data')
        self.assertEqual(ClientPacketTypes.to_str(6), 'Unknown packet')
        self.assertEqual(ClientPacketTypes.to_str(42), 'Unknown packet')

        self.assertEqual(ServerPacketTypes.to_str(4), 'Pong')
        self.assertEqual(ServerPacketTypes.to_str(15), 'Unknown packet')
        self.assertEqual(ServerPacketTypes.to_str(42), 'Unknown packet')


class ConnectTestCase(BaseTestCase):
    def unexpected_packet_message(self, expected, got):
        return (
            'Code: 102. Unexpected packet from server {}:{} '
            '(expected {}, got {})'
        ).format(self.host, self.port, expected, got)

    def test_exception_on_hello_packet(self):
        with self.created_client(user='wrong_user') as client:
            with self.assertRaises(errors.ServerException) as e:
                client.execute('SHOW TABLES')

        # Simple exception formatting checks
        exc = e.exception
        self.assertIn('Code:', str(exc))
        self.assertIn('Stack trace:', str(exc))

    def test_network_error(self):
        client = Client('bad-address')

        with patch('socket.getaddrinfo') as mocked_getaddrinfo:
            mocked_getaddrinfo.side_effect = socket.error(
                -2, 'Name or service not known'
            )

            with self.assertRaises(errors.NetworkError):
                client.execute('SHOW TABLES')

    def test_timeout_error(self):
        with patch('socket.socket') as ms:
            ms.return_value.connect.side_effect = socket.timeout

            with self.assertRaises(errors.SocketTimeoutError) as e:
                self.client.execute('SHOW TABLES')
            self.assertEqual(
                str(e.exception),
                'Code: 209. ({}:{})'.format(self.host, self.port)
            )

            ms.return_value.connect.side_effect = socket.timeout(42, 'Test')

            with self.assertRaises(errors.SocketTimeoutError) as e:
                self.client.execute('SHOW TABLES')
            self.assertEqual(
                str(e.exception),
                'Code: 209. Test ({}:{})'.format(self.host, self.port)
            )

    def test_transport_not_connection_on_disconnect(self):
        # Create connection.
        self.client.execute('SELECT 1')

        connection = self.client.connection

        with patch.object(connection, 'ping') as mocked_ping:
            mocked_ping.return_value = False

            with patch.object(connection, 'socket') as mocked_socket:
                mocked_socket.shutdown.side_effect = socket.error(
                    107, 'Transport endpoint is not connected'
                )

                # New socket should be created.
                rv = self.client.execute('SELECT 1')
                self.assertEqual(rv, [(1, )])

                # Close newly created socket.
                connection.socket.close()

    def test_socket_error_on_ping(self):
        self.client.execute('SELECT 1')

        with patch.object(self.client.connection, 'fout') as mocked_fout:
            mocked_fout.flush.side_effect = socket.error(32, 'Broken pipe')

            rv = self.client.execute('SELECT 1')
            self.assertEqual(rv, [(1, )])

    def test_ping_got_unexpected_package(self):
        self.client.execute('SELECT 1')

        with patch.object(self.client.connection, 'fin') as mocked_fin:
            # Emulate Exception packet on ping.
            mocked_fin.read_one.return_value = 2

            error = errors.UnexpectedPacketFromServerError
            with self.assertRaises(error) as e:
                self.client.execute('SELECT 1')

            self.assertEqual(
                str(e.exception),
                self.unexpected_packet_message('Pong', 'Exception')
            )

    def test_eof_on_receive_packet(self):
        self.client.execute('SELECT 1')

        with patch.object(self.client.connection, 'fin') as mocked_fin:
            # Emulate Exception packet on ping.
            mocked_fin.read_one.side_effect = [4, EOFError]

            with self.assertRaises(EOFError):
                self.client.execute('SELECT 1')

    def test_eof_error_on_ping(self):
        self.client.execute('SELECT 1')

        self.raised = False
        read_one = self.client.connection.fin.read_one

        def side_effect(*args, **kwargs):
            if not self.raised:
                self.raised = True
                raise EOFError('Unexpected EOF while reading bytes')

            else:
                return read_one(*args, **kwargs)

        with patch.object(self.client.connection, 'fin') as mocked_fin:
            mocked_fin.read_one.side_effect = side_effect

            rv = self.client.execute('SELECT 1')
            self.assertEqual(rv, [(1, )])

    def test_alt_hosts(self):
        client = Client(
            'wrong_host', 1234, self.database, self.user, self.password,
            alt_hosts='{}:{}'.format(self.host, self.port)
        )

        self.n_calls = 0
        getaddrinfo = socket.getaddrinfo

        def side_getaddrinfo(host, *args, **kwargs):
            if host == 'wrong_host':
                self.n_calls += 1
                raise socket.error(-2, 'Name or service not known')
            return getaddrinfo(host, *args, **kwargs)

        with patch('socket.getaddrinfo') as mocked_getaddrinfo:
            mocked_getaddrinfo.side_effect = side_getaddrinfo

            rv = client.execute('SELECT 1')
            self.assertEqual(rv, [(1,)])

            client.disconnect()

            rv = client.execute('SELECT 1')
            self.assertEqual(rv, [(1,)])
            # Last host must be remembered and getaddrinfo must call exactly
            # once with host == 'wrong_host'.
            self.assertEqual(self.n_calls, 1)

        client.disconnect()

    def test_remember_current_database(self):
        with self.created_client() as client:
            client.execute('   USE     system   ; ')
            client.disconnect()

            rv = client.execute('SELECT currentDatabase()')
            self.assertEqual(rv, [('system', )])

    def test_context_manager(self):
        with self.created_client() as c:
            c.execute('SELECT 1')
            self.assertTrue(c.connection.connected)
        self.assertFalse(c.connection.connected)

    def test_unknown_packet(self):
        self.client.execute('SELECT 1')

        with patch('clickhouse_driver.connection.read_varint') as read_mock, \
                patch.object(self.client.connection, 'force_connect'):
            read_mock.return_value = 42

            with self.assertRaises(errors.UnknownPacketFromServerError) as e:
                self.client.execute('SELECT 1')

            self.assertEqual(
                str(e.exception),
                'Code: 100. Unknown packet 42 from server {}:{}'.format(
                    self.host, self.port
                )
            )

    def test_unknown_packet_on_connect(self):
        with patch('clickhouse_driver.connection.read_varint') as read_mock:
            read_mock.return_value = 42

            error = errors.UnexpectedPacketFromServerError
            with self.assertRaises(error) as e:
                self.client.execute('SELECT 1')

            msg = self.unexpected_packet_message(
                'Hello or Exception', 'Unknown packet'
            )
            self.assertEqual(str(e.exception), msg)

    def test_partially_consumed_query(self):
        self.client.execute_iter('SELECT 1')

        error = errors.PartiallyConsumedQueryError
        with self.assertRaises(error) as e:
            self.client.execute_iter('SELECT 1')

        self.assertEqual(
            str(e.exception),
            'Simultaneous queries on single connection detected'
        )
        rv = self.client.execute('SELECT 1')
        self.assertEqual(rv, [(1, )])

    def test_read_all_packets_on_execute_iter(self):
        list(self.client.execute_iter('SELECT 1'))
        list(self.client.execute_iter('SELECT 1'))

    def test_round_robin(self):
        kwargs = {
            'round_robin': True,
            'alt_hosts': '{}:{}'.format(self.host, self.port)
        }
        with self.created_client(**kwargs) as client:
            self.assertFalse(client.connection.connected)
            self.assertFalse(list(client.connections)[0].connected)

            client.execute('SELECT 1')

            self.assertTrue(client.connection.connected)
            self.assertFalse(list(client.connections)[0].connected)

            client.execute('SELECT 1')

            self.assertTrue(client.connection.connected)
            self.assertTrue(list(client.connections)[0].connected)

            client.disconnect()

            self.assertFalse(client.connection.connected)
            self.assertFalse(list(client.connections)[0].connected)


class FakeBufferedReader(BufferedReader):
    def __init__(self, inputs, bufsize=128):
        super(FakeBufferedReader, self).__init__(bufsize)
        self._inputs = inputs
        self._counter = 0

    def read_into_buffer(self):
        try:
            value = self._inputs[self._counter]
        except IndexError:
            value = b''
        else:
            self._counter += 1

        self.current_buffer_size = len(value)
        self.buffer[:len(value)] = value

        if self.current_buffer_size == 0:
            raise EOFError('Unexpected EOF while reading bytes')


class TestBufferedReader(TestCase):

    def test_corner_case_read(self):
        rdr = FakeBufferedReader([
            b'\x00' * 10,
            b'\xff' * 10,
        ])

        self.assertEqual(rdr.read(5), b'\x00' * 5)
        self.assertEqual(rdr.read(10), b'\x00' * 5 + b'\xff' * 5)
        self.assertEqual(rdr.read(5), b'\xff' * 5)

        self.assertRaises(EOFError, rdr.read, 10)

    def test_corner_case_read_to_end_of_buffer(self):
        rdr = FakeBufferedReader([
            b'\x00' * 10,
            b'\xff' * 10,
        ])

        self.assertEqual(rdr.read(5), b'\x00' * 5)
        self.assertEqual(rdr.read(5), b'\x00' * 5)
        self.assertEqual(rdr.read(5), b'\xff' * 5)
        self.assertEqual(rdr.read(5), b'\xff' * 5)

        self.assertRaises(EOFError, rdr.read, 10)

    def test_corner_case_exact_buffer(self):
        rdr = FakeBufferedReader([
            b'\x00' * 10,
            b'\xff' * 10,
        ], bufsize=10)

        self.assertEqual(rdr.read(5), b'\x00' * 5)
        self.assertEqual(rdr.read(10), b'\x00' * 5 + b'\xff' * 5)
        self.assertEqual(rdr.read(5), b'\xff' * 5)

    def test_read_strings(self):
        strings = [
            u'Yoyodat' * 10,
            u'Peter Maffay' * 10,
        ]

        buf = BytesIO()
        for name in strings:
            write_binary_str(name, buf)
        buf = buf.getvalue()

        ref_values = tuple(x.encode('utf-8') for x in strings)

        for split in range(1, len(buf) - 1):
            for split_2 in range(split + 1, len(buf) - 2):
                self.assertEqual(
                    buf[:split] + buf[split:split_2] + buf[split_2:], buf
                )
                bufs = [
                    buf[:split],
                    buf[split:split_2],
                    buf[split_2:],
                ]
                rdr = FakeBufferedReader(bufs, bufsize=4096)
                read_values = rdr.read_strings(2)
                self.assertEqual(repr(ref_values), repr(read_values))