File: dn_ctypes.py

package info (click to toggle)
freeipa 4.12.4-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 100,668 kB
  • sloc: python: 298,952; javascript: 71,606; ansic: 49,369; sh: 6,547; makefile: 2,553; xml: 343; sed: 16
file content (168 lines) | stat: -rw-r--r-- 3,905 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#
# Copyright (C) 2019  FreeIPA Contributors see COPYING for license
#
"""ctypes wrapper for libldap_str2dn
"""
from __future__ import absolute_import

import ctypes
import ctypes.util

import six

__all__ = ("str2dn", "dn2str", "DECODING_ERROR", "LDAPError")

# load reentrant ldap client library (libldap_r-*.so.2 or libldap.so.2)
ldap_lib_filename = next(
    filter(None, map(ctypes.util.find_library, ["ldap_r-2", "ldap"])), None
)

if ldap_lib_filename is None:
    raise ImportError("libldap_r or libldap shared library missing")
try:
    lib = ctypes.CDLL(ldap_lib_filename)
except OSError as e:
    raise ImportError(str(e))

# constants
LDAP_AVA_FREE_ATTR = 0x0010
LDAP_AVA_FREE_VALUE = 0x0020
LDAP_DECODING_ERROR = -4

# mask for AVA flags
AVA_MASK = ~(LDAP_AVA_FREE_ATTR | LDAP_AVA_FREE_VALUE)


class berval(ctypes.Structure):
    __slots__ = ()
    _fields_ = [("bv_len", ctypes.c_ulong), ("bv_value", ctypes.c_char_p)]

    def __bytes__(self):
        buf = ctypes.create_string_buffer(self.bv_value, self.bv_len)
        return buf.raw

    def __str__(self):
        return self.__bytes__().decode("utf-8")

    if six.PY2:
        __unicode__ = __str__
        __str__ = __bytes__


class LDAPAVA(ctypes.Structure):
    __slots__ = ()
    _fields_ = [
        ("la_attr", berval),
        ("la_value", berval),
        ("la_flags", ctypes.c_uint16),
    ]


# typedef LDAPAVA** LDAPRDN;
LDAPRDN = ctypes.POINTER(ctypes.POINTER(LDAPAVA))
# typedef LDAPRDN* LDAPDN;
LDAPDN = ctypes.POINTER(LDAPRDN)


def errcheck(result, func, arguments):
    if result != 0:
        if result == LDAP_DECODING_ERROR:
            raise DECODING_ERROR
        else:
            msg = ldap_err2string(result)
            raise LDAPError(msg.decode("utf-8"))
    return result


ldap_str2dn = lib.ldap_str2dn
ldap_str2dn.argtypes = (
    ctypes.c_char_p,
    ctypes.POINTER(LDAPDN),
    ctypes.c_uint16,
)
ldap_str2dn.restype = ctypes.c_int16
ldap_str2dn.errcheck = errcheck

ldap_dnfree = lib.ldap_dnfree
ldap_dnfree.argtypes = (LDAPDN,)
ldap_dnfree.restype = None

ldap_err2string = lib.ldap_err2string
ldap_err2string.argtypes = (ctypes.c_int16,)
ldap_err2string.restype = ctypes.c_char_p


class LDAPError(Exception):
    pass


class DECODING_ERROR(LDAPError):
    pass


# RFC 4514, 2.4
_ESCAPE_CHARS = {'"', "+", ",", ";", "<", ">", "'", "\x00"}


def _escape_dn(dn):
    if not dn:
        return ""
    result = []
    # a space or number sign occurring at the beginning of the string
    if dn[0] in {"#", " "}:
        result.append("\\")
    for c in dn:
        if c in _ESCAPE_CHARS:
            result.append("\\")
        result.append(c)
    # a space character occurring at the end of the string
    if len(dn) > 1 and result[-1] == " ":
        # insert before last entry
        result.insert(-1, "\\")
    return "".join(result)


def dn2str(dn):
    return ",".join(
        "+".join(
            "=".join((attr, _escape_dn(value))) for attr, value, _flag in rdn
        )
        for rdn in dn
    )


def str2dn(dn, flags=0):
    if dn is None:
        return []
    if isinstance(dn, six.text_type):
        dn = dn.encode("utf-8")

    ldapdn = LDAPDN()
    try:
        ldap_str2dn(dn, ctypes.byref(ldapdn), flags)

        result = []
        if not ldapdn:
            # empty DN, str2dn("") == []
            return result

        for rdn in ldapdn:
            if not rdn:
                break
            avas = []
            for ava_p in rdn:
                if not ava_p:
                    break
                ava = ava_p[0]
                avas.append(
                    (
                        six.text_type(ava.la_attr),
                        six.text_type(ava.la_value),
                        ava.la_flags & AVA_MASK,
                    )
                )
            result.append(avas)

        return result
    finally:
        ldap_dnfree(ldapdn)