0001from . import dbconnection
0002from . import sqlbuilder
0003from .compat import string_type
0004
0005
0006__all__ = ['SelectResults']
0007
0008
0009class SelectResults(object):
0010 IterationClass = dbconnection.Iteration
0011
0012 def __init__(self, sourceClass, clause, clauseTables=None,
0013 **ops):
0014 self.sourceClass = sourceClass
0015 if clause is None or isinstance(clause, str) and clause == 'all':
0016 clause = sqlbuilder.SQLTrueClause
0017 if not isinstance(clause, sqlbuilder.SQLExpression):
0018 clause = sqlbuilder.SQLConstant(clause)
0019 self.clause = clause
0020 self.ops = ops
0021 if ops.get('orderBy', sqlbuilder.NoDefault) is sqlbuilder.NoDefault:
0022 ops['orderBy'] = sourceClass.sqlmeta.defaultOrder
0023 orderBy = ops['orderBy']
0024 if isinstance(orderBy, (tuple, list)):
0025 orderBy = list(map(self._mungeOrderBy, orderBy))
0026 else:
0027 orderBy = self._mungeOrderBy(orderBy)
0028 ops['dbOrderBy'] = orderBy
0029 if 'connection' in ops and ops['connection'] is None:
0030 del ops['connection']
0031 if ops.get('limit', None):
0032 assert not ops.get('start', None) and not ops.get('end', None), "'limit' cannot be used with 'start' or 'end'"
0034 ops["start"] = 0
0035 ops["end"] = ops.pop("limit")
0036
0037 tablesSet = sqlbuilder.tablesUsedSet(self.clause,
0038 self._getConnection().dbName)
0039 if clauseTables:
0040 for table in clauseTables:
0041 tablesSet.add(table)
0042 self.clauseTables = clauseTables
0043
0044
0045 self.tables = list(tablesSet) + [sourceClass.sqlmeta.table]
0046
0047 def queryForSelect(self):
0048 columns = [self.sourceClass.q.id] + [getattr(self.sourceClass.q, x.name)
0050 for x in self.sourceClass.sqlmeta.columnList]
0051 query = sqlbuilder.Select(columns,
0052 where=self.clause,
0053 join=self.ops.get(
0054 'join', sqlbuilder.NoDefault),
0055 distinct=self.ops.get('distinct', False),
0056 lazyColumns=self.ops.get(
0057 'lazyColumns', False),
0058 start=self.ops.get('start', 0),
0059 end=self.ops.get('end', None),
0060 orderBy=self.ops.get(
0061 'dbOrderBy', sqlbuilder.NoDefault),
0062 reversed=self.ops.get('reversed', False),
0063 staticTables=self.tables,
0064 forUpdate=self.ops.get('forUpdate', False))
0065 return query
0066
0067 def __repr__(self):
0068 return "<%s at %x>" % (self.__class__.__name__, id(self))
0069
0070 def _getConnection(self):
0071 return self.ops.get('connection') or self.sourceClass._connection
0072
0073 def __str__(self):
0074 conn = self._getConnection()
0075 return conn.queryForSelect(self)
0076
0077 def _mungeOrderBy(self, orderBy):
0078 if isinstance(orderBy, string_type) and orderBy.startswith('-'):
0079 orderBy = orderBy[1:]
0080 desc = True
0081 else:
0082 desc = False
0083 if isinstance(orderBy, string_type):
0084 if orderBy in self.sourceClass.sqlmeta.columns:
0085 val = getattr(self.sourceClass.q,
0086 self.sourceClass.sqlmeta.columns[orderBy].name)
0087 if desc:
0088 return sqlbuilder.DESC(val)
0089 else:
0090 return val
0091 else:
0092 orderBy = sqlbuilder.SQLConstant(orderBy)
0093 if desc:
0094 return sqlbuilder.DESC(orderBy)
0095 else:
0096 return orderBy
0097 else:
0098 return orderBy
0099
0100 def clone(self, **newOps):
0101 ops = self.ops.copy()
0102 ops.update(newOps)
0103 return self.__class__(self.sourceClass, self.clause,
0104 self.clauseTables, **ops)
0105
0106 def orderBy(self, orderBy):
0107 return self.clone(orderBy=orderBy)
0108
0109 def connection(self, conn):
0110 return self.clone(connection=conn)
0111
0112 def limit(self, limit):
0113 return self[:limit]
0114
0115 def lazyColumns(self, value):
0116 return self.clone(lazyColumns=value)
0117
0118 def reversed(self):
0119 return self.clone(reversed=not self.ops.get('reversed', False))
0120
0121 def distinct(self):
0122 return self.clone(distinct=True)
0123
0124 def newClause(self, new_clause):
0125 return self.__class__(self.sourceClass, new_clause,
0126 self.clauseTables, **self.ops)
0127
0128 def filter(self, filter_clause):
0129 if filter_clause is None:
0130
0131 return self
0132 clause = self.clause
0133 if isinstance(clause, string_type):
0134 clause = sqlbuilder.SQLConstant('(%s)' % clause)
0135 return self.newClause(sqlbuilder.AND(clause, filter_clause))
0136
0137 def __getitem__(self, value):
0138 if isinstance(value, slice):
0139 assert not value.step, "Slices do not support steps"
0140 if not value.start and not value.stop:
0141
0142 return self
0143
0144
0145
0146
0147 if (value.start and value.start < 0) or (value.stop and value.stop < 0):
0149 if value.start:
0150 if value.stop:
0151 return list(self)[value.start:value.stop]
0152 return list(self)[value.start:]
0153 return list(self)[:value.stop]
0154
0155 if value.start:
0156 assert value.start >= 0
0157 start = self.ops.get('start', 0) + value.start
0158 if value.stop is not None:
0159 assert value.stop >= 0
0160 if value.stop < value.start:
0161
0162 end = start
0163 else:
0164 end = value.stop + self.ops.get('start', 0)
0165 if self.ops.get('end', None) is not None and self.ops['end'] < end:
0167
0168 end = self.ops['end']
0169 else:
0170 end = self.ops.get('end', None)
0171 else:
0172 start = self.ops.get('start', 0)
0173 end = value.stop + start
0174 if self.ops.get('end', None) is not None and self.ops['end'] < end:
0176 end = self.ops['end']
0177 return self.clone(start=start, end=end)
0178 else:
0179 if value < 0:
0180 return list(iter(self))[value]
0181 else:
0182 start = self.ops.get('start', 0) + value
0183 return list(self.clone(start=start, end=start + 1))[0]
0184
0185 def __iter__(self):
0186
0187
0188
0189 return iter(list(self.lazyIter()))
0190
0191 def lazyIter(self):
0192 """
0193 Returns an iterator that will lazily pull rows out of the
0194 database and return SQLObject instances
0195 """
0196 conn = self._getConnection()
0197 return conn.iterSelect(self)
0198
0199 def accumulate(self, *expressions):
0200 """ Use accumulate expression(s) to select result
0201 using another SQL select through current
0202 connection.
0203 Return the accumulate result
0204 """
0205 conn = self._getConnection()
0206 exprs = []
0207 for expr in expressions:
0208 if not isinstance(expr, sqlbuilder.SQLExpression):
0209 expr = sqlbuilder.SQLConstant(expr)
0210 exprs.append(expr)
0211 return conn.accumulateSelect(self, *exprs)
0212
0213 def count(self):
0214 """ Counting elements of current select results """
0215 assert not self.ops.get('start') and not self.ops.get('end'), "start/end/limit have no meaning with 'count'"
0217 assert not (self.ops.get('distinct') and
0218 (self.ops.get('start') or self.ops.get('end'))), "distinct-counting of sliced objects is not supported"
0220 if self.ops.get('distinct'):
0221
0222
0223
0224
0225 count = self.accumulate(
0226 'COUNT(DISTINCT %s)' % self._getConnection().sqlrepr(
0227 self.sourceClass.q.id))
0228 else:
0229 count = self.accumulate('COUNT(*)')
0230 if self.ops.get('start'):
0231 count -= self.ops['start']
0232 if self.ops.get('end'):
0233 count = min(self.ops['end'] - self.ops.get('start', 0), count)
0234 return count
0235
0236 def accumulateMany(self, *attributes):
0237 """ Making the expressions for count/sum/min/max/avg
0238 of a given select result attributes.
0239 `attributes` must be a list/tuple of pairs (func_name, attribute);
0240 `attribute` can be a column name (like 'a_column')
0241 or a dot-q attribute (like Table.q.aColumn)
0242 """
0243 expressions = []
0244 conn = self._getConnection()
0245 if self.ops.get('distinct'):
0246 distinct = 'DISTINCT '
0247 else:
0248 distinct = ''
0249 for func_name, attribute in attributes:
0250 if not isinstance(attribute, str):
0251 attribute = conn.sqlrepr(attribute)
0252 expression = '%s(%s%s)' % (func_name, distinct, attribute)
0253 expressions.append(expression)
0254 return self.accumulate(*expressions)
0255
0256 def accumulateOne(self, func_name, attribute):
0257 """ Making the sum/min/max/avg of a given select result attribute.
0258 `attribute` can be a column name (like 'a_column')
0259 or a dot-q attribute (like Table.q.aColumn)
0260 """
0261 return self.accumulateMany((func_name, attribute))
0262
0263 def sum(self, attribute):
0264 return self.accumulateOne("SUM", attribute)
0265
0266 def min(self, attribute):
0267 return self.accumulateOne("MIN", attribute)
0268
0269 def avg(self, attribute):
0270 return self.accumulateOne("AVG", attribute)
0271
0272 def max(self, attribute):
0273 return self.accumulateOne("MAX", attribute)
0274
0275 def getOne(self, default=sqlbuilder.NoDefault):
0276 """
0277 If a query is expected to only return a single value,
0278 using ``.getOne()`` will return just that value.
0279
0280 If not results are found, ``SQLObjectNotFound`` will be
0281 raised, unless you pass in a default value (like
0282 ``.getOne(None)``).
0283
0284 If more than one result is returned,
0285 ``SQLObjectIntegrityError`` will be raised.
0286 """
0287 from . import main
0288 results = list(self)
0289 if not results:
0290 if default is sqlbuilder.NoDefault:
0291 raise main.SQLObjectNotFound(
0292 "No results matched the query for %s"
0293 % self.sourceClass.__name__)
0294 return default
0295 if len(results) > 1:
0296 raise main.SQLObjectIntegrityError(
0297 "More than one result returned from query: %s"
0298 % results)
0299 return results[0]
0300
0301 def throughTo(self):
0302 class _throughTo_getter(object):
0303 def __init__(self, inst):
0304 self.sresult = inst
0305
0306 def __getattr__(self, attr):
0307 return self.sresult._throughTo(attr)
0308 return _throughTo_getter(self)
0309 throughTo = property(throughTo)
0310
0311 def _throughTo(self, attr):
0312 otherClass = None
0313 orderBy = sqlbuilder.NoDefault
0314
0315 ref = self.sourceClass.sqlmeta.columns.get(
0316 attr.endswith('ID') and attr or attr + 'ID', None)
0317 if ref and ref.foreignKey:
0318 otherClass, clause = self._throughToFK(ref)
0319 else:
0320 join = [x for x in self.sourceClass.sqlmeta.joins
0321 if x.joinMethodName == attr]
0322 if join:
0323 join = join[0]
0324 orderBy = join.orderBy
0325 if hasattr(join, 'otherColumn'):
0326 otherClass, clause = self._throughToRelatedJoin(join)
0327 else:
0328 otherClass, clause = self._throughToMultipleJoin(join)
0329
0330 if not otherClass:
0331 raise AttributeError(
0332 "throughTo argument (got %s) should be "
0333 "name of foreignKey or SQL*Join in %s" % (attr,
0334 self.sourceClass))
0335
0336 return otherClass.select(clause,
0337 orderBy=orderBy,
0338 connection=self._getConnection())
0339
0340 def _throughToFK(self, col):
0341 otherClass = getattr(self.sourceClass, "_SO_class_" + col.foreignKey)
0342 colName = col.name
0343 query = self.queryForSelect().newItems([
0344 sqlbuilder.ColumnAS(getattr(self.sourceClass.q, colName), colName)
0345 ]).orderBy(None).distinct()
0346 query = sqlbuilder.Alias(query,
0347 "%s_%s" % (self.sourceClass.__name__,
0348 col.name))
0349 return otherClass, otherClass.q.id == getattr(query.q, colName)
0350
0351 def _throughToMultipleJoin(self, join):
0352 otherClass = join.otherClass
0353 colName = join.soClass.sqlmeta.style. dbColumnToPythonAttr(join.joinColumn)
0355 query = self.queryForSelect().newItems(
0356 [sqlbuilder.ColumnAS(self.sourceClass.q.id, 'id')]). orderBy(None).distinct()
0358 query = sqlbuilder.Alias(query,
0359 "%s_%s" % (self.sourceClass.__name__,
0360 join.joinMethodName))
0361 joinColumn = getattr(otherClass.q, colName)
0362 return otherClass, joinColumn == query.q.id
0363
0364 def _throughToRelatedJoin(self, join):
0365 otherClass = join.otherClass
0366 intTable = sqlbuilder.Table(join.intermediateTable)
0367 colName = join.joinColumn
0368 query = self.queryForSelect().newItems(
0369 [sqlbuilder.ColumnAS(self.sourceClass.q.id, 'id')]). orderBy(None).distinct()
0371 query = sqlbuilder.Alias(query,
0372 "%s_%s" % (self.sourceClass.__name__,
0373 join.joinMethodName))
0374 clause = sqlbuilder.AND(
0375 otherClass.q.id == getattr(intTable, join.otherColumn),
0376 getattr(intTable, colName) == query.q.id)
0377 return otherClass, clause