1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
|
#!/usr/bin/env python
# -*- coding: ISO-8859-15 -*-
#
# Copyright (C) 2005-2007 David Guerizec <david@guerizec.net>
#
# Last modified: 2006 Sep 17, 01:30:08 by david
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
import os
import MySQLdb
from sshproxy.config import get_config
def Q(item):
"""Safe quote mysql values"""
if item is None:
return ''
return str(item).replace("'", "\\'")
class MySQLDB(object):
"""
This object is meant to be used as a mixin to open only the
necessary number of connections to the database.
It implements the open_db method, that should be called from __reginit__.
"""
__db = {}
def open_db(self):
cfg = get_config('%s.mysql' % self._db_handler)
conid = 'mysql://%s@%s:%s/%s' % (cfg['user'], cfg['host'],
cfg['port'], cfg['db'])
if not self.__db.has_key(conid):
try:
MySQLDB.__db[conid] = MySQLdb.connect(host=cfg['host'],
port=cfg['port'],
db=cfg['db'],
user=cfg['user'],
passwd=cfg['password'])
except:
if not os.environ.get('SSHPROXY_WIZARD', None):
raise
self.db = self.__db[conid]
def sql_get(self, query):
sql = self.db.cursor()
sql.execute(query)
result = sql.fetchone()
sql.close()
if not result or not len(result):
return None
if len(result) == 1:
return result[0]
return result
def sql_list(self, query):
sql = self.db.cursor()
sql.execute(query)
for result in sql.fetchall():
yield result
sql.close()
return
def sql_add(self, query):
sql = self.db.cursor()
sql.execute(query)
sql.close()
result = self.sql_get('select last_insert_id()')
return result
def sql_update(self, query):
sql = self.db.cursor()
sql.execute(query)
sql.close()
# same content
sql_del = sql_update
def sql_set(self, table, **fields):
query = """replace %s set %s"""
q = []
for field, value in fields.items():
q.append("%s='%s'" % (field, Q(value)))
sql = self.db.cursor()
sql.execute(query % (table, ', '.join(q)))
sql.close()
|