1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
|
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""A lightweight wrapper around MySQLdb."""
import copy
import MySQLdb.constants
import MySQLdb.converters
import MySQLdb.cursors
import itertools
import logging
class Connection(object):
"""A lightweight wrapper around MySQLdb DB-API connections.
The main value we provide is wrapping rows in a dict/object so that
columns can be accessed by name. Typical usage:
db = database.Connection("localhost", "mydatabase")
for article in db.query("SELECT * FROM articles"):
print article.title
Cursors are hidden by the implementation, but other than that, the methods
are very similar to the DB-API.
We explicitly set the timezone to UTC and the character encoding to
UTF-8 on all connections to avoid time zone and encoding errors.
"""
def __init__(self, host, database, user=None, password=None):
self.host = host
self.database = database
args = dict(conv=CONVERSIONS, use_unicode=True, charset="utf8",
db=database, init_command='SET time_zone = "+0:00"',
sql_mode="TRADITIONAL")
if user is not None:
args["user"] = user
if password is not None:
args["passwd"] = password
# We accept a path to a MySQL socket file or a host(:port) string
if "/" in host:
args["unix_socket"] = host
else:
self.socket = None
pair = host.split(":")
if len(pair) == 2:
args["host"] = pair[0]
args["port"] = int(pair[1])
else:
args["host"] = host
args["port"] = 3306
self._db = None
self._db_args = args
try:
self.reconnect()
except:
logging.error("Cannot connect to MySQL on %s", self.host,
exc_info=True)
def __del__(self):
self.close()
def close(self):
"""Closes this database connection."""
if getattr(self, "_db", None) is not None:
self._db.close()
self._db = None
def reconnect(self):
"""Closes the existing database connection and re-opens it."""
self.close()
self._db = MySQLdb.connect(**self._db_args)
self._db.autocommit(True)
def iter(self, query, *parameters):
"""Returns an iterator for the given query and parameters."""
if self._db is None: self.reconnect()
cursor = MySQLdb.cursors.SSCursor(self._db)
try:
self._execute(cursor, query, parameters)
column_names = [d[0] for d in cursor.description]
for row in cursor:
yield Row(zip(column_names, row))
finally:
cursor.close()
def query(self, query, *parameters):
"""Returns a row list for the given query and parameters."""
cursor = self._cursor()
try:
self._execute(cursor, query, parameters)
column_names = [d[0] for d in cursor.description]
return [Row(itertools.izip(column_names, row)) for row in cursor]
finally:
cursor.close()
def get(self, query, *parameters):
"""Returns the first row returned for the given query."""
rows = self.query(query, *parameters)
if not rows:
return None
elif len(rows) > 1:
raise Exception("Multiple rows returned for Database.get() query")
else:
return rows[0]
def execute(self, query, *parameters):
"""Executes the given query, returning the lastrowid from the query."""
cursor = self._cursor()
try:
self._execute(cursor, query, parameters)
return cursor.lastrowid
finally:
cursor.close()
def executemany(self, query, parameters):
"""Executes the given query against all the given param sequences.
We return the lastrowid from the query.
"""
cursor = self._cursor()
try:
cursor.executemany(query, parameters)
return cursor.lastrowid
finally:
cursor.close()
def _cursor(self):
if self._db is None: self.reconnect()
return self._db.cursor()
def _execute(self, cursor, query, parameters):
try:
return cursor.execute(query, parameters)
except OperationalError:
logging.error("Error connecting to MySQL on %s", self.host)
self.close()
raise
class Row(dict):
"""A dict that allows for object-like property access syntax."""
def __getattr__(self, name):
try:
return self[name]
except KeyError:
raise AttributeError(name)
# Fix the access conversions to properly recognize unicode/binary
FIELD_TYPE = MySQLdb.constants.FIELD_TYPE
FLAG = MySQLdb.constants.FLAG
CONVERSIONS = copy.deepcopy(MySQLdb.converters.conversions)
field_types = [FIELD_TYPE.BLOB, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING]
if 'VARCHAR' in vars(FIELD_TYPE):
field_types.append(FIELD_TYPE.VARCHAR)
for field_type in field_types:
CONVERSIONS[field_type].insert(0, (FLAG.BINARY, str))
# Alias some common MySQL exceptions
IntegrityError = MySQLdb.IntegrityError
OperationalError = MySQLdb.OperationalError
|