# this code should be common among cwrap and ATen preprocessing
# for now, I have put it in one place but right now is copied out of cwrap

import copy
from typing import Any, Dict, Iterable, List, Union

Arg = Dict[str, Any]


def parse_arguments(args: List[Union[str, Arg]]) -> List[Arg]:
    new_args = []
    for arg in args:
        # Simple arg declaration of form "<type> <name>"
        if isinstance(arg, str):
            t, _, name = arg.partition(" ")
            new_args.append({"type": t, "name": name})
        elif isinstance(arg, dict):
            if "arg" in arg:
                arg["type"], _, arg["name"] = arg["arg"].partition(" ")
                del arg["arg"]
            new_args.append(arg)
        else:
            raise AssertionError()
    return new_args


Declaration = Dict[str, Any]


def set_declaration_defaults(declaration: Declaration) -> None:
    if "schema_string" not in declaration:
        # This happens for legacy TH bindings like
        # _thnn_conv_depthwise2d_backward
        declaration["schema_string"] = ""
    declaration.setdefault("arguments", [])
    declaration.setdefault("return", "void")
    if "cname" not in declaration:
        declaration["cname"] = declaration["name"]
    if "backends" not in declaration:
        declaration["backends"] = ["CPU", "CUDA"]
    assert "api_name" not in declaration
    declaration["api_name"] = declaration["name"]
    # NB: keep this in sync with gen_autograd.py
    if declaration.get("overload_name"):
        declaration["type_wrapper_name"] = "{}_{}".format(
            declaration["name"], declaration["overload_name"]
        )
    else:
        declaration["type_wrapper_name"] = declaration["name"]
    # TODO: Uggggh, parsing the schema string here, really???
    declaration["operator_name_with_overload"] = declaration["schema_string"].split(
        "("
    )[0]
    if declaration["schema_string"]:
        declaration["unqual_schema_string"] = declaration["schema_string"].split("::")[
            1
        ]
        declaration["unqual_operator_name_with_overload"] = declaration[
            "operator_name_with_overload"
        ].split("::")[1]
    else:
        declaration["unqual_schema_string"] = ""
        declaration["unqual_operator_name_with_overload"] = ""
    # Simulate multiple dispatch, even if it's not necessary
    if "options" not in declaration:
        declaration["options"] = [
            {
                "arguments": copy.deepcopy(declaration["arguments"]),
                "schema_order_arguments": copy.deepcopy(
                    declaration["schema_order_arguments"]
                ),
            }
        ]
        del declaration["arguments"]
        del declaration["schema_order_arguments"]
    # Parse arguments (some of them can be strings)
    for option in declaration["options"]:
        option["arguments"] = parse_arguments(option["arguments"])
        option["schema_order_arguments"] = parse_arguments(
            option["schema_order_arguments"]
        )
    # Propagate defaults from declaration to options
    for option in declaration["options"]:
        for k, v in declaration.items():
            # TODO(zach): why does cwrap not propagate 'name'? I need it
            # propagaged for ATen
            if k != "options":
                option.setdefault(k, v)


# TODO(zach): added option to remove keyword handling for C++ which cannot
# support it.

Option = Dict[str, Any]


def filter_unique_options(
    options: Iterable[Option],
    allow_kwarg: bool,
    type_to_signature: Dict[str, str],
    remove_self: bool,
) -> List[Option]:
    def exclude_arg(arg: Arg) -> bool:
        return arg["type"] == "CONSTANT"  # type: ignore[no-any-return]

    def exclude_arg_with_self_check(arg: Arg) -> bool:
        return exclude_arg(arg) or (remove_self and arg["name"] == "self")

    def signature(option: Option, num_kwarg_only: int) -> str:
        if num_kwarg_only == 0:
            kwarg_only_count = None
        else:
            kwarg_only_count = -num_kwarg_only
        arg_signature = "#".join(
            type_to_signature.get(arg["type"], arg["type"])
            for arg in option["arguments"][:kwarg_only_count]
            if not exclude_arg_with_self_check(arg)
        )
        if kwarg_only_count is None:
            return arg_signature
        kwarg_only_signature = "#".join(
            arg["name"] + "#" + arg["type"]
            for arg in option["arguments"][kwarg_only_count:]
            if not exclude_arg(arg)
        )
        return arg_signature + "#-#" + kwarg_only_signature

    seen_signatures = set()
    unique = []
    for option in options:
        # if only check num_kwarg_only == 0 if allow_kwarg == False
        limit = len(option["arguments"]) if allow_kwarg else 0
        for num_kwarg_only in range(0, limit + 1):
            sig = signature(option, num_kwarg_only)
            if sig not in seen_signatures:
                if num_kwarg_only > 0:
                    for arg in option["arguments"][-num_kwarg_only:]:
                        arg["kwarg_only"] = True
                unique.append(option)
                seen_signatures.add(sig)
                break
    return unique


def sort_by_number_of_args(declaration: Declaration, reverse: bool = True) -> None:
    def num_args(option: Option) -> int:
        return len(option["arguments"])

    declaration["options"].sort(key=num_args, reverse=reverse)


class Function(object):
    def __init__(self, name: str) -> None:
        self.name = name
        self.arguments: List["Argument"] = []

    def add_argument(self, arg: "Argument") -> None:
        assert isinstance(arg, Argument)
        self.arguments.append(arg)

    def __repr__(self) -> str:
        return self.name + "(" + ", ".join(a.__repr__() for a in self.arguments) + ")"


class Argument(object):
    def __init__(self, _type: str, name: str, is_optional: bool):
        self.type = _type
        self.name = name
        self.is_optional = is_optional

    def __repr__(self) -> str:
        return self.type + " " + self.name


def parse_header(path: str) -> List[Function]:
    with open(path, "r") as f:
        lines: Iterable[Any] = f.read().split("\n")

    # Remove empty lines and prebackend directives
    lines = filter(lambda l: l and not l.startswith("#"), lines)
    # Remove line comments
    lines = (l.partition("//") for l in lines)
    # Select line and comment part
    lines = ((l[0].strip(), l[2].strip()) for l in lines)
    # Remove trailing special signs
    lines = ((l[0].rstrip(");").rstrip(","), l[1]) for l in lines)
    # Split arguments
    lines = ((l[0].split(","), l[1]) for l in lines)
    # Flatten lines
    new_lines = []
    for l, c in lines:
        for split in l:
            new_lines.append((split, c))
    lines = new_lines
    del new_lines
    # Remove unnecessary whitespace
    lines = ((l[0].strip(), l[1]) for l in lines)
    # Remove empty lines
    lines = filter(lambda l: l[0], lines)
    generic_functions = []
    for l, c in lines:
        if l.startswith("TH_API void THNN_"):
            fn_name = l[len("TH_API void THNN_") :]
            if fn_name[0] == "(" and fn_name[-2] == ")":
                fn_name = fn_name[1:-2]
            else:
                fn_name = fn_name[:-1]
            generic_functions.append(Function(fn_name))
        elif l.startswith("TORCH_CUDA_CPP_API void THNN_"):
            fn_name = l[len("TORCH_CUDA_CPP_API void THNN_") :]
            if fn_name[0] == "(" and fn_name[-2] == ")":
                fn_name = fn_name[1:-2]
            else:
                fn_name = fn_name[:-1]
            generic_functions.append(Function(fn_name))
        elif l.startswith("TORCH_CUDA_CU_API void THNN_"):
            fn_name = l[len("TORCH_CUDA_CU_API void THNN_") :]
            if fn_name[0] == "(" and fn_name[-2] == ")":
                fn_name = fn_name[1:-2]
            else:
                fn_name = fn_name[:-1]
            generic_functions.append(Function(fn_name))
        elif l:
            t, name = l.split()
            if "*" in name:
                t = t + "*"
                name = name[1:]
            generic_functions[-1].add_argument(Argument(t, name, "[OPTIONAL]" in c))
    return generic_functions
