1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import (
_literal_as_text,
ClauseElement,
ColumnElement,
Executable,
FunctionElement
)
from sqlalchemy.sql.functions import GenericFunction
from sqlalchemy_utils.functions.orm import quote
class explain(Executable, ClauseElement):
"""
Define EXPLAIN element.
http://www.postgresql.org/docs/devel/static/sql-explain.html
"""
def __init__(
self,
stmt,
analyze=False,
verbose=False,
costs=True,
buffers=False,
timing=True,
format='text'
):
self.statement = _literal_as_text(stmt)
self.analyze = analyze
self.verbose = verbose
self.costs = costs
self.buffers = buffers
self.timing = timing
self.format = format
class explain_analyze(explain):
def __init__(self, stmt, **kwargs):
super(explain_analyze, self).__init__(
stmt,
analyze=True,
**kwargs
)
@compiles(explain, 'postgresql')
def pg_explain(element, compiler, **kw):
text = "EXPLAIN "
options = []
if element.analyze:
options.append('ANALYZE true')
if not element.timing:
options.append('TIMING false')
if element.buffers:
options.append('BUFFERS true')
if element.format != 'text':
options.append('FORMAT %s' % element.format)
if element.verbose:
options.append('VERBOSE true')
if not element.costs:
options.append('COSTS false')
if options:
text += '(%s) ' % ', '.join(options)
text += compiler.process(element.statement)
return text
class array_get(FunctionElement):
name = 'array_get'
@compiles(array_get)
def compile_array_get(element, compiler, **kw):
args = list(element.clauses)
if len(args) != 2:
raise Exception(
"Function 'array_get' expects two arguments (%d given)." %
len(args)
)
if not hasattr(args[1], 'value') or not isinstance(args[1].value, int):
raise Exception(
"Second argument should be an integer."
)
return '(%s)[%s]' % (
compiler.process(args[0]),
sa.text(str(args[1].value + 1))
)
class row_to_json(GenericFunction):
name = 'row_to_json'
type = postgresql.JSON
@compiles(row_to_json, 'postgresql')
def compile_row_to_json(element, compiler, **kw):
return "%s(%s)" % (element.name, compiler.process(element.clauses))
class json_array_length(GenericFunction):
name = 'json_array_length'
type = sa.Integer
@compiles(json_array_length, 'postgresql')
def compile_json_array_length(element, compiler, **kw):
return "%s(%s)" % (element.name, compiler.process(element.clauses))
class array_agg(GenericFunction):
name = 'array_agg'
type = postgresql.ARRAY
def __init__(self, arg, default=None, **kw):
self.type = postgresql.ARRAY(arg.type)
self.default = default
GenericFunction.__init__(self, arg, **kw)
@compiles(array_agg, 'postgresql')
def compile_array_agg(element, compiler, **kw):
compiled = "%s(%s)" % (element.name, compiler.process(element.clauses))
if element.default is None:
return compiled
return str(sa.func.coalesce(
sa.text(compiled),
sa.cast(postgresql.array(element.default), element.type)
).compile(compiler))
class Asterisk(ColumnElement):
def __init__(self, selectable):
self.selectable = selectable
@compiles(Asterisk)
def compile_asterisk(element, compiler, **kw):
return '%s.*' % quote(compiler.dialect, element.selectable.name)
|