from typing import Any, Callable, Dict, Optional, Tuple, Type, TypeVar

import srsly

from ..model import Model
from ..shims import TensorFlowShim, keras_model_fns, maybe_handshake_model
from ..util import xp2tensorflow, tensorflow2xp, assert_tensorflow_installed
from ..util import is_tensorflow_array, convert_recursive, is_xp_array
from ..types import ArrayXd, ArgsKwargs
from ..compat import tensorflow as tf

InT = TypeVar("InT")
OutT = TypeVar("OutT")
InFunc = TypeVar("InFunc")
XType = TypeVar("XType", bound=ArrayXd)
YType = TypeVar("YType", bound=ArrayXd)


def keras_subclass(
    name: str,
    X: XType,
    Y: YType,
    input_shape: Tuple[int, ...],
    compile_args: Optional[Dict[str, Any]] = None,
) -> Callable[[InFunc], InFunc]:
    """Decorate a custom keras subclassed model with enough information to
    serialize and deserialize it reliably in the face of the many restrictions
    on keras subclassed models.

    name (str): The unique namespace string to use to represent this model class.
    X (Any): A sample X input for performing a forward pass on the network.
    Y (Any): A sample Y input for performing a backward pass on the network.
    input_shape (Tuple[int, ...]): A set of input shapes for building the network.
    compile: Arguments to pass directly to the keras `model.compile` call.

    RETURNS (Callable): The decorated class.
    """

    compile_defaults = {"optimizer": "adam", "loss": "mse"}
    if compile_args is None:
        compile_args = compile_defaults
    else:
        compile_args = {**compile_defaults, **compile_args}

    def call_fn(clazz):

        clazz.catalogue_name = property(lambda inst: name)
        clazz.eg_shape = property(lambda inst: input_shape)
        clazz.eg_compile = property(lambda inst: compile_args)
        clazz.eg_x = property(lambda inst: X)
        clazz.eg_y = property(lambda inst: Y)

        @keras_model_fns(name)
        def create_component(*call_args, **call_kwargs):
            return clazz(*call_args, **call_kwargs)

        # Capture construction args and store them on the instance
        wrapped_init = clazz.__init__

        def __init__(self, *args, **kwargs):
            wrapped_init(self, *args, **kwargs)
            try:
                srsly.json_dumps(args)
                srsly.json_dumps(kwargs)
            except BaseException as _err:
                raise ValueError(
                    "In order to serialize Keras Subclass models, the constructor "
                    "arguments must be serializable. This allows thinc to recreate "
                    "the code-based model with the same configuration.\n"
                    f"The encountered error is: {_err}"
                )
            self.eg_args = ArgsKwargs(args, kwargs)

        clazz.__init__ = __init__

        return clazz

    return call_fn


def TensorFlowWrapper(
    tensorflow_model: Any,
    convert_inputs: Optional[Callable] = None,
    convert_outputs: Optional[Callable] = None,
    optimizer: Optional[Any] = None,
    model_class: Type[Model] = Model,
    model_name: str = "tensorflow",
) -> Model[InT, OutT]:
    """Wrap a TensorFlow model, so that it has the same API as Thinc models.
    To optimize the model, you'll need to create a TensorFlow optimizer and call
    optimizer.apply_gradients after each batch.
    """
    assert_tensorflow_installed()
    if not isinstance(tensorflow_model, tf.keras.models.Model):
        err = f"Expected tf.keras.models.Model, got: {type(tensorflow_model)}"
        raise ValueError(err)
    tensorflow_model = maybe_handshake_model(tensorflow_model)
    if convert_inputs is None:
        convert_inputs = _convert_inputs
    if convert_outputs is None:
        convert_outputs = _convert_outputs
    return model_class(
        model_name,
        forward,
        shims=[TensorFlowShim(tensorflow_model, optimizer=optimizer)],
        attrs={"convert_inputs": convert_inputs, "convert_outputs": convert_outputs},
    )


def forward(model: Model[InT, OutT], X: InT, is_train: bool) -> Tuple[OutT, Callable]:
    """Return the output of the wrapped TensorFlow model for the given input,
    along with a callback to handle the backward pass.
    """
    convert_inputs = model.attrs["convert_inputs"]
    convert_outputs = model.attrs["convert_outputs"]
    tensorflow_model = model.shims[0]
    X_tensorflow, get_dX = convert_inputs(model, X, is_train)
    if is_train:
        Y_tensorflow, tensorflow_backprop = tensorflow_model(X_tensorflow, is_train)
    else:
        Y_tensorflow = tensorflow_model(X_tensorflow, is_train)
    Y, get_dY_tensorflow = convert_outputs(model, Y_tensorflow, is_train)

    def backprop(dY: OutT) -> InT:
        dY_tensorflow = get_dY_tensorflow(dY)
        dX_tensorflow = tensorflow_backprop(dY_tensorflow)
        return get_dX(dX_tensorflow)

    return Y, backprop


# Default conversion functions
# These are pretty much the same as the PyTorch one, but I think we should
# leave the duplication -- I think the abstraction could get pretty messy,
# and then may need to be undone, as there can always be different specifics.


def _convert_inputs(model, X, is_train):
    xp2tensorflow_ = lambda x: xp2tensorflow(x, requires_grad=is_train)
    converted = convert_recursive(is_xp_array, xp2tensorflow_, X)
    if isinstance(converted, ArgsKwargs):

        def reverse_conversion(dXtf):
            return convert_recursive(is_tensorflow_array, tensorflow2xp, dXtf)

        return converted, reverse_conversion
    elif isinstance(converted, dict):

        def reverse_conversion(dXtf):
            dX = convert_recursive(is_tensorflow_array, tensorflow2xp, dXtf)
            return dX.kwargs

        return ArgsKwargs(args=tuple(), kwargs=converted), reverse_conversion
    elif isinstance(converted, (tuple, list)):

        def reverse_conversion(dXtf):
            dX = convert_recursive(is_tensorflow_array, tensorflow2xp, dXtf)
            return dX.args

        return ArgsKwargs(args=converted, kwargs={}), reverse_conversion
    else:

        def reverse_conversion(dXtf):
            dX = convert_recursive(is_tensorflow_array, tensorflow2xp, dXtf)
            return dX.args[0]

        return ArgsKwargs(args=(converted,), kwargs={}), reverse_conversion


def _convert_outputs(model, Ytf, is_train):
    Y = convert_recursive(is_tensorflow_array, tensorflow2xp, Ytf)

    def reverse_conversion(dY):
        return convert_recursive(is_xp_array, xp2tensorflow, dY)

    return Y, reverse_conversion
