1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
|
#
# Copyright (C) 2019 FreeIPA Contributors see COPYING for license
#
"""ctypes wrapper for libldap_str2dn
"""
from __future__ import absolute_import
import ctypes
import ctypes.util
import six
__all__ = ("str2dn", "dn2str", "DECODING_ERROR", "LDAPError")
# load reentrant ldap client library (libldap_r-*.so.2 or libldap.so.2)
ldap_lib_filename = next(
filter(None, map(ctypes.util.find_library, ["ldap_r-2", "ldap"])), None
)
if ldap_lib_filename is None:
raise ImportError("libldap_r or libldap shared library missing")
try:
lib = ctypes.CDLL(ldap_lib_filename)
except OSError as e:
raise ImportError(str(e))
# constants
LDAP_AVA_FREE_ATTR = 0x0010
LDAP_AVA_FREE_VALUE = 0x0020
LDAP_DECODING_ERROR = -4
# mask for AVA flags
AVA_MASK = ~(LDAP_AVA_FREE_ATTR | LDAP_AVA_FREE_VALUE)
class berval(ctypes.Structure):
__slots__ = ()
_fields_ = [("bv_len", ctypes.c_ulong), ("bv_value", ctypes.c_char_p)]
def __bytes__(self):
buf = ctypes.create_string_buffer(self.bv_value, self.bv_len)
return buf.raw
def __str__(self):
return self.__bytes__().decode("utf-8")
if six.PY2:
__unicode__ = __str__
__str__ = __bytes__
class LDAPAVA(ctypes.Structure):
__slots__ = ()
_fields_ = [
("la_attr", berval),
("la_value", berval),
("la_flags", ctypes.c_uint16),
]
# typedef LDAPAVA** LDAPRDN;
LDAPRDN = ctypes.POINTER(ctypes.POINTER(LDAPAVA))
# typedef LDAPRDN* LDAPDN;
LDAPDN = ctypes.POINTER(LDAPRDN)
def errcheck(result, func, arguments):
if result != 0:
if result == LDAP_DECODING_ERROR:
raise DECODING_ERROR
else:
msg = ldap_err2string(result)
raise LDAPError(msg.decode("utf-8"))
return result
ldap_str2dn = lib.ldap_str2dn
ldap_str2dn.argtypes = (
ctypes.c_char_p,
ctypes.POINTER(LDAPDN),
ctypes.c_uint16,
)
ldap_str2dn.restype = ctypes.c_int16
ldap_str2dn.errcheck = errcheck
ldap_dnfree = lib.ldap_dnfree
ldap_dnfree.argtypes = (LDAPDN,)
ldap_dnfree.restype = None
ldap_err2string = lib.ldap_err2string
ldap_err2string.argtypes = (ctypes.c_int16,)
ldap_err2string.restype = ctypes.c_char_p
class LDAPError(Exception):
pass
class DECODING_ERROR(LDAPError):
pass
# RFC 4514, 2.4
_ESCAPE_CHARS = {'"', "+", ",", ";", "<", ">", "'", "\x00"}
def _escape_dn(dn):
if not dn:
return ""
result = []
# a space or number sign occurring at the beginning of the string
if dn[0] in {"#", " "}:
result.append("\\")
for c in dn:
if c in _ESCAPE_CHARS:
result.append("\\")
result.append(c)
# a space character occurring at the end of the string
if len(dn) > 1 and result[-1] == " ":
# insert before last entry
result.insert(-1, "\\")
return "".join(result)
def dn2str(dn):
return ",".join(
"+".join(
"=".join((attr, _escape_dn(value))) for attr, value, _flag in rdn
)
for rdn in dn
)
def str2dn(dn, flags=0):
if dn is None:
return []
if isinstance(dn, six.text_type):
dn = dn.encode("utf-8")
ldapdn = LDAPDN()
try:
ldap_str2dn(dn, ctypes.byref(ldapdn), flags)
result = []
if not ldapdn:
# empty DN, str2dn("") == []
return result
for rdn in ldapdn:
if not rdn:
break
avas = []
for ava_p in rdn:
if not ava_p:
break
ava = ava_p[0]
avas.append(
(
six.text_type(ava.la_attr),
six.text_type(ava.la_value),
ava.la_flags & AVA_MASK,
)
)
result.append(avas)
return result
finally:
ldap_dnfree(ldapdn)
|