from typing import Any, Callable, Optional

from ..compat import torch
from ..model import Model
from ..shims import PyTorchGradScaler, PyTorchShim, TorchScriptShim
from .pytorchwrapper import forward, convert_pytorch_default_inputs
from .pytorchwrapper import convert_pytorch_default_outputs


def TorchScriptWrapper_v1(
    torchscript_model: Optional["torch.jit.ScriptModule"] = None,
    convert_inputs: Optional[Callable] = None,
    convert_outputs: Optional[Callable] = None,
    mixed_precision: bool = False,
    grad_scaler: Optional[PyTorchGradScaler] = None,
    device: Optional["torch.device"] = None,
) -> Model[Any, Any]:
    """Wrap a TorchScript model, so that it has the same API as Thinc models.

    torchscript_model:
        The TorchScript module. A value of `None` is also possible to
        construct a shim to deserialize into.
    convert_inputs:
        Function that converts inputs and gradients that should be passed
        to the model to Torch tensors.
    convert_outputs:
        Function that converts model outputs and gradients from Torch tensors
        Thinc arrays.
    mixed_precision:
        Enable mixed-precision. This changes whitelisted ops to run
        in half precision for better performance and lower memory use.
    grad_scaler:
        The gradient scaler to use for mixed-precision training. If this
        argument is set to "None" and mixed precision is enabled, a gradient
        scaler with the default configuration is used.
    device:
        The PyTorch device to run the model on. When this argument is
        set to "None", the default device for the currently active Thinc
        ops is used.
    """

    if convert_inputs is None:
        convert_inputs = convert_pytorch_default_inputs
    if convert_outputs is None:
        convert_outputs = convert_pytorch_default_outputs

    return Model(
        "pytorch_script",
        forward,
        attrs={"convert_inputs": convert_inputs, "convert_outputs": convert_outputs},
        shims=[
            TorchScriptShim(
                model=torchscript_model,
                mixed_precision=mixed_precision,
                grad_scaler=grad_scaler,
                device=device,
            )
        ],
        dims={"nI": None, "nO": None},
    )


def pytorch_to_torchscript_wrapper(model: Model):
    """Convert a PyTorch wrapper to a TorchScript wrapper. The embedded PyTorch
    `Module` is converted to `ScriptModule`.
    """
    shim = model.shims[0]
    if not isinstance(shim, PyTorchShim):
        raise ValueError("Expected PyTorchShim when converting a PyTorch wrapper")

    convert_inputs = model.attrs["convert_inputs"]
    convert_outputs = model.attrs["convert_outputs"]

    pytorch_model = shim._model
    if not isinstance(pytorch_model, torch.nn.Module):
        raise ValueError("PyTorchShim does not wrap a PyTorch module")

    torchscript_model = torch.jit.script(pytorch_model)
    grad_scaler = shim._grad_scaler
    mixed_precision = shim._mixed_precision
    device = shim.device

    return TorchScriptWrapper_v1(
        torchscript_model,
        convert_inputs=convert_inputs,
        convert_outputs=convert_outputs,
        mixed_precision=mixed_precision,
        grad_scaler=grad_scaler,
        device=device,
    )
