1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310
|
# -*- coding: utf-8 -*-
# This software is distributed under the two-clause BSD license.
# Copyright (c) The django-ldapdb project
import collections
import re
import ldap
from django.db.models import aggregates
from django.db.models.sql import compiler
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE
from django.db.models.sql.where import AND, OR, WhereNode
from ldapdb import escape_ldap_filter
from ldapdb.models.fields import ListField
_ORDER_BY_LIMIT_OFFSET_RE = re.compile(
r"(?:\bORDER BY\b\s+([\w\.]+)\s(?P<order>\bASC\b)|(\bDESC\b))\s{1,2}(?:\bLIMIT\b\s+(?P<limit>-?\d+))?[\)\s]?(?:\bOFFSET\b\s+(?P<offset>(\d+)))?" # noqa: E501
)
class LdapDBError(Exception):
"""Base class for LDAPDB errors."""
LdapLookup = collections.namedtuple('LdapLookup', ['base', 'scope', 'filterstr'])
def query_as_ldap(query, compiler, connection):
"""Convert a django.db.models.sql.query.Query to a LdapLookup."""
if query.is_empty():
return
if query.model._meta.model_name == 'migration' and not hasattr(query.model, 'object_classes'):
# FIXME(rbarrois): Support migrations
return
# FIXME(rbarrois): this could be an extra Where clause
filterstr = ''.join(['(objectClass=%s)' % cls for cls in
query.model.object_classes])
# FIXME(rbarrois): Remove this code as part of #101
if (len(query.where.children) == 1
and not isinstance(query.where.children[0], WhereNode)
and query.where.children[0].lhs.target.column == 'dn'):
lookup = query.where.children[0]
if lookup.lookup_name != 'exact':
raise LdapDBError("Unsupported dn lookup: %s" % lookup.lookup_name)
return LdapLookup(
base=lookup.rhs,
scope=ldap.SCOPE_BASE,
filterstr='(&%s)' % filterstr,
)
sql, params = compiler.compile(query.where)
if sql:
filterstr += '(%s)' % (sql % tuple(escape_ldap_filter(param) for param in params))
return LdapLookup(
base=query.model.base_dn,
scope=query.model.search_scope,
filterstr='(&%s)' % filterstr,
)
def where_node_as_ldap(where, compiler, connection):
"""Parse a django.db.models.sql.where.WhereNode.
Returns:
(clause, [params]): the filter clause, with a list of unescaped parameters.
"""
bits, params = [], []
for item in where.children:
if isinstance(item, WhereNode):
clause, clause_params = compiler.compile(item)
else:
clause, clause_params = item.as_sql(compiler, connection)
bits.append(clause)
params.extend(clause_params)
if not bits:
return '', []
# FIXME(rbarrois): shouldn't we flatten recursive AND / OR?
if len(bits) == 1:
clause = bits[0]
elif where.connector == AND:
clause = '&' + ''.join('(%s)' % bit for bit in bits)
elif where.connector == OR:
clause = '|' + ''.join('(%s)' % bit for bit in bits)
else:
raise LdapDBError("Unhandled WHERE connector: %s" % where.connector)
if where.negated:
clause = ('!(%s)' % clause)
return clause, params
class SQLCompiler(compiler.SQLCompiler):
"""LDAP-based SQL compiler."""
def compile(self, node, *args, **kwargs):
"""Parse a WhereNode to a LDAP filter string."""
if isinstance(node, WhereNode):
return where_node_as_ldap(node, self, self.connection)
return super().compile(node, *args, **kwargs)
def execute_sql(self, result_type=compiler.SINGLE, chunked_fetch=False,
chunk_size=GET_ITERATOR_CHUNK_SIZE):
if result_type != compiler.SINGLE:
raise Exception("LDAP does not support MULTI queries")
# Setup self.select, self.klass_info, self.annotation_col_map
# All expected from ModelIterable.__iter__
self.pre_sql_setup()
lookup = query_as_ldap(self.query, compiler=self, connection=self.connection)
if lookup is None:
return
try:
vals = self.connection.search_s(
base=lookup.base,
scope=lookup.scope,
filterstr=lookup.filterstr,
attrlist=['dn'],
)
# Flatten iterator
vals = list(vals)
except ldap.NO_SUCH_OBJECT:
vals = []
if not vals:
return None
output = []
self.setup_query()
for e in self.select:
if isinstance(e[0], aggregates.Count):
# Check if the SQL query has a limit value and append
# that value, else append the length of the return values
# from LDAP.
sql = self.as_sql()[0]
if hasattr(self.query, 'subquery') and self.query.subquery:
sql = self.query.subquery
m = _ORDER_BY_LIMIT_OFFSET_RE.search(sql)
if m:
limit = m.group('limit')
offset = m.group('offset')
if limit and int(limit) >= 0:
output.append(int(limit))
elif offset:
output.append(len(vals) - int(offset))
else:
output.append(len(vals))
else:
output.append(e[0])
return output
def results_iter(self, results=None, tuple_expected=False, chunked_fetch=False, chunk_size=GET_ITERATOR_CHUNK_SIZE):
lookup = query_as_ldap(self.query, compiler=self, connection=self.connection)
if lookup is None:
return
if len(self.query.select):
fields = [x.field for x in self.query.select]
else:
fields = self.query.model._meta.fields
attrlist = [x.db_column for x in fields if x.db_column]
try:
vals = self.connection.search_s(
base=lookup.base,
scope=lookup.scope,
filterstr=lookup.filterstr,
attrlist=attrlist,
)
except ldap.NO_SUCH_OBJECT:
return
# perform sorting
if self.query.extra_order_by:
ordering = self.query.extra_order_by
elif not self.query.default_ordering:
ordering = self.query.order_by
else:
ordering = self.query.order_by or self.query.model._meta.ordering
for fieldname in reversed(ordering):
if fieldname.startswith('-'):
sort_field = fieldname[1:]
reverse = True
else:
sort_field = fieldname
reverse = False
if sort_field == 'pk':
sort_field = self.query.model._meta.pk.name
field = self.query.model._meta.get_field(sort_field)
if sort_field == 'dn':
vals = sorted(vals, key=lambda pair: pair[0], reverse=reverse)
else:
def get_key(obj):
attr = field.from_ldap(
obj[1].get(field.db_column, []),
connection=self.connection,
)
if hasattr(attr, 'lower'):
attr = attr.lower()
return attr
vals = sorted(vals, key=get_key, reverse=reverse)
# process results
pos = 0
results = []
for dn, attrs in vals:
# FIXME : This is not optimal, we retrieve more results than we
# need but there is probably no other options as we can't perform
# ordering server side.
if (self.query.low_mark and pos < self.query.low_mark) or \
(self.query.high_mark is not None
and pos >= self.query.high_mark):
pos += 1
continue
row = []
self.setup_query()
for e in self.select:
if isinstance(e[0], aggregates.Count):
value = 0
input_field = e[0].get_source_expressions()[0].field
if input_field.get_attname() == 'dn':
value = 1
elif hasattr(input_field, 'from_ldap'):
result = input_field.from_ldap(
attrs.get(input_field.db_column, []),
connection=self.connection)
if result:
value = 1
if isinstance(input_field, ListField):
value = len(result)
row.append(value)
else:
if e[0].field.get_attname() == 'dn':
row.append(dn)
elif hasattr(e[0].field, 'from_ldap'):
row.append(e[0].field.from_ldap(
attrs.get(e[0].field.db_column, []),
connection=self.connection))
else:
row.append(None)
if self.query.distinct:
if row in results:
continue
else:
results.append(row)
yield row
pos += 1
def has_results(self):
import inspect
iterator = self.results_iter()
if inspect.isgenerator(iterator):
try:
next(iterator)
return True
except StopIteration:
return False
else:
return False
class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
pass
class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
def execute_sql(self, result_type=compiler.MULTI):
lookup = query_as_ldap(self.query, compiler=self, connection=self.connection)
if not lookup:
return
try:
vals = self.connection.search_s(
base=lookup.base,
scope=lookup.scope,
filterstr=lookup.filterstr,
attrlist=['dn'],
)
except ldap.NO_SUCH_OBJECT:
return
# FIXME : there is probably a more efficient way to do this
for dn, attrs in vals:
self.connection.delete_s(dn)
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
pass
class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
def execute_sql(self, result_type=compiler.SINGLE):
# Return only number values through the aggregate compiler
output = super().execute_sql(result_type)
return filter(lambda a: isinstance(a, int), output)
|