from typing import Optional, Callable, Any, Tuple, TypeVar

from ..model import Model

_ModelT = TypeVar("_ModelT", bound=Model)

do_nothing = lambda *args, **kwargs: None


def with_debug(
    layer: _ModelT,
    name: Optional[str] = None,
    *,
    on_init: Callable[[Model, Any, Any], None] = do_nothing,
    on_forward: Callable[[Model, Any, bool], None] = do_nothing,
    on_backprop: Callable[[Any], None] = do_nothing,
) -> _ModelT:
    """Debugging layer that wraps any layer and allows executing callbacks
    during the forward pass, backward pass and initialization. The callbacks
    will receive the same arguments as the functions they're called in.
    """
    name = layer.name if name is None else name

    orig_forward = layer._func
    orig_init = layer.init

    def forward(model: Model, X: Any, is_train: bool) -> Tuple[Any, Callable]:
        on_forward(model, X, is_train)
        layer_Y, layer_callback = orig_forward(layer, X, is_train=is_train)

        def backprop(dY: Any) -> Any:
            on_backprop(dY)
            return layer_callback(dY)

        return layer_Y, backprop

    def init(model: Model, X: Any, Y: Any) -> None:
        on_init(model, X, Y)
        if orig_init is not None:
            orig_init(layer, X, Y)

    layer.replace_callbacks(forward, init=init)

    return layer
