File: db_utils.py

package info (click to toggle)
mysql-workbench 6.3.8%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 113,932 kB
  • ctags: 87,814
  • sloc: ansic: 955,521; cpp: 427,465; python: 59,728; yacc: 59,129; xml: 54,204; sql: 7,091; objc: 965; makefile: 638; sh: 613; java: 237; perl: 30; ruby: 6; php: 1
file content (372 lines) | stat: -rw-r--r-- 12,801 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
# Copyright (c) 2010, 2015, Oracle and/or its affiliates. All rights reserved.
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License as
# published by the Free Software Foundation; version 2 of the
# License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
# 02110-1301  USA

from grt import modules
from grt import DBLoginError

def escape_sql_string(s):
    return s.replace("\\", "\\\\").replace("'", "\\'")

def escape_sql_identifier(s):
    return s.replace("`", "``")


def strip_password(s):
    import re
    while True:
        ss = re.sub(r"(.*PASSWORD.*\(')(.*)('\).*)", r"\1XXXXXX\3", s)
        if s == ss:
            break
        s = ss
    return s

    
def substring_to_delimiter(source, index, limit, force_limit = False):
    """
        Extracts from a string starting at the given index and ending
        once the first character in limit is found.
        
        If force_limit is True, will return None if the limit is not found
        and the end of the string is reached.
    """

    # When limit is quoting handles special quote embedding cases:
    # - Escaped quoting using \
    # - Quoting using the quote character twice: '', "" or ``
    quoting = '\'"`'
    handle_embedded_quoting = True if limit in quoting else False
    

    # Sweeps the string starting at index and until a character in limit is found
    token = ''
    limit_found = False
    escape_found = False

    for char in source[index:]:
        if handle_embedded_quoting:
            if char == '\\' and not escape_found:
                escape_found = True
                token += char
                continue
                
        if char in limit:
            if not escape_found:
                if not limit_found:
                    limit_found = True
                else:
                    token += char
                    limit_found = False
        else:
            if limit_found:
                break
            
        if not limit_found:
            escape_found = False
            token += char

    # Limit may be mandatory or not, returns None if
    # it was not found
    if force_limit and not limit_found:
        return None
    else:
        return token

def parse_mysql_ids(source):
    """
        Extracts from a string an array with all the valid IDs found
        Expected format is a dot separated list of IDs where they could optionally be quoted
        by single, double or back quotes.

        If an invalid ID is found the process will stop.
    """
    ids = []
    index = 0;
    length = len(source)
    previous_token = False
    
    # Sweeps a string extracting all the 
    while index < length:
        token = ''
        char = source[index]

        # In case of quoting found, the ID will be extracted until the closing quote is found
        if char in '"\'`':
            token = substring_to_delimiter(source, index + 1, char, True)
            if token:
                index = index + len(token) + 2
        else:
            # The dot as separator is just skipped as long as previos was a valid token
            # i.e. 2 in a row is a mistake
            if char in '. ':
                if previous_token:
                    index += 1
                    previous_token = False
                else:
                    token = None

            else:
                # If no quoting is found, next ID will be until the delimiters are found
                token = substring_to_delimiter(source, index, '. ')
                index = index + len(token)

        if token:
            ids.append(token)
            previous_token = True
        elif token is None:
            break

    return ids   

class MySQLError(Exception):
    def __init__(self, msg, code, location):
        Exception.__init__(self, msg + " (code %i)"%code)
        self.code = code
        self.location = location


class MySQLLoginError(MySQLError, DBLoginError):
    def __init__(self, msg, code, location):
        MySQLError.__init__(self, msg, code, location)
        DBLoginError.__init__(self, msg)


class QueryError(Exception):
  not_connected_errors = (2006, 2013, 2026, 2055, 2048)
  def __init__(self, msg, error, errortext = None):
    self.msg = msg
    self.error = error
    self.errortext = errortext

  def __str__(self):
    return self.msg + ".\nSQL Error: " + str(self.error)

  def is_connection_error(self):
    code = 0
    try:
      code = int(self.error)
    except ValueError:
      pass
    return code in self.not_connected_errors

  def is_error_recoverable(self):
    return self.error != 2006 # Probably add more errors here


class ConnectionTunnel:
    def __init__(self, info):
        self.tunnel = modules.DbMySQLQuery.openTunnel(info)
        if self.tunnel > 0:
            self.port = modules.DbMySQLQuery.getTunnelPort(self.tunnel)
        else:
            self.port = None
    
    def __del__(self):
        if self.tunnel > 0:
            modules.DbMySQLQuery.closeTunnel(self.tunnel)


class MySQLResult:
    def __init__(self, result):
        self.result = result


    def __del__(self):
        if self.result:
            modules.DbMySQLQuery.closeResult(self.result)

    def firstRow(self):
        return modules.DbMySQLQuery.resultFirstRow(self.result)

    def nextRow(self):
        return modules.DbMySQLQuery.resultNextRow(self.result)

    
    def stringByName(self, name):
        return modules.DbMySQLQuery.resultFieldStringValueByName(self.result, name)

    def unicodeByName(self, name):
        s = modules.DbMySQLQuery.resultFieldStringValueByName(self.result, name)
        if type(s) is str:
            return s.decode("utf-8")
        return s

    def intByName(self, name):
        return modules.DbMySQLQuery.resultFieldIntValueByName(self.result, name)

    def floatByName(self, name):
        return modules.DbMySQLQuery.resultFieldDoubleValueByName(self.result, name)

    def stringByIndex(self, i):
        return modules.DbMySQLQuery.resultFieldStringValue(self.result, i)

    def unicodeByIndex(self, i):
        s = modules.DbMySQLQuery.resultFieldStringValue(self.result, i)
        if type(s) is str:
            return s.decode("utf-8")
        return s

    def floatByIndex(self, i):
        return modules.DbMySQLQuery.resultFieldDoubleValue(self.result, i)

    def intByIndex(self, i):
        return modules.DbMySQLQuery.resultFieldIntValue(self.result, i)

    def numFields(self):
        return modules.DbMySQLQuery.resultNumFields(self.result)

    def numRows(self):
        return modules.DbMySQLQuery.resultNumRows(self.result)

    def fieldName(self, i):
        return modules.DbMySQLQuery.resultFieldName(self.result, i)

    def fieldType(self, i):
        return modules.DbMySQLQuery.resultFieldType(self.result, i)
        

class MySQLConnection:
    """
        Connection to a MySQL server, use as:
          info = grt.root.wb.rdbmsMgmt.storedConns[0]
          conn = MySQLConnection(info)
          conn.connect()
          result = conn.executeQuery("SHOW DATABASES")
          flag = result.firstRow()
          while flag:
              print result.stringByName("Database")
              flag = result.nextRow()
    """
    def __init__(self, info, status_cb = None, password = None):
        assert type(status_cb) is not unicode
        self.connect_info = info
        self.connection = 0
        self.server_down = 0
        self.status_cb = status_cb
        self.password = password


    def __del__(self):
        self.disconnect()

    def send_status(self, code, error = None):
        if self.status_cb:
            self.status_cb(code, error, self.connect_info)

    def connect(self):
        self.server_down = False
        if not self.connection:
            params = self.connect_info.parameterValues
            old_timeout_value = None
            if params.has_key('OPT_READ_TIMEOUT'):
                old_timeout_value = params['OPT_READ_TIMEOUT']
            params['OPT_READ_TIMEOUT'] = 5*60

            #self.thread = thread.get_ident()
            if self.password is not None:
                self.connection = modules.DbMySQLQuery.openConnectionP(self.connect_info, self.password)
            else:
                self.connection = modules.DbMySQLQuery.openConnection(self.connect_info)

            if old_timeout_value:
                params['OPT_READ_TIMEOUT'] = old_timeout_value
            else:
                del params['OPT_READ_TIMEOUT']
            if self.connection < 0:
                self.connection = 0
                code = modules.DbMySQLQuery.lastErrorCode()
                if code == 1045:
                    raise MySQLLoginError(modules.DbMySQLQuery.lastError(), modules.DbMySQLQuery.lastErrorCode(), "%s@%s" % (self.connect_info.parameterValues["userName"], self.connect_info.parameterValues["hostName"]))
                    
                if code in (2003,2002):
                    self.server_down = True
                raise MySQLError(modules.DbMySQLQuery.lastError(), modules.DbMySQLQuery.lastErrorCode(), "%s@%s" % (self.connect_info.parameterValues["userName"], self.connect_info.parameterValues["hostName"]))

            self.send_status(0, "Connection created")

    def ping(self):
        self.executeQuery("SELECT 1")
        return True

    
    def disconnect(self):
        if self.connection:
            modules.DbMySQLQuery.closeConnection(self.connection)
            self.connection = 0
            self.send_status(-1, "Connection closed by client")
    
    @property
    def is_connected(self):
        return self.connection > 0
    

    def execute(self, query):
        if self.connection:
            #assert self.thread == thread.get_ident()
            result = modules.DbMySQLQuery.execute(self.connection, query)
            if result < 0:
              code = modules.DbMySQLQuery.lastConnectionErrorCode(self.connection)
              error = strip_password(modules.DbMySQLQuery.lastConnectionError(self.connection))
              self.send_status(code, error)
              raise QueryError("Error executing '%s'\n%s" % (strip_password(query), error), code, error)

            self.send_status(0)
            return result == 0
        else:
            self.send_status(-1, "Connection to MySQL server is currently not established")
            raise QueryError("Connection to MySQL server is currently not established", -1)


    def executeQuery(self, query):
        if self.connection:
            #assert self.thread == thread.get_ident()
            result = modules.DbMySQLQuery.executeQuery(self.connection, query.encode("utf-8") if type(query) is unicode else query)
            if result < 0:
                code = modules.DbMySQLQuery.lastConnectionErrorCode(self.connection)
                error = modules.DbMySQLQuery.lastConnectionError(self.connection)
                self.send_status(code, error)
                raise QueryError("Error executing '%s'\n%s"%(query, error), code, error)

            self.send_status(0)
            return MySQLResult(result)
        else:
            self.send_status(-1, "Connection to MySQL server is currently not established")
            raise QueryError("Connection to MySQL server is currently not established", -1)

    def executeQueryMultiResult(self, query):
        if self.connection:
            result = modules.DbMySQLQuery.executeQueryMultiResult(self.connection, query.encode("utf-8") if type(query) is unicode else query)
            if len(result) == 0:
                code = modules.DbMySQLQuery.lastConnectionErrorCode(self.connection)
                error = modules.DbMySQLQuery.lastConnectionError(self.connection)
                self.send_status(code, error)
                raise QueryError("Error executing '%s'\n%s"%(query, error), code, error)

            self.send_status(0)

            result_list = []

            for index in range(0, len(result)):
                result_list.append(MySQLResult(result[index]))
            return result_list
        else:
            self.send_status(-1, "Connection to MySQL server is currently not established")
            raise QueryError("Connection to MySQL server is currently not established", -1)

    def updateCount(self):
        return modules.DbMySQLQuery.lastUpdateCount(self.connection)
    affectedRows = updateCount