1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
|
"""
Classes that represent database functions.
"""
from django.db.models import Func, Transform, Value, fields
class Cast(Func):
"""
Coerce an expression to a new field type.
"""
function = 'CAST'
template = '%(function)s(%(expressions)s AS %(db_type)s)'
mysql_types = {
fields.CharField: 'char',
fields.IntegerField: 'signed integer',
fields.FloatField: 'signed',
}
def __init__(self, expression, output_field):
super(Cast, self).__init__(expression, output_field=output_field)
def as_sql(self, compiler, connection, **extra_context):
if 'db_type' not in extra_context:
extra_context['db_type'] = self._output_field.db_type(connection)
return super(Cast, self).as_sql(compiler, connection, **extra_context)
def as_mysql(self, compiler, connection):
extra_context = {}
output_field_class = type(self._output_field)
if output_field_class in self.mysql_types:
extra_context['db_type'] = self.mysql_types[output_field_class]
return self.as_sql(compiler, connection, **extra_context)
def as_postgresql(self, compiler, connection):
# CAST would be valid too, but the :: shortcut syntax is more readable.
return self.as_sql(compiler, connection, template='%(expressions)s::%(db_type)s')
class Coalesce(Func):
"""
Chooses, from left to right, the first non-null expression and returns it.
"""
function = 'COALESCE'
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError('Coalesce must take at least two expressions')
super(Coalesce, self).__init__(*expressions, **extra)
def as_oracle(self, compiler, connection):
# we can't mix TextField (NCLOB) and CharField (NVARCHAR), so convert
# all fields to NCLOB when we expect NCLOB
if self.output_field.get_internal_type() == 'TextField':
class ToNCLOB(Func):
function = 'TO_NCLOB'
expressions = [
ToNCLOB(expression) for expression in self.get_source_expressions()]
clone = self.copy()
clone.set_source_expressions(expressions)
return super(Coalesce, clone).as_sql(compiler, connection)
return self.as_sql(compiler, connection)
class ConcatPair(Func):
"""
A helper class that concatenates two arguments together. This is used
by `Concat` because not all backend databases support more than two
arguments.
"""
function = 'CONCAT'
def __init__(self, left, right, **extra):
super(ConcatPair, self).__init__(left, right, **extra)
def as_sqlite(self, compiler, connection):
coalesced = self.coalesce()
return super(ConcatPair, coalesced).as_sql(
compiler, connection, template='%(expressions)s', arg_joiner=' || '
)
def as_mysql(self, compiler, connection):
# Use CONCAT_WS with an empty separator so that NULLs are ignored.
return super(ConcatPair, self).as_sql(
compiler, connection, function='CONCAT_WS', template="%(function)s('', %(expressions)s)"
)
def coalesce(self):
# null on either side results in null for expression, wrap with coalesce
c = self.copy()
expressions = [
Coalesce(expression, Value('')) for expression in c.get_source_expressions()
]
c.set_source_expressions(expressions)
return c
class Concat(Func):
"""
Concatenates text fields together. Backends that result in an entire
null expression when any arguments are null will wrap each argument in
coalesce functions to ensure we always get a non-null result.
"""
function = None
template = "%(expressions)s"
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError('Concat must take at least two expressions')
paired = self._paired(expressions)
super(Concat, self).__init__(paired, **extra)
def _paired(self, expressions):
# wrap pairs of expressions in successive concat functions
# exp = [a, b, c, d]
# -> ConcatPair(a, ConcatPair(b, ConcatPair(c, d))))
if len(expressions) == 2:
return ConcatPair(*expressions)
return ConcatPair(expressions[0], self._paired(expressions[1:]))
class Greatest(Func):
"""
Chooses the maximum expression and returns it.
If any expression is null the return value is database-specific:
On Postgres, the maximum not-null expression is returned.
On MySQL, Oracle, and SQLite, if any expression is null, null is returned.
"""
function = 'GREATEST'
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError('Greatest must take at least two expressions')
super(Greatest, self).__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection):
"""Use the MAX function on SQLite."""
return super(Greatest, self).as_sqlite(compiler, connection, function='MAX')
class Least(Func):
"""
Chooses the minimum expression and returns it.
If any expression is null the return value is database-specific:
On Postgres, the minimum not-null expression is returned.
On MySQL, Oracle, and SQLite, if any expression is null, null is returned.
"""
function = 'LEAST'
def __init__(self, *expressions, **extra):
if len(expressions) < 2:
raise ValueError('Least must take at least two expressions')
super(Least, self).__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection):
"""Use the MIN function on SQLite."""
return super(Least, self).as_sqlite(compiler, connection, function='MIN')
class Length(Transform):
"""Returns the number of characters in the expression"""
function = 'LENGTH'
lookup_name = 'length'
def __init__(self, expression, **extra):
output_field = extra.pop('output_field', fields.IntegerField())
super(Length, self).__init__(expression, output_field=output_field, **extra)
def as_mysql(self, compiler, connection):
return super(Length, self).as_sql(compiler, connection, function='CHAR_LENGTH')
class Lower(Transform):
function = 'LOWER'
lookup_name = 'lower'
class Now(Func):
template = 'CURRENT_TIMESTAMP'
def __init__(self, output_field=None, **extra):
if output_field is None:
output_field = fields.DateTimeField()
super(Now, self).__init__(output_field=output_field, **extra)
def as_postgresql(self, compiler, connection):
# Postgres' CURRENT_TIMESTAMP means "the time at the start of the
# transaction". We use STATEMENT_TIMESTAMP to be cross-compatible with
# other databases.
return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()')
class Substr(Func):
function = 'SUBSTRING'
def __init__(self, expression, pos, length=None, **extra):
"""
expression: the name of a field, or an expression returning a string
pos: an integer > 0, or an expression returning an integer
length: an optional number of characters to return
"""
if not hasattr(pos, 'resolve_expression'):
if pos < 1:
raise ValueError("'pos' must be greater than 0")
pos = Value(pos)
expressions = [expression, pos]
if length is not None:
if not hasattr(length, 'resolve_expression'):
length = Value(length)
expressions.append(length)
super(Substr, self).__init__(*expressions, **extra)
def as_sqlite(self, compiler, connection):
return super(Substr, self).as_sql(compiler, connection, function='SUBSTR')
def as_oracle(self, compiler, connection):
return super(Substr, self).as_sql(compiler, connection, function='SUBSTR')
class Upper(Transform):
function = 'UPPER'
lookup_name = 'upper'
|