from typing import Any, Dict, List, Optional
import catalogue
import contextlib
import copy
from io import BytesIO
import numpy

from ..backends import Ops, get_current_ops
from ..optimizers import Optimizer
from ..types import ArgsKwargs, ArrayXd
from ..util import get_array_module
from .shim import Shim
from ..compat import tensorflow as tf
from ..compat import cupy, h5py

keras_model_fns = catalogue.create("thinc", "keras", entry_points=True)


def maybe_handshake_model(keras_model):
    """Call the required predict/compile/build APIs to initialize a model if it
    is a subclass of tf.keras.Model. This is required to be able to call set_weights
    on subclassed layers."""
    try:
        keras_model.get_config()
        return keras_model
    except (AttributeError, NotImplementedError):
        # Subclassed models don't implement get_config
        pass

    for prop_name in ["catalogue_name", "eg_x", "eg_y", "eg_shape"]:
        if not hasattr(keras_model, prop_name):
            raise ValueError(
                "Keras subclassed models are not whole-model serializable by "
                "TensorFlow. To work around this, you must decorate your keras "
                "model subclasses with the 'keras_subclass' decorator. The decorator "
                "requires a single X/Y input of fake-data that can be used to initialize "
                "your subclass model properly when loading the saved version."
            )

    ops: Ops = get_current_ops()
    if ops.device_type == "cpu":
        device = "CPU"
    else:  # pragma: no cover
        device = tf.test.gpu_device_name()

    compile_args = keras_model.eg_compile

    with tf.device(device):
        # Calling predict creates layers and weights for subclassed models
        keras_model.compile(**compile_args)
        keras_model.build(keras_model.eg_shape)
        keras_model.predict(keras_model.eg_x)
        # Made public in 2.2.x
        if hasattr(keras_model, "_make_train_function"):
            keras_model._make_train_function()
        else:
            keras_model.make_train_function()
    return keras_model


class TensorFlowShim(Shim):
    """Interface between a TensorFlow model and a Thinc Model. This container is
    *not* a Thinc Model subclass itself.

    Reference for custom training:
    https://www.tensorflow.org/tutorials/customization/custom_training_walkthrough
    """

    gradients: Optional[List["tf.Tensor"]]

    def __init__(self, model: Any, config=None, optimizer: Any = None):
        super().__init__(model, config, optimizer)
        self.gradients = None

    def __str__(self):
        lines: List[str] = []

        def accumulate(line: str):
            lines.append(line)

        self._model.summary(print_fn=accumulate)
        return "\n".join(lines)

    def __call__(self, X: ArgsKwargs, is_train: bool):
        if is_train:
            return self.begin_update(X)
        else:
            return self.predict(X)

    def predict(self, X: ArgsKwargs):
        old_phase = tf.keras.backend.learning_phase()
        tf.keras.backend.set_learning_phase(0)
        Y = self._model(*X.args, **X.kwargs)
        tf.keras.backend.set_learning_phase(old_phase)
        return Y

    def begin_update(self, X: ArgsKwargs):
        tf.keras.backend.set_learning_phase(1)
        tape = tf.GradientTape()
        tape.__enter__()
        tape.watch(X.args)  # watch the input layers
        output = self._model(*X.args, **X.kwargs)

        def backprop(d_output):
            # d_args[0] contains derivative of loss wrt output (d_loss/d_output)
            tape.__exit__(None, None, None)
            # We need to handle a tuple of inputs
            if len(X.args) == 1:
                wrt_tensors = [X.args[0]]  # add the input layer also for d_loss/d_input
            else:
                wrt_tensors = list(X.args[0])
            wrt_tensors.extend(self._model.trainable_variables)
            all_gradients = tape.gradient(
                output, wrt_tensors, output_gradients=d_output
            )
            dX = all_gradients[: len(X.args)]
            opt_grads = all_gradients[1:]
            # Accumulate gradients
            if self.gradients is not None:
                assert len(opt_grads) == len(self.gradients), "gradients must match"
                variable: tf.Variable
                for variable, new_variable in zip(self.gradients, opt_grads):
                    variable.assign_add(new_variable)
            else:
                # Create variables from the grads to allow accumulation
                self.gradients = [tf.Variable(f) for f in opt_grads]
            return ArgsKwargs(args=tuple(dX), kwargs={})

        return output, backprop

    def finish_update(self, optimizer: Optimizer):
        if self.gradients is None:
            raise ValueError(
                "There are no gradients for optimization. Be sure to call begin_update"
                " before calling finish_update."
            )
        assert len(self.gradients) == len(self._model.trainable_variables)
        grad: tf.Tensor
        variable: tf.Variable
        params = []
        grads = []
        shapes = []

        for grad, variable in zip(self.gradients, self._model.trainable_variables):
            param = variable.numpy()
            grad = grad.numpy()
            shapes.append((param.size, param.shape))
            params.append(param.ravel())
            grads.append(grad.ravel())
        xp = get_array_module(params[0])
        flat_params, flat_grads = optimizer(
            (self.id, "tensorflow-shim"), xp.concatenate(params), xp.concatenate(grads)
        )
        start = 0
        for grad, variable in zip(self.gradients, self._model.trainable_variables):
            size, shape = shapes.pop(0)
            param = flat_params[start : start + size].reshape(shape)
            variable.assign(param)
            start += size
        self.gradients = None

    def _load_weights_from_state_dict(
        self, state_dict: Optional[Dict[str, ArrayXd]] = None
    ):
        if state_dict is None:
            state_dict = self._create_state_dict()
        for layer in self._model.layers:
            current_layer_weights = []
            for weight in layer.weights:
                current_layer_weights.append(state_dict[weight.name])
            layer.set_weights(current_layer_weights)

    # Create a state dict similar to PyTorch
    def _create_state_dict(self):
        # key as variable name and value as numpy arrays
        state_dict = {}
        for layer in self._model.layers:
            for weight in layer.weights:
                state_dict[weight.name] = weight.numpy()
        return state_dict

    @contextlib.contextmanager
    def use_params(self, params):
        key_prefix = f"tensorflow_{self.id}_"
        # state dict stores key as name and value as numpy array
        state_dict = {}
        for k, v in params.items():
            if hasattr(k, "startswith") and k.startswith(key_prefix):
                if cupy is None:
                    assert isinstance(v, numpy.ndarray)
                else:  # pragma: no cover
                    if isinstance(v, cupy.core.core.ndarray):
                        v = cupy.asnumpy(v)
                    assert isinstance(v, numpy.ndarray)
                state_dict[k.replace(key_prefix, "")] = v
        if state_dict:
            backup = self._create_state_dict()
            self._load_weights_from_state_dict(state_dict)
            yield
            self._load_weights_from_state_dict(backup)
        else:
            yield

    def _clone_model(self):
        """similar to tf.keras.models.clone_model()
        But the tf.keras.models.clone_model changes the names of tf.Variables.
        This method even preserves that
        """
        model_json_config = self._model.to_json()
        tf.keras.backend.clear_session()
        self._model = tf.keras.models.model_from_json(model_json_config)
        self._load_weights_from_state_dict()

    def copy(self):
        model_json_config = self._model.to_json()
        self._model = None
        tf.keras.backend.clear_session()
        copied = copy.deepcopy(self)
        copied._model = tf.keras.models.model_from_json(model_json_config)
        copied._load_weights_from_state_dict()
        return copied

    def to_device(self, device_type: str, device_id: int):  # pragma: no cover
        if device_type == "cpu":
            with tf.device("/CPU"):  # pragma: no cover
                self._clone_model()
        elif device_type == "gpu":
            with tf.device("/GPU:{}".format(device_id)):
                self._clone_model()

    def to_bytes(self):
        filelike = BytesIO()
        try:
            with h5py.File(filelike, "w") as f:
                self._model.save(f, save_format="h5")
            return filelike.getvalue()
        except NotImplementedError:
            if not hasattr(self._model, "catalogue_name"):
                raise ValueError(
                    "Couldn't serialize to h5, and model has no factory "
                    "function for component serialization."
                )
        # Check the factory function and throw ValueError if it doesn't exist
        keras_model_fns.get(self._model.catalogue_name)
        return self._model.catalogue_name, self._model.get_weights()

    def from_bytes(self, data):
        ops: Ops = get_current_ops()
        if ops.device_type == "cpu":
            device = "CPU"
        else:  # pragma: no cover
            device = tf.test.gpu_device_name()

        # Plain bytes
        if isinstance(data, (str, bytes)):
            tf.keras.backend.clear_session()
            filelike = BytesIO(data)
            filelike.seek(0)
            with h5py.File(filelike, "r") as f:
                with tf.device(device):
                    self._model = tf.keras.models.load_model(f)
                return
        # We only have to create the model if it doesn't already exist.
        catalogue_name, model_weights = data
        if self._model is None:
            model_fn = keras_model_fns.get(catalogue_name)
            tf.keras.backend.clear_session()
            with tf.device(device):
                if hasattr(self._model, "eg_args"):
                    ak: ArgsKwargs = self._model.eg_args
                    new_model = model_fn(*ak.args, **ak.kwargs)
                else:
                    new_model = model_fn()
            self._model_initialized = maybe_handshake_model(new_model)
        self._model.set_weights(model_weights)
