1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
|
from sqlalchemy import sql, util, schema, topological
"""Utility functions that build upon SQL and Schema constructs."""
class TableCollection(object):
def __init__(self, tables=None):
self.tables = tables or []
def __len__(self):
return len(self.tables)
def __getitem__(self, i):
return self.tables[i]
def __iter__(self):
return iter(self.tables)
def __contains__(self, obj):
return obj in self.tables
def __add__(self, obj):
return self.tables + list(obj)
def add(self, table):
self.tables.append(table)
if hasattr(self, '_sorted'):
del self._sorted
def sort(self, reverse=False):
try:
sorted = self._sorted
except AttributeError, e:
self._sorted = self._do_sort()
sorted = self._sorted
if reverse:
x = sorted[:]
x.reverse()
return x
else:
return sorted
def _do_sort(self):
tuples = []
class TVisitor(schema.SchemaVisitor):
def visit_foreign_key(_self, fkey):
if fkey.use_alter:
return
parent_table = fkey.column.table
if parent_table in self:
child_table = fkey.parent.table
tuples.append( ( parent_table, child_table ) )
vis = TVisitor()
for table in self.tables:
vis.traverse(table)
sorter = topological.QueueDependencySorter( tuples, self.tables )
head = sorter.sort()
sequence = []
def to_sequence( node, seq=sequence):
seq.append( node.item )
for child in node.children:
to_sequence( child )
if head is not None:
to_sequence( head )
return sequence
class TableFinder(TableCollection, sql.NoColumnVisitor):
"""locate all Tables within a clause."""
def __init__(self, table, check_columns=False, include_aliases=False):
TableCollection.__init__(self)
self.check_columns = check_columns
self.include_aliases = include_aliases
if table is not None:
self.traverse(table)
def visit_alias(self, alias):
if self.include_aliases:
self.tables.append(alias)
def visit_table(self, table):
self.tables.append(table)
def visit_column(self, column):
if self.check_columns:
self.traverse(column.table)
class ColumnFinder(sql.ClauseVisitor):
def __init__(self):
self.columns = util.Set()
def visit_column(self, c):
self.columns.add(c)
def __iter__(self):
return iter(self.columns)
class ColumnsInClause(sql.ClauseVisitor):
"""Given a selectable, visit clauses and determine if any columns
from the clause are in the selectable.
"""
def __init__(self, selectable):
self.selectable = selectable
self.result = False
def visit_column(self, column):
if self.selectable.c.get(column.key) is column:
self.result = True
class AbstractClauseProcessor(sql.NoColumnVisitor):
"""Traverse a clause and attempt to convert the contents of container elements
to a converted element.
The conversion operation is defined by subclasses.
"""
def convert_element(self, elem):
"""Define the *conversion* method for this ``AbstractClauseProcessor``."""
raise NotImplementedError()
def copy_and_process(self, list_):
"""Copy the container elements in the given list to a new list and
process the new list.
"""
list_ = [o.copy_container() for o in list_]
self.process_list(list_)
return list_
def process_list(self, list_):
"""Process all elements of the given list in-place."""
for i in range(0, len(list_)):
elem = self.convert_element(list_[i])
if elem is not None:
list_[i] = elem
else:
self.traverse(list_[i])
def visit_grouping(self, grouping):
elem = self.convert_element(grouping.elem)
if elem is not None:
grouping.elem = elem
def visit_clauselist(self, clist):
for i in range(0, len(clist.clauses)):
n = self.convert_element(clist.clauses[i])
if n is not None:
clist.clauses[i] = n
def visit_unary(self, unary):
elem = self.convert_element(unary.element)
if elem is not None:
unary.element = elem
def visit_binary(self, binary):
elem = self.convert_element(binary.left)
if elem is not None:
binary.left = elem
elem = self.convert_element(binary.right)
if elem is not None:
binary.right = elem
# TODO: visit_select().
class ClauseAdapter(AbstractClauseProcessor):
"""Given a clause (like as in a WHERE criterion), locate columns
which are embedded within a given selectable, and changes those
columns to be that of the selectable.
E.g.::
table1 = Table('sometable', metadata,
Column('col1', Integer),
Column('col2', Integer)
)
table2 = Table('someothertable', metadata,
Column('col1', Integer),
Column('col2', Integer)
)
condition = table1.c.col1 == table2.c.col1
and make an alias of table1::
s = table1.alias('foo')
calling ``ClauseAdapter(s).traverse(condition)`` converts
condition to read::
s.c.col1 == table2.c.col1
"""
def __init__(self, selectable, include=None, exclude=None, equivalents=None):
self.selectable = selectable
self.include = include
self.exclude = exclude
self.equivalents = equivalents
def convert_element(self, col):
if not isinstance(col, sql.ColumnElement):
return None
if self.include is not None:
if col not in self.include:
return None
if self.exclude is not None:
if col in self.exclude:
return None
newcol = self.selectable.corresponding_column(col, raiseerr=False, require_embedded=True, keys_ok=False)
if newcol is None and self.equivalents is not None and col in self.equivalents:
for equiv in self.equivalents[col]:
newcol = self.selectable.corresponding_column(equiv, raiseerr=False, require_embedded=True, keys_ok=False)
if newcol:
return newcol
return newcol
|