import copy
import operator
import sys
import unicodedata
from collections.abc import Iterator
from dataclasses import asdict, dataclass, field, fields, replace
from enum import IntEnum
from typing import Any, Optional, TypeVar

from xsdata.codegen.exceptions import CodegenError
from xsdata.formats.converter import converter
from xsdata.formats.dataclass.models.elements import XmlType
from xsdata.models.enums import DataType, Namespace, Tag
from xsdata.models.mixins import ElementBase
from xsdata.utils import namespaces, text

xml_type_map = {
    Tag.ANY: XmlType.WILDCARD,
    Tag.ANY_ATTRIBUTE: XmlType.ATTRIBUTES,
    Tag.ATTRIBUTE: XmlType.ATTRIBUTE,
    Tag.CHOICE: XmlType.ELEMENTS,
    Tag.ELEMENT: XmlType.ELEMENT,
}

GLOBAL_TYPES = (
    Tag.ELEMENT,
    Tag.COMPLEX_TYPE,
    Tag.BINDING_OPERATION,
    Tag.BINDING_MESSAGE,
    Tag.MESSAGE,
)

T = TypeVar("T", bound="CodegenModel")


@dataclass
class CodegenModel:
    """Base codegen model."""

    def clone(self: T, **kwargs: Any) -> T:
        """Return a deep cloned instance."""
        clone = copy.deepcopy(self)
        return replace(clone, **kwargs) if kwargs else clone

    def swap(self, source: "CodegenModel") -> None:
        """Swap the instance attributes from the source instance."""
        for f in fields(self):
            value = copy.deepcopy(getattr(source, f.name))
            setattr(self, f.name, value)


@dataclass
class Restrictions(CodegenModel):
    """Class field validation restrictions.

    Args:
        min_occurs: The minimum number of occurrences
        max_occurs: The maximum number of occurrences
        min_exclusive: The lower exclusive bound for numeric values
        min_inclusive: The lower inclusive bound for numeric values
        min_length: The minimum length of characters or list items allowed
        max_exclusive: The upper exclusive bound for numeric values
        max_inclusive: The upper inclusive bound for numeric values
        max_length: The max length of characters or list items allowed
        total_digits:  The exact number of digits allowed for numeric values
        fraction_digits: The maximum number of decimal places allowed
        length: The exact number of characters or list items allowed
        white_space: Specifies how white space is handled
        pattern: Defines the exact sequence of characters that are acceptable
        explicit_timezone: Require or prohibit the time zone offset in date/time
        nillable: Specifies whether nil content is allowed
        sequence: The sequence reference number of the attr
        tokens: Specifies whether the value needs tokenization
        format: The output format used for byte and datetime types
        choice: The choice reference number of the attr
        group: The group reference number of the attr
        process_contents: Specifies the content processed mode: strict, lax, skip
        path: The coded attr path in the source document
    """

    min_occurs: int | None = field(default=None)
    max_occurs: int | None = field(default=None)
    min_exclusive: str | None = field(default=None)
    min_inclusive: str | None = field(default=None)
    min_length: int | None = field(default=None)
    max_exclusive: str | None = field(default=None)
    max_inclusive: str | None = field(default=None)
    max_length: int | None = field(default=None)
    total_digits: int | None = field(default=None)
    fraction_digits: int | None = field(default=None)
    length: int | None = field(default=None)
    white_space: str | None = field(default=None)
    pattern: str | None = field(default=None)
    explicit_timezone: str | None = field(default=None)
    nillable: bool | None = field(default=None)
    sequence: int | None = field(default=None, compare=False)
    tokens: bool | None = field(default=None)
    format: str | None = field(default=None)
    choice: int | None = field(default=None, compare=False)
    group: int | None = field(default=None)
    process_contents: str | None = field(default=None)
    path: list[tuple[str, int, int, int]] = field(default_factory=list)

    @property
    def is_list(self) -> bool:
        """Return whether the max occurs larger than one."""
        return self.max_occurs is not None and self.max_occurs > 1

    @property
    def is_optional(self) -> bool:
        """Return whether the min occurs is zero."""
        return self.min_occurs == 0

    @property
    def is_prohibited(self) -> bool:
        """Return whether the max occurs is zero."""
        return self.max_occurs == 0

    def merge(self, source: "Restrictions") -> None:
        """Update properties from another instance.

        Args:
            source: The source instance to merge properties from
        """
        keys = (
            "min_exclusive",
            "min_inclusive",
            "min_length",
            "max_exclusive",
            "max_inclusive",
            "max_length",
            "total_digits",
            "fraction_digits",
            "length",
            "white_space",
            "pattern",
            "explicit_timezone",
            "process_contents",
        )

        for key in keys:
            value = getattr(source, key)
            if value is not None:
                setattr(self, key, value)

        self.path = source.path + self.path
        self.sequence = self.sequence or source.sequence
        self.choice = self.choice or source.choice
        self.tokens = self.tokens or source.tokens
        self.format = self.format or source.format
        self.group = self.group or source.group

        if self.min_occurs is None and source.min_occurs is not None:
            self.min_occurs = source.min_occurs

        if self.max_occurs is None and source.max_occurs is not None:
            self.max_occurs = source.max_occurs

    def asdict(self, types: list[type] | None = None) -> dict:
        """Return the initialized only properties as a dictionary.

        Skip None or implied values, and optionally use the
        attribute types to convert relevant options.

        Args:
            types: An optional list of attr python types

        Returns:
            A key-value of map of the attr restrictions for generation.
        """
        result = {}
        sorted_types = converter.sort_types(types) if types else []

        if self.is_list:
            if self.min_occurs is not None and self.min_occurs > 0:
                result["min_occurs"] = self.min_occurs
            if self.max_occurs is not None and self.max_occurs < sys.maxsize:
                result["max_occurs"] = self.max_occurs
        elif (
            self.min_occurs == self.max_occurs == 1
            and not self.nillable
            and not self.tokens
        ):
            result["required"] = True

        for key, value in asdict(self).items():
            if value is None or key in (
                "choice",
                "group",
                "min_occurs",
                "max_occurs",
                "path",
            ):
                continue

            if key == "process_contents" and value != "skip":
                continue

            if key.endswith("clusive") and types:
                value = converter.deserialize(value, sorted_types)

            result[key] = value

        return result

    @classmethod
    def from_element(cls, element: ElementBase) -> "Restrictions":
        """Static constructor from a xsd model.

        Args:
            element: A element base instance.

        Returns:
            The new restrictions instance
        """
        return cls(**element.get_restrictions())


@dataclass(unsafe_hash=True)
class AttrType(CodegenModel):
    """Class field typing information.

    Args:
        qname: The namespace qualified name
        alias: The type alias
        reference: The type reference number
        native: Specifies if it's python native type
        forward: Specifies if it's a forward reference
        circular: Specifies if it's a circular reference
        substituted: Specifies if it has been processed for substitution groups
    """

    qname: str
    alias: str | None = field(default=None, compare=False)
    reference: int = field(default=0, compare=False)
    native: bool = field(default=False)
    forward: bool = field(default=False)
    circular: bool = field(default=False)
    substituted: bool = field(default=False, compare=False)

    @property
    def datatype(self) -> DataType | None:
        """Return the datatype instance if native, none otherwise."""
        return DataType.from_qname(self.qname) if self.native else None

    @property
    def name(self) -> str:
        """Shortcut for qname local name."""
        return namespaces.local_name(self.qname)

    def is_dependency(self, allow_circular: bool) -> bool:
        """Return whether this type is a dependency.

        The type must a reference to a user type, not a forward
        reference and not a circular unless if it's allowed.

        Args:
            allow_circular: Allow circular references as dependencies

        Returns:
            The bool result/
        """
        return not (
            self.forward or self.native or (not allow_circular and self.circular)
        )


@dataclass
class Attr(CodegenModel):
    """Class field model representation.

    Args:
        tag: The xml tag that produced this attr
        name: The final attr name
        local_name: The original attr name
        wrapper: The wrapper element name
        index: The index position of this attr in the class
        default: The default value
        fixed: Specifies if the default value is fixed
        mixed: Specifies if the attr supports mixed content
        types: The attr types list
        choices: The attr choice list
        namespace: The attr namespace
        help: The attr help text
        restrictions: The attr restrictions instance
        parent: The parent class qualified name of the attr
        substitution: The substitution group this attr belongs to
    """

    tag: str
    name: str = field(compare=False)
    local_name: str = field(default="")
    wrapper: str | None = field(default=None)
    index: int = field(compare=False, default_factory=int)
    default: str | None = field(default=None, compare=False)
    fixed: bool = field(default=False, compare=False)
    mixed: bool = field(default=False, compare=False)
    types: list[AttrType] = field(default_factory=list, compare=False)
    choices: list["Attr"] = field(default_factory=list, compare=False)
    namespace: str | None = field(default=None)
    help: str | None = field(default=None, compare=False)
    restrictions: Restrictions = field(default_factory=Restrictions, compare=False)
    parent: str | None = field(default=None, compare=False)
    substitution: str | None = field(default=None, compare=False)

    def __post_init__(self):
        """Post init processing."""
        if not self.local_name:
            self.local_name = self.name

        if text.alnum(self.name) == "":
            self.name = "_".join(unicodedata.name(char) for char in self.name)

    @property
    def key(self) -> str:
        """Generate a key for this attr.

        Concatenate the tag/namespace/local_name.
        This key is used to find duplicates, it's not
        supposed to be unique.

        Returns:
            The unique key for this attr.

        """
        return f"{self.tag}.{self.namespace}.{self.local_name}"

    @property
    def qname(self) -> str:
        """Return the fully qualified name of the attr."""
        return namespaces.build_qname(self.namespace, self.local_name)

    @property
    def is_attribute(self) -> bool:
        """Return whether this attr represents a xml attribute node."""
        return self.tag in (Tag.ATTRIBUTE, Tag.ANY_ATTRIBUTE)

    @property
    def is_element(self) -> bool:
        """Return whether this attr represents a xml element."""
        return self.tag == Tag.ELEMENT

    @property
    def is_enumeration(self) -> bool:
        """Return whether this attr an enumeration member."""
        return self.tag == Tag.ENUMERATION

    @property
    def is_dict(self) -> bool:
        """Return whether this attr is derived from xs:anyAttribute."""
        return self.tag == Tag.ANY_ATTRIBUTE

    @property
    def is_factory(self) -> bool:
        """Return whether this attribute is a list of items or a mapping."""
        return self.is_list or self.is_dict or self.is_tokens

    @property
    def is_forward_ref(self) -> bool:
        """Return whether any attr types is a forward or circular reference."""
        return any(tp.circular or tp.forward for tp in self.types)

    @property
    def is_circular_ref(self) -> bool:
        """Return whether any attr types is a circular reference."""
        return any(tp.circular for tp in self.types)

    @property
    def is_group(self) -> bool:
        """Return whether this attr is a reference to a group class."""
        return self.tag in (Tag.ATTRIBUTE_GROUP, Tag.GROUP)

    @property
    def is_list(self) -> bool:
        """Return whether this attr requires a list of values."""
        return self.restrictions.is_list

    @property
    def is_prohibited(self) -> bool:
        """Return whether this attr is prohibited."""
        return self.restrictions.is_prohibited

    @property
    def is_nameless(self) -> bool:
        """Return whether this attr is a real xml node."""
        return self.tag not in (Tag.ATTRIBUTE, Tag.ELEMENT)

    @property
    def is_nillable(self) -> bool:
        """Return whether this attr supports nil values."""
        return self.restrictions.nillable is True

    @property
    def is_optional(self) -> bool:
        """Return whether this attr is not required."""
        return self.restrictions.is_optional

    @property
    def is_suffix(self) -> bool:
        """Return whether this attr is supposed to be generated last."""
        return self.index == sys.maxsize

    @property
    def is_xsi_type(self) -> bool:
        """Return whether this attr represents a xsi:type attribute."""
        return self.namespace == Namespace.XSI.uri and self.name == "type"

    @property
    def is_tokens(self) -> bool:
        """Return whether this attr supports token values."""
        return self.restrictions.tokens is True

    @property
    def is_wildcard(self) -> bool:
        """Return whether this attr supports any content."""
        return self.tag in (Tag.ANY_ATTRIBUTE, Tag.ANY)

    @property
    def is_any_type(self) -> bool:
        """Return whether this attr types support any content."""
        return any(tp is object for tp in self.get_native_types())

    @property
    def native_types(self) -> list[type]:
        """Return a list of all the builtin data types."""
        return list(set(self.get_native_types()))

    @property
    def user_types(self) -> Iterator[AttrType]:
        """Yield an iterator of all the user defined types."""
        for tp in self.types:
            if not tp.native:
                yield tp

    @property
    def slug(self) -> str:
        """Return the slugified name of the attr."""
        return text.alnum(self.name)

    @property
    def xml_type(self) -> str | None:
        """Return the xml type this attribute is mapped to."""
        return xml_type_map.get(self.tag)

    def get_native_types(self) -> Iterator[type]:
        """Yield an iterator of all the native attr types."""
        for tp in self.types:
            datatype = tp.datatype
            if datatype:
                yield datatype.type

    def can_be_restricted(self) -> bool:
        """Return whether this attr can be restricted."""
        return self.xml_type not in (Tag.ATTRIBUTE, None)


@dataclass(unsafe_hash=True)
class Extension(CodegenModel):
    """Base class model representation.

    Args:
        tag: The xml tag that produced this extension
        type: The extension type
        restrictions: The extension restrictions instance
    """

    tag: str
    type: AttrType
    restrictions: Restrictions = field(hash=False)


class Status(IntEnum):
    """Class process status enumeration."""

    RAW = 0
    UNGROUPING = 10
    UNGROUPED = 11
    FLATTENING = 20
    FLATTENED = 21
    SANITIZING = 30
    SANITIZED = 31
    RESOLVING = 40
    RESOLVED = 41
    CLEANING = 50
    CLEANED = 51
    FINALIZING = 60
    FINALIZED = 61


@dataclass
class Class(CodegenModel):
    """Class model representation.

    Args:
        qname: The namespace qualified name
        tag: The xml tag that produced this class
        location: The schema/document location uri
        mixed: Specifies whether this class supports mixed content
        abstract: Specifies whether this is an abstract class
        nillable: Specifies whether this class supports nil content
        local_type: Specifies if this class was an inner type at some point
        status: The processing status of the class
        container: The xml container of the class, schema, override, redefine
        package: The designated package of the class
        module: The designated module of the class
        namespace: The class namespace
        help: The help text
        meta_name: The xml element name of the class
        default: The default value
        fixed: Specifies whether the default value is fixed
        substitutions: The list of all the substitution groups this class belongs to
        extensions: The list of all the extension instances
        attrs: The list of all the attr instances
        inner: The list of all the inner class instances
        ns_map: The namespace prefix-URI map
        parent: The parent outer class
    """

    qname: str
    tag: str
    location: str = field(compare=False)
    mixed: bool = field(default=False)
    abstract: bool = field(default=False)
    nillable: bool = field(default=False)
    local_type: bool = field(default=False)
    status: Status = field(default=Status.RAW)
    container: str | None = field(default=None)
    package: str | None = field(default=None)
    module: str | None = field(default=None)
    namespace: str | None = field(default=None)
    help: str | None = field(default=None)
    meta_name: str | None = field(default=None)
    default: Any = field(default=None, compare=False)
    fixed: bool = field(default=False, compare=False)
    substitutions: list[str] = field(default_factory=list)
    extensions: list[Extension] = field(default_factory=list)
    attrs: list[Attr] = field(default_factory=list)
    inner: list["Class"] = field(default_factory=list)
    ns_map: dict = field(default_factory=dict)
    parent: Optional["Class"] = field(default=None, compare=False)

    @property
    def name(self) -> str:
        """Shortcut for the class local name."""
        return namespaces.local_name(self.qname)

    @property
    def slug(self) -> str:
        """Return a slugified version of the class name."""
        return text.alnum(self.name)

    @property
    def ref(self) -> int:
        """Return this id reference of this instance."""
        return id(self)

    @property
    def target_namespace(self) -> str | None:
        """Return the class target namespace."""
        return namespaces.target_uri(self.qname)

    @property
    def has_suffix_attr(self) -> bool:
        """Return whether it includes a suffix attr."""
        return any(attr.is_suffix for attr in self.attrs)

    @property
    def has_help_attr(self) -> bool:
        """Return whether at least one of attrs has help content."""
        return any(attr.help and attr.help.strip() for attr in self.attrs)

    @property
    def is_element(self) -> bool:
        """Return whether this class represents a xml element."""
        return self.tag == Tag.ELEMENT

    @property
    def is_enumeration(self) -> bool:
        """Return whether all attrs are enumeration members."""
        return len(self.attrs) > 0 and all(attr.is_enumeration for attr in self.attrs)

    @property
    def is_complex_type(self) -> bool:
        """Return whether this class represents a root/global class.

        Global classes are the only classes that get generated by default.
        """
        return self.tag in GLOBAL_TYPES

    @property
    def is_group(self) -> bool:
        """Return whether this class is derived from a xs:group/attributeGroup."""
        return self.tag in (Tag.ATTRIBUTE_GROUP, Tag.GROUP)

    @property
    def is_nillable(self) -> bool:
        """Return whether this class represents a nillable xml element."""
        return self.nillable or any(x.restrictions.nillable for x in self.extensions)

    @property
    def is_mixed(self) -> bool:
        """Return whether this class supports mixed content."""
        return self.mixed or any(x.mixed for x in self.attrs)

    @property
    def is_restricted(self) -> bool:
        """Return whether this class includes any restriction extensions."""
        return any(
            True for extension in self.extensions if extension.tag == Tag.RESTRICTION
        )

    @property
    def is_service(self) -> bool:
        """Return whether this instance is derived from a wsdl:operation."""
        return self.tag == Tag.BINDING_OPERATION

    @property
    def references(self) -> Iterator[int]:
        """Yield all class object reference numbers."""
        for tp in self.types():
            if tp.reference:
                yield tp.reference

    @property
    def target_module(self) -> str:
        """Return the designated full module path.

        Raises:
            CodeGenerationError: if the target was not designated
                a package and module.
        """
        if self.package and self.module:
            return f"{self.package}.{self.module}"

        if self.module:
            return self.module

        raise CodegenError(
            "Type has not been assigned to a module yet!", type=self.qname
        )

    def dependencies(self, allow_circular: bool = False) -> Iterator[str]:
        """Yields all class dependencies.

        Omit circular and forward references by default.

        Collect:
            * base classes
            * attribute types
            * attribute choice types
            * recursively go through the inner classes
            * Ignore inner class references
            * Ignore native types.

        Args:
            allow_circular: Allow circular references
        """
        for tp in set(self.types()):
            if tp.is_dependency(allow_circular=allow_circular):
                yield tp.qname

    def types(self) -> Iterator[AttrType]:
        """Yields all class types."""
        for _, tp in self.types_with_parents():
            yield tp

    def types_with_parents(self) -> Iterator[tuple[CodegenModel, AttrType]]:
        """Yields all class types with their parent codegen instance."""
        for ext in self.extensions:
            yield ext, ext.type

        for attr in self.attrs:
            for tp in attr.types:
                yield attr, tp

            for choice in attr.choices:
                for tp in choice.types:
                    yield choice, tp

        for inner in self.inner:
            yield from inner.types_with_parents()

    def children(self) -> Iterator[CodegenModel]:
        """Yield all codegen children."""
        for attr in self.attrs:
            yield attr
            yield attr.restrictions

            yield from attr.types

            for choice in attr.choices:
                yield choice
                yield choice.restrictions

                yield from choice.types

        for ext in self.extensions:
            yield ext
            yield ext.type
            yield ext.restrictions

        for inner in self.inner:
            yield from inner.children()

    def has_forward_ref(self) -> bool:
        """Return whether this class has any forward references."""
        for attr in self.attrs:
            if attr.is_forward_ref:
                return True

            if any(choice for choice in attr.choices if choice.is_forward_ref):
                return True

        return any(inner.has_forward_ref() for inner in self.inner)

    def parent_names(self) -> list[str]:
        """Return the outer class names."""
        result = []
        target = self.parent
        while target is not None:
            result.append(target.name)
            target = target.parent

        return list(reversed(result))


@dataclass
class Import:
    """Python import statement model representation.

    Args:
        qname: The qualified name of the imported class
        source: The absolute module path
        alias: Specifies an alias to avoid naming conflicts
    """

    qname: str
    source: str
    alias: str | None = field(default=None)

    @property
    def name(self) -> str:
        """Return the name of the imported class."""
        return namespaces.local_name(self.qname)

    @property
    def slug(self) -> str:
        """Return a slugified version of the imported class name."""
        return text.alnum(self.name)


# Getters used all over the codegen process
get_location = operator.attrgetter("location")
get_name = operator.attrgetter("name")
get_qname = operator.attrgetter("qname")
get_tag = operator.attrgetter("tag")
get_restriction_choice = operator.attrgetter("restrictions.choice")
get_restriction_sequence = operator.attrgetter("restrictions.sequence")
get_slug = operator.attrgetter("slug")
get_target_namespace = operator.attrgetter("target_namespace")
is_enumeration = operator.attrgetter("is_enumeration")
is_group = operator.attrgetter("is_group")
