#
# Copyright (C) 2019  FreeIPA Contributors see COPYING for license
#
"""ctypes wrapper for libldap_str2dn
"""
from __future__ import absolute_import

import ctypes
import ctypes.util

import six

__all__ = ("str2dn", "dn2str", "DECODING_ERROR", "LDAPError")

# load reentrant ldap client library (libldap_r-*.so.2 or libldap.so.2)
ldap_lib_filename = next(
    filter(None, map(ctypes.util.find_library, ["ldap_r-2", "ldap"])), None
)

if ldap_lib_filename is None:
    raise ImportError("libldap_r or libldap shared library missing")
try:
    lib = ctypes.CDLL(ldap_lib_filename)
except OSError as e:
    raise ImportError(str(e))

# constants
LDAP_AVA_FREE_ATTR = 0x0010
LDAP_AVA_FREE_VALUE = 0x0020
LDAP_DECODING_ERROR = -4

# mask for AVA flags
AVA_MASK = ~(LDAP_AVA_FREE_ATTR | LDAP_AVA_FREE_VALUE)


class berval(ctypes.Structure):
    __slots__ = ()
    _fields_ = [("bv_len", ctypes.c_ulong), ("bv_value", ctypes.c_char_p)]

    def __bytes__(self):
        buf = ctypes.create_string_buffer(self.bv_value, self.bv_len)
        return buf.raw

    def __str__(self):
        return self.__bytes__().decode("utf-8")

    if six.PY2:
        __unicode__ = __str__
        __str__ = __bytes__


class LDAPAVA(ctypes.Structure):
    __slots__ = ()
    _fields_ = [
        ("la_attr", berval),
        ("la_value", berval),
        ("la_flags", ctypes.c_uint16),
    ]


# typedef LDAPAVA** LDAPRDN;
LDAPRDN = ctypes.POINTER(ctypes.POINTER(LDAPAVA))
# typedef LDAPRDN* LDAPDN;
LDAPDN = ctypes.POINTER(LDAPRDN)


def errcheck(result, func, arguments):
    if result != 0:
        if result == LDAP_DECODING_ERROR:
            raise DECODING_ERROR
        else:
            msg = ldap_err2string(result)
            raise LDAPError(msg.decode("utf-8"))
    return result


ldap_str2dn = lib.ldap_str2dn
ldap_str2dn.argtypes = (
    ctypes.c_char_p,
    ctypes.POINTER(LDAPDN),
    ctypes.c_uint16,
)
ldap_str2dn.restype = ctypes.c_int16
ldap_str2dn.errcheck = errcheck

ldap_dnfree = lib.ldap_dnfree
ldap_dnfree.argtypes = (LDAPDN,)
ldap_dnfree.restype = None

ldap_err2string = lib.ldap_err2string
ldap_err2string.argtypes = (ctypes.c_int16,)
ldap_err2string.restype = ctypes.c_char_p


class LDAPError(Exception):
    pass


class DECODING_ERROR(LDAPError):
    pass


# RFC 4514, 2.4
_ESCAPE_CHARS = {'"', "+", ",", ";", "<", ">", "'", "\x00"}


def _escape_dn(dn):
    if not dn:
        return ""
    result = []
    # a space or number sign occurring at the beginning of the string
    if dn[0] in {"#", " "}:
        result.append("\\")
    for c in dn:
        if c in _ESCAPE_CHARS:
            result.append("\\")
        result.append(c)
    # a space character occurring at the end of the string
    if len(dn) > 1 and result[-1] == " ":
        # insert before last entry
        result.insert(-1, "\\")
    return "".join(result)


def dn2str(dn):
    return ",".join(
        "+".join(
            "=".join((attr, _escape_dn(value))) for attr, value, _flag in rdn
        )
        for rdn in dn
    )


def str2dn(dn, flags=0):
    if dn is None:
        return []
    if isinstance(dn, six.text_type):
        dn = dn.encode("utf-8")

    ldapdn = LDAPDN()
    try:
        ldap_str2dn(dn, ctypes.byref(ldapdn), flags)

        result = []
        if not ldapdn:
            # empty DN, str2dn("") == []
            return result

        for rdn in ldapdn:
            if not rdn:
                break
            avas = []
            for ava_p in rdn:
                if not ava_p:
                    break
                ava = ava_p[0]
                avas.append(
                    (
                        six.text_type(ava.la_attr),
                        six.text_type(ava.la_value),
                        ava.la_flags & AVA_MASK,
                    )
                )
            result.append(avas)

        return result
    finally:
        ldap_dnfree(ldapdn)
