1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
|
from __future__ import unicode_literals
import datetime
import decimal
import hashlib
import logging
from time import time
from django.conf import settings
from django.utils.encoding import force_bytes
from django.utils.timezone import utc
logger = logging.getLogger('django.db.backends')
class CursorWrapper(object):
def __init__(self, cursor, db):
self.cursor = cursor
self.db = db
WRAP_ERROR_ATTRS = frozenset(['fetchone', 'fetchmany', 'fetchall', 'nextset'])
def __getattr__(self, attr):
cursor_attr = getattr(self.cursor, attr)
if attr in CursorWrapper.WRAP_ERROR_ATTRS:
return self.db.wrap_database_errors(cursor_attr)
else:
return cursor_attr
def __iter__(self):
with self.db.wrap_database_errors:
for item in self.cursor:
yield item
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
# Close instead of passing through to avoid backend-specific behavior
# (#17671). Catch errors liberally because errors in cleanup code
# aren't useful.
try:
self.close()
except self.db.Database.Error:
pass
# The following methods cannot be implemented in __getattr__, because the
# code must run when the method is invoked, not just when it is accessed.
def callproc(self, procname, params=None):
self.db.validate_no_broken_transaction()
with self.db.wrap_database_errors:
if params is None:
return self.cursor.callproc(procname)
else:
return self.cursor.callproc(procname, params)
def execute(self, sql, params=None):
self.db.validate_no_broken_transaction()
with self.db.wrap_database_errors:
if params is None:
return self.cursor.execute(sql)
else:
return self.cursor.execute(sql, params)
def executemany(self, sql, param_list):
self.db.validate_no_broken_transaction()
with self.db.wrap_database_errors:
return self.cursor.executemany(sql, param_list)
class CursorDebugWrapper(CursorWrapper):
# XXX callproc isn't instrumented at this time.
def execute(self, sql, params=None):
start = time()
try:
return super(CursorDebugWrapper, self).execute(sql, params)
finally:
stop = time()
duration = stop - start
sql = self.db.ops.last_executed_query(self.cursor, sql, params)
self.db.queries_log.append({
'sql': sql,
'time': "%.3f" % duration,
})
logger.debug(
'(%.3f) %s; args=%s', duration, sql, params,
extra={'duration': duration, 'sql': sql, 'params': params}
)
def executemany(self, sql, param_list):
start = time()
try:
return super(CursorDebugWrapper, self).executemany(sql, param_list)
finally:
stop = time()
duration = stop - start
try:
times = len(param_list)
except TypeError: # param_list could be an iterator
times = '?'
self.db.queries_log.append({
'sql': '%s times: %s' % (times, sql),
'time': "%.3f" % duration,
})
logger.debug(
'(%.3f) %s; args=%s', duration, sql, param_list,
extra={'duration': duration, 'sql': sql, 'params': param_list}
)
###############################################
# Converters from database (string) to Python #
###############################################
def typecast_date(s):
return datetime.date(*map(int, s.split('-'))) if s else None # returns None if s is null
def typecast_time(s): # does NOT store time zone information
if not s:
return None
hour, minutes, seconds = s.split(':')
if '.' in seconds: # check whether seconds have a fractional part
seconds, microseconds = seconds.split('.')
else:
microseconds = '0'
return datetime.time(int(hour), int(minutes), int(seconds), int((microseconds + '000000')[:6]))
def typecast_timestamp(s): # does NOT store time zone information
# "2005-07-29 15:48:00.590358-05"
# "2005-07-29 09:56:00-05"
if not s:
return None
if ' ' not in s:
return typecast_date(s)
d, t = s.split()
# Extract timezone information, if it exists. Currently we just throw
# it away, but in the future we may make use of it.
if '-' in t:
t, tz = t.split('-', 1)
tz = '-' + tz
elif '+' in t:
t, tz = t.split('+', 1)
tz = '+' + tz
else:
tz = ''
dates = d.split('-')
times = t.split(':')
seconds = times[2]
if '.' in seconds: # check whether seconds have a fractional part
seconds, microseconds = seconds.split('.')
else:
microseconds = '0'
tzinfo = utc if settings.USE_TZ else None
return datetime.datetime(
int(dates[0]), int(dates[1]), int(dates[2]),
int(times[0]), int(times[1]), int(seconds),
int((microseconds + '000000')[:6]), tzinfo
)
def typecast_decimal(s):
if s is None or s == '':
return None
return decimal.Decimal(s)
###############################################
# Converters from Python to database (string) #
###############################################
def rev_typecast_decimal(d):
if d is None:
return None
return str(d)
def split_identifier(identifier):
"""
Split a SQL identifier into a two element tuple of (namespace, name).
The identifier could be a table, column, or sequence name might be prefixed
by a namespace.
"""
try:
namespace, name = identifier.split('"."')
except ValueError:
namespace, name = '', identifier
return namespace.strip('"'), name.strip('"')
def truncate_name(identifier, length=None, hash_len=4):
"""
Shorten a SQL identifier to a repeatable mangled version with the given
length.
If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
truncate the table portion only.
"""
namespace, name = split_identifier(identifier)
if length is None or len(name) <= length:
return identifier
digest = hashlib.md5(force_bytes(name)).hexdigest()[:hash_len]
return '%s%s%s' % ('%s"."' % namespace if namespace else '', name[:length - hash_len], digest)
def format_number(value, max_digits, decimal_places):
"""
Formats a number into a string with the requisite number of digits and
decimal places.
"""
if value is None:
return None
if isinstance(value, decimal.Decimal):
context = decimal.getcontext().copy()
if max_digits is not None:
context.prec = max_digits
if decimal_places is not None:
value = value.quantize(decimal.Decimal(".1") ** decimal_places, context=context)
else:
context.traps[decimal.Rounded] = 1
value = context.create_decimal(value)
return "{:f}".format(value)
if decimal_places is not None:
return "%.*f" % (decimal_places, value)
return "{:f}".format(value)
def strip_quotes(table_name):
"""
Strip quotes off of quoted table names to make them safe for use in index
names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming
scheme) becomes 'USER"."TABLE'.
"""
has_quotes = table_name.startswith('"') and table_name.endswith('"')
return table_name[1:-1] if has_quotes else table_name
|