# -*- coding: utf-8 -*-
"""
Django Extensions additional model fields

Some fields might require additional dependencies to be installed.
"""

import re
import string

try:
    import uuid

    HAS_UUID = True
except ImportError:
    HAS_UUID = False

try:
    import shortuuid

    HAS_SHORT_UUID = True
except ImportError:
    HAS_SHORT_UUID = False

from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.db.models import DateTimeField, CharField, SlugField, Q, UniqueConstraint
from django.db.models.constants import LOOKUP_SEP
from django.template.defaultfilters import slugify
from django.utils.crypto import get_random_string
from django.utils.encoding import force_str


MAX_UNIQUE_QUERY_ATTEMPTS = getattr(
    settings, "EXTENSIONS_MAX_UNIQUE_QUERY_ATTEMPTS", 100
)


class UniqueFieldMixin:
    def check_is_bool(self, attrname):
        if not isinstance(getattr(self, attrname), bool):
            raise ValueError("'{}' argument must be True or False".format(attrname))

    @staticmethod
    def _get_fields(model_cls):
        return [
            (f, f.model if f.model != model_cls else None)
            for f in model_cls._meta.get_fields()
            if not f.is_relation or f.one_to_one or (f.many_to_one and f.related_model)
        ]

    def get_queryset(self, model_cls, slug_field):
        for field, model in self._get_fields(model_cls):
            if model and field == slug_field:
                return model._default_manager.all()
        return model_cls._default_manager.all()

    def find_unique(self, model_instance, field, iterator, *args):
        # exclude the current model instance from the queryset used in finding
        # next valid hash
        queryset = self.get_queryset(model_instance.__class__, field)
        if model_instance.pk:
            queryset = queryset.exclude(pk=model_instance.pk)

        # form a kwarg dict used to implement any unique_together constraints
        kwargs = {}
        for params in model_instance._meta.unique_together:
            if self.attname in params:
                for param in params:
                    kwargs[param] = getattr(model_instance, param, None)

        # for support django 2.2+
        query = Q()
        constraints = getattr(model_instance._meta, "constraints", None)
        if constraints:
            unique_constraints = filter(
                lambda c: isinstance(c, UniqueConstraint), constraints
            )
            for unique_constraint in unique_constraints:
                if self.attname in unique_constraint.fields:
                    condition = {
                        field: getattr(model_instance, field, None)
                        for field in unique_constraint.fields
                        if field != self.attname
                    }
                    query &= Q(**condition)

        new = next(iterator)
        kwargs[self.attname] = new
        while not new or queryset.filter(query, **kwargs):
            new = next(iterator)
            kwargs[self.attname] = new
        setattr(model_instance, self.attname, new)
        return new


class AutoSlugField(UniqueFieldMixin, SlugField):
    """
    AutoSlugField

    By default, sets editable=False, blank=True.

    Required arguments:

    populate_from
        Specifies which field, list of fields, or model method
        the slug will be populated from.

        populate_from can traverse a ForeignKey relationship
        by using Django ORM syntax:
            populate_from = 'related_model__field'

    Optional arguments:

    separator
        Defines the used separator (default: '-')

    overwrite
        If set to True, overwrites the slug on every save (default: False)

    slugify_function
        Defines the function which will be used to "slugify" a content
        (default: :py:func:`~django.template.defaultfilters.slugify` )

    It is possible to provide custom "slugify" function with
    the ``slugify_function`` function in a model class.

    ``slugify_function`` function in a model class takes priority over
    ``slugify_function`` given as an argument to :py:class:`~AutoSlugField`.

    Example

    .. code-block:: python

        # models.py

        from django.db import models

        from django_extensions.db.fields import AutoSlugField


        class MyModel(models.Model):
            def slugify_function(self, content):
                return content.replace('_', '-').lower()

            title = models.CharField(max_length=42)
            slug = AutoSlugField(populate_from='title')

    Inspired by SmileyChris' Unique Slugify snippet:
    https://www.djangosnippets.org/snippets/690/
    """

    def __init__(self, *args, **kwargs):
        kwargs.setdefault("blank", True)
        kwargs.setdefault("editable", False)

        populate_from = kwargs.pop("populate_from", None)
        if populate_from is None:
            raise ValueError("missing 'populate_from' argument")
        else:
            self._populate_from = populate_from

        if not callable(populate_from):
            if not isinstance(populate_from, (list, tuple)):
                populate_from = (populate_from,)

            if not all(isinstance(e, str) for e in populate_from):
                raise TypeError(
                    "'populate_from' must be str or list[str] or tuple[str], found `%s`"
                    % populate_from
                )

        self.slugify_function = kwargs.pop("slugify_function", slugify)
        self.separator = kwargs.pop("separator", "-")
        self.overwrite = kwargs.pop("overwrite", False)
        self.check_is_bool("overwrite")
        self.overwrite_on_add = kwargs.pop("overwrite_on_add", True)
        self.check_is_bool("overwrite_on_add")
        self.allow_duplicates = kwargs.pop("allow_duplicates", False)
        self.check_is_bool("allow_duplicates")
        self.max_unique_query_attempts = kwargs.pop(
            "max_unique_query_attempts", MAX_UNIQUE_QUERY_ATTEMPTS
        )
        super().__init__(*args, **kwargs)

    def _slug_strip(self, value):
        """
        Clean up a slug by removing slug separator characters that occur at
        the beginning or end of a slug.

        If an alternate separator is used, it will also replace any instances
        of the default '-' separator with the new separator.
        """
        re_sep = "(?:-|%s)" % re.escape(self.separator)
        value = re.sub("%s+" % re_sep, self.separator, value)
        return re.sub(r"^%s+|%s+$" % (re_sep, re_sep), "", value)

    @staticmethod
    def slugify_func(content, slugify_function):
        if content:
            return slugify_function(content)
        return ""

    def slug_generator(self, original_slug, start):
        yield original_slug
        for i in range(start, self.max_unique_query_attempts):
            slug = original_slug
            end = "%s%s" % (self.separator, i)
            end_len = len(end)
            if self.slug_len and len(slug) + end_len > self.slug_len:
                slug = slug[: self.slug_len - end_len]
                slug = self._slug_strip(slug)
            slug = "%s%s" % (slug, end)
            yield slug
        raise RuntimeError(
            "max slug attempts for %s exceeded (%s)"
            % (original_slug, self.max_unique_query_attempts)
        )

    def create_slug(self, model_instance, add):
        slug = getattr(model_instance, self.attname)
        use_existing_slug = False
        if slug and not self.overwrite:
            # Existing slug and not configured to overwrite - Short-circuit
            # here to prevent slug generation when not required.
            use_existing_slug = True

        if self.overwrite_on_add and add:
            use_existing_slug = False

        if use_existing_slug:
            return slug

        # get fields to populate from and slug field to set
        populate_from = self._populate_from
        if not isinstance(populate_from, (list, tuple)):
            populate_from = (populate_from,)

        slug_field = model_instance._meta.get_field(self.attname)
        slugify_function = getattr(
            model_instance, "slugify_function", self.slugify_function
        )

        # slugify the original field content and set next step to 2
        slug_for_field = lambda lookup_value: self.slugify_func(
            self.get_slug_fields(model_instance, lookup_value),
            slugify_function=slugify_function,
        )
        slug = self.separator.join(map(slug_for_field, populate_from))
        start = 2

        # strip slug depending on max_length attribute of the slug field
        # and clean-up
        self.slug_len = slug_field.max_length
        if self.slug_len:
            slug = slug[: self.slug_len]
        slug = self._slug_strip(slug)
        original_slug = slug

        if self.allow_duplicates:
            setattr(model_instance, self.attname, slug)
            return slug

        return self.find_unique(
            model_instance, slug_field, self.slug_generator(original_slug, start)
        )

    def get_slug_fields(self, model_instance, lookup_value):
        if callable(lookup_value):
            # A function has been provided
            return "%s" % lookup_value(model_instance)

        lookup_value_path = lookup_value.split(LOOKUP_SEP)
        attr = model_instance
        for elem in lookup_value_path:
            try:
                attr = getattr(attr, elem)
            except AttributeError:
                raise AttributeError(
                    "value {} in AutoSlugField's 'populate_from' argument {} returned an error - {} has no attribute {}".format(  # noqa: E501
                        elem, lookup_value, attr, elem
                    )
                )
        if callable(attr):
            return "%s" % attr()

        return attr

    def pre_save(self, model_instance, add):
        value = force_str(self.create_slug(model_instance, add))
        return value

    def get_internal_type(self):
        return "SlugField"

    def deconstruct(self):
        name, path, args, kwargs = super().deconstruct()
        kwargs["populate_from"] = self._populate_from
        if not self.separator == "-":
            kwargs["separator"] = self.separator
        if self.overwrite is not False:
            kwargs["overwrite"] = True
        if self.allow_duplicates is not False:
            kwargs["allow_duplicates"] = True
        return name, path, args, kwargs


class RandomCharField(UniqueFieldMixin, CharField):
    """
    RandomCharField

    By default, sets editable=False, blank=True, unique=False.

    Required arguments:

    length
        Specifies the length of the field

    Optional arguments:

    unique
        If set to True, duplicate entries are not allowed (default: False)

    lowercase
        If set to True, lowercase the alpha characters (default: False)

    uppercase
        If set to True, uppercase the alpha characters (default: False)

    include_alpha
        If set to True, include alpha characters (default: True)

    include_digits
        If set to True, include digit characters (default: True)

    include_punctuation
        If set to True, include punctuation characters (default: False)

    keep_default
        If set to True, keeps the default initialization value (default: False)
    """

    def __init__(self, *args, **kwargs):
        kwargs.setdefault("blank", True)
        kwargs.setdefault("editable", False)

        self.length = kwargs.pop("length", None)
        if self.length is None:
            raise ValueError("missing 'length' argument")
        kwargs["max_length"] = self.length

        self.lowercase = kwargs.pop("lowercase", False)
        self.check_is_bool("lowercase")
        self.uppercase = kwargs.pop("uppercase", False)
        self.check_is_bool("uppercase")
        if self.uppercase and self.lowercase:
            raise ValueError(
                "the 'lowercase' and 'uppercase' arguments are mutually exclusive"
            )
        self.include_digits = kwargs.pop("include_digits", True)
        self.check_is_bool("include_digits")
        self.include_alpha = kwargs.pop("include_alpha", True)
        self.check_is_bool("include_alpha")
        self.include_punctuation = kwargs.pop("include_punctuation", False)
        self.keep_default = kwargs.pop("keep_default", False)
        self.check_is_bool("include_punctuation")
        self.max_unique_query_attempts = kwargs.pop(
            "max_unique_query_attempts", MAX_UNIQUE_QUERY_ATTEMPTS
        )

        # Set unique=False unless it's been set manually.
        if "unique" not in kwargs:
            kwargs["unique"] = False

        super().__init__(*args, **kwargs)

    def random_char_generator(self, chars):
        for i in range(self.max_unique_query_attempts):
            yield "".join(get_random_string(self.length, chars))
        raise RuntimeError(
            "max random character attempts exceeded (%s)"
            % self.max_unique_query_attempts
        )

    def in_unique_together(self, model_instance):
        for params in model_instance._meta.unique_together:
            if self.attname in params:
                return True
        return False

    def pre_save(self, model_instance, add):
        if (not add or self.keep_default) and getattr(
            model_instance, self.attname
        ) != "":
            return getattr(model_instance, self.attname)

        population = ""
        if self.include_alpha:
            if self.lowercase:
                population += string.ascii_lowercase
            elif self.uppercase:
                population += string.ascii_uppercase
            else:
                population += string.ascii_letters

        if self.include_digits:
            population += string.digits

        if self.include_punctuation:
            population += string.punctuation

        random_chars = self.random_char_generator(population)
        if not self.unique and not self.in_unique_together(model_instance):
            new = next(random_chars)
            setattr(model_instance, self.attname, new)
            return new

        return self.find_unique(
            model_instance,
            model_instance._meta.get_field(self.attname),
            random_chars,
        )

    def internal_type(self):
        return "CharField"

    def deconstruct(self):
        name, path, args, kwargs = super().deconstruct()
        kwargs["length"] = self.length
        del kwargs["max_length"]
        if self.lowercase is True:
            kwargs["lowercase"] = self.lowercase
        if self.uppercase is True:
            kwargs["uppercase"] = self.uppercase
        if self.include_alpha is False:
            kwargs["include_alpha"] = self.include_alpha
        if self.include_digits is False:
            kwargs["include_digits"] = self.include_digits
        if self.include_punctuation is True:
            kwargs["include_punctuation"] = self.include_punctuation
        if self.unique is True:
            kwargs["unique"] = self.unique
        return name, path, args, kwargs


class CreationDateTimeField(DateTimeField):
    """
    CreationDateTimeField

    By default, sets editable=False, blank=True, auto_now_add=True
    """

    def __init__(self, *args, **kwargs):
        kwargs.setdefault("editable", False)
        kwargs.setdefault("blank", True)
        kwargs.setdefault("auto_now_add", True)
        DateTimeField.__init__(self, *args, **kwargs)

    def get_internal_type(self):
        return "DateTimeField"

    def deconstruct(self):
        name, path, args, kwargs = super().deconstruct()
        if self.editable is not False:
            kwargs["editable"] = True
        if self.blank is not True:
            kwargs["blank"] = False
        if self.auto_now_add is not False:
            kwargs["auto_now_add"] = True
        return name, path, args, kwargs


class ModificationDateTimeField(CreationDateTimeField):
    """
    ModificationDateTimeField

    By default, sets editable=False, blank=True, auto_now=True

    Sets value to now every time the object is saved.
    """

    def __init__(self, *args, **kwargs):
        kwargs.setdefault("auto_now", True)
        DateTimeField.__init__(self, *args, **kwargs)

    def get_internal_type(self):
        return "DateTimeField"

    def deconstruct(self):
        name, path, args, kwargs = super().deconstruct()
        if self.auto_now is not False:
            kwargs["auto_now"] = True
        return name, path, args, kwargs

    def pre_save(self, model_instance, add):
        if not getattr(model_instance, "update_modified", True):
            return getattr(model_instance, self.attname)
        return super().pre_save(model_instance, add)


class UUIDVersionError(Exception):
    pass


class UUIDFieldMixin:
    """
    UUIDFieldMixin

    By default uses UUID version 4 (randomly generated UUID).

    The field support all uuid versions which are natively supported by the uuid python module, except version 2.
    For more information see: https://docs.python.org/lib/module-uuid.html
    """  # noqa: E501

    DEFAULT_MAX_LENGTH = 36

    def __init__(
        self,
        verbose_name=None,
        name=None,
        auto=True,
        version=4,
        node=None,
        clock_seq=None,
        namespace=None,
        uuid_name=None,
        *args,
        **kwargs,
    ):
        if not HAS_UUID:
            raise ImproperlyConfigured(
                "'uuid' module is required for UUIDField. "
                "(Do you have Python 2.5 or higher installed ?)"
            )

        kwargs.setdefault("max_length", self.DEFAULT_MAX_LENGTH)

        if auto:
            self.empty_strings_allowed = False
            kwargs["blank"] = True
            kwargs.setdefault("editable", False)

        self.auto = auto
        self.version = version
        self.node = node
        self.clock_seq = clock_seq
        self.namespace = namespace
        self.uuid_name = uuid_name or name

        super().__init__(verbose_name=verbose_name, *args, **kwargs)

    def create_uuid(self):
        if not self.version or self.version == 4:
            return uuid.uuid4()
        elif self.version == 1:
            return uuid.uuid1(self.node, self.clock_seq)
        elif self.version == 2:
            raise UUIDVersionError("UUID version 2 is not supported.")
        elif self.version == 3:
            return uuid.uuid3(self.namespace, self.uuid_name)
        elif self.version == 5:
            return uuid.uuid5(self.namespace, self.uuid_name)
        else:
            raise UUIDVersionError("UUID version %s is not valid." % self.version)

    def pre_save(self, model_instance, add):
        value = super().pre_save(model_instance, add)

        if self.auto and add and value is None:
            value = force_str(self.create_uuid())
            setattr(model_instance, self.attname, value)
            return value
        else:
            if self.auto and not value:
                value = force_str(self.create_uuid())
                setattr(model_instance, self.attname, value)

        return value

    def formfield(self, form_class=None, choices_form_class=None, **kwargs):
        if self.auto:
            return None
        return super().formfield(form_class, choices_form_class, **kwargs)

    def deconstruct(self):
        name, path, args, kwargs = super().deconstruct()

        if kwargs.get("max_length", None) == self.DEFAULT_MAX_LENGTH:
            del kwargs["max_length"]
        if self.auto is not True:
            kwargs["auto"] = self.auto
        if self.version != 4:
            kwargs["version"] = self.version
        if self.node is not None:
            kwargs["node"] = self.node
        if self.clock_seq is not None:
            kwargs["clock_seq"] = self.clock_seq
        if self.namespace is not None:
            kwargs["namespace"] = self.namespace
        if self.uuid_name is not None:
            kwargs["uuid_name"] = self.name

        return name, path, args, kwargs


class ShortUUIDField(UUIDFieldMixin, CharField):
    """
    ShortUUIDFied

    Generates concise (22 characters instead of 36), unambiguous, URL-safe UUIDs.

    Based on `shortuuid`: https://github.com/stochastic-technologies/shortuuid
    """

    DEFAULT_MAX_LENGTH = 22

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        if not HAS_SHORT_UUID:
            raise ImproperlyConfigured(
                "'shortuuid' module is required for ShortUUIDField. "
                "(Do you have Python 2.5 or higher installed ?)"
            )
        kwargs.setdefault("max_length", self.DEFAULT_MAX_LENGTH)

    def create_uuid(self):
        if not self.version or self.version == 4:
            return shortuuid.uuid()
        elif self.version == 1:
            return shortuuid.uuid()
        elif self.version == 2:
            raise UUIDVersionError("UUID version 2 is not supported.")
        elif self.version == 3:
            raise UUIDVersionError("UUID version 3 is not supported.")
        elif self.version == 5:
            return shortuuid.uuid(name=self.namespace)
        else:
            raise UUIDVersionError("UUID version %s is not valid." % self.version)
