File: connection.py

package info (click to toggle)
python-vertica 0.2.3-1
  • links: PTS
  • area: contrib
  • in suites: jessie, jessie-kfreebsd
  • size: 280 kB
  • ctags: 232
  • sloc: python: 999; makefile: 3
file content (236 lines) | stat: -rw-r--r-- 8,071 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
from __future__ import absolute_import

import os
import select
import socket
import ssl

from struct import unpack

import vertica_python.errors as errors
import vertica_python.vertica.messages as messages

from vertica_python.vertica.messages.message import BackendMessage

from vertica_python.vertica.cursor  import Cursor

# To support vertica_python 0.1.9 interface
class OldResults(object):
    def __init__(self, rows):
        self.rows = rows

class Connection(object):

    def __init__(self, options={}):
        self.reset_values()

        self.options = {}

        for key, value in options.iteritems():
            if value is not None:
                self.options[key] = value

        self.options['port'] = 5433 if self.options.get('port') is None else self.options['port']
        self.options['read_timeout'] = 600 if self.options.get('read_timeout') is None else self.options['read_timeout']

        self.row_style = self.options.get('row_style') if self.options.get('row_style') is not None else 'hash'
        self.boot_connection()
        #self.debug = True


    #
    # To support vertica_python 0.1.9 interface
    #
    def query(self, query, handler=None):
        if handler:
            cur = Cursor(self, 'dict', handler)
            cur.execute(query)
        else:
            cur = Cursor(self, 'dict')
            cur.execute(query)
            return OldResults(cur.fetchall())

    #
    # dbApi methods
    #

    def close(self):
        try:
            self.write(messages.Terminate())
        finally:
            self.close_socket()

    def commit(self):
        if self.closed():
            raise errors.Error('Connection is closed')

        cur = self.cursor()
        cur.execute('commit')

    def rollback(self):
        if self.closed():
            raise errors.Error('Connection is closed')

        cur = self.cursor()
        cur.execute('rollback')

    def cursor(self, cursor_type=None, row_handler=None):
        if self.closed():
            raise errors.Error('Connection is closed')
        return Cursor(self, cursor_type=cursor_type, row_handler=row_handler)



    #
    # Internal 
    #

    def reset_values(self):
        self.parameters = {}
        self.session_id = None
        self.backend_pid = None
        self.backend_key = None
        self.transaction_status = None
        self.socket = None

    def _socket(self):
        if self.socket is not None:
            return self.socket

        if self.options.get('ssl', False) is True:
            # SSL
            raw_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            raw_socket.connect((self.options['host'], self.options['port']))
            raw_socket.sendall(messages.SslRequest().to_bytes())
            response = raw_socket.recv(1)
            if response == 'S':
                # May want to add certs to this
                raw_socket = ssl.wrap_socket(raw_socket)
            else:
                raise SSLNotSupported("SSL requested but not supported by server")
        else:
            # Non-SSL
            raw_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            raw_socket.connect((self.options['host'], self.options['port']))

        self.socket = raw_socket
        return raw_socket

    def ssl(self):
        return self.socket is not None and isinstance(ssl.SSLSocket, self.socket)

    def opened(self):
        return self.socket is not None and self.backend_pid is not None and self.transaction_status is not None

    def closed(self):
        return not self.opened()

    def write(self, message):

        if hasattr(message, 'to_bytes') is False or callable(getattr(message, 'to_bytes')) is False:
            raise TypeError("invalid message: ({0})".format(message))

        if getattr(self, 'debug', False):
            print ( "=> {0}", format(message))
        try:
            self._socket().sendall(message.to_bytes())
        except Exception as e:
            self.close_socket()
            raise errors.ConnectionError(e.message)

    def close_socket(self):
        try:
            self._socket().close()
            self.socket = None
        finally:
            self.reset_values()

    def reset_connection(self):
        self.close()
        self.boot_connection()

    def boot_connection(self):
        self.startup_connection()
        self.initialize_connection()

    def read_message(self):
        try:
            ready = select.select([self._socket()], [], [], self.options['read_timeout'])
            if len(ready[0]) > 0:
                type = self.read_bytes(1)
                size = unpack('!I', self.read_bytes(4))[0]

                if size < 4:
                    raise errors.MessageError("Bad message size: {0}".format(size))
                message = BackendMessage.factory(type, self.read_bytes(size - 4))
                if getattr(self, 'debug', False):
                    print ("<= {0}", format(message))
                return message
            else:
                self.close()
                raise errors.TimedOutError("Connection timed out")
        except Exception as e:
            self.close_socket()
            raise errors.ConnectionError(e.message)

    def process_message(self, message):
        if isinstance(message, messages.ErrorResponse):
            raise errors.ConnectionError(message.error_message())
        elif isinstance(message, messages.NoticeResponse):
            if getattr(self, 'notice_handler', None) is not None:
                self.notice_handler(message)
        elif isinstance(message, messages.BackendKeyData):
            self.backend_pid = message.pid
            self.backend_key = message.key
        elif isinstance(message, messages.ParameterStatus):
            self.parameters[message.name] = message.value
        elif isinstance(message, messages.ReadyForQuery):
            self.transaction_status = message.transaction_status
        else:
            raise errors.MessageError("Unhandled message: {0}".format(message))

    def __str__(self):
        safe_options = {}
        for key, value in self.options.iteritems():
            if key != 'password':
                safe_options[key] = value
        s1 = "<Vertica.Connection:{0} parameters={1} backend_pid={2}, ".format(id(self), self.parameters, self.backend_pid)
        s2 = "backend_key={0}, transaction_status={1}, socket={2}, options={3}, row_style={4}>".format(self.backend_key, self.transaction_status, self.socket, safe_options, self.row_style)
        return s1+s2



    def read_bytes(self, n):
        results = ''
        while len(results) < n:
            bytes = self._socket().recv(n-len(results))
            if bytes is None or len(bytes) == 0:
                raise errors.ConnectionError("Connection closed by Vertica")
            results = results + bytes
        return results

    def startup_connection(self):
        self.write(messages.Startup(self.options['user'], self.options['database']))
        message = None
        while True:
            message = self.read_message()

            if isinstance(message, messages.Authentication):
                # Password message isn't right format ("incomplete message from client")
                if message.code != messages.Authentication.OK:
                    self.write(messages.Password(self.options['password'], message.code, dict(user=self.options['user'], salt=getattr(message, 'salt', None))))
            else:
                self.process_message(message)

            if isinstance(message, messages.ReadyForQuery):
                break

    def initialize_connection(self):
        if self.options.get('search_path') is not None:
            self.query("SET SEARCH_PATH TO {0}".format(self.options['search_path']))
        if self.options.get('role') is not None:
            self.query("SET ROLE {0}".format(self.options['role']))
        #if self.options.get('interruptable', False) is True:
        #    self.session_id = self.query("SELECT session_id FROM v_monitor.current_session").the_value()