import random
from copy import deepcopy
from itertools import chain
from typing import Mapping, Sequence

import pytest
import torch
from common_utils import set_rng_seed
from torchvision import models
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.detection.backbone_utils import BackboneWithFPN, mobilenet_backbone, resnet_fpn_backbone
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names


@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
def test_resnet_fpn_backbone(backbone_name):
    x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
    model = resnet_fpn_backbone(backbone_name=backbone_name, weights=None)
    assert isinstance(model, BackboneWithFPN)
    y = model(x)
    assert list(y.keys()) == ["0", "1", "2", "3", "pool"]

    with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
        resnet_fpn_backbone(backbone_name=backbone_name, weights=None, trainable_layers=6)
    with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
        resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[0, 1, 2, 3])
    with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
        resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[2, 3, 4, 5])


@pytest.mark.parametrize("backbone_name", ("mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"))
def test_mobilenet_backbone(backbone_name):
    with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
        mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False, trainable_layers=-1)
    with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
        mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[-1, 0, 1, 2])
    with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
        mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[3, 4, 5, 6])
    model_fpn = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True)
    assert isinstance(model_fpn, BackboneWithFPN)
    model = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False)
    assert isinstance(model, torch.nn.Sequential)


# Needed by TestFxFeatureExtraction.test_leaf_module_and_function
def leaf_function(x):
    return int(x)


# Needed by TestFXFeatureExtraction. Checking that node naming conventions
# are respected. Particularly the index postfix of repeated node names
class TestSubModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = x + 1
        x = x + 1
        x = self.relu(x)
        x = self.relu(x)
        return x


class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.submodule = TestSubModule()
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.submodule(x)
        x = x + 1
        x = x + 1
        x = self.relu(x)
        x = self.relu(x)
        return x


test_module_nodes = [
    "x",
    "submodule.add",
    "submodule.add_1",
    "submodule.relu",
    "submodule.relu_1",
    "add",
    "add_1",
    "relu",
    "relu_1",
]


class TestFxFeatureExtraction:
    inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device="cpu")
    model_defaults = {"num_classes": 1}
    leaf_modules = []

    def _create_feature_extractor(self, *args, **kwargs):
        """
        Apply leaf modules
        """
        tracer_kwargs = {}
        if "tracer_kwargs" not in kwargs:
            tracer_kwargs = {"leaf_modules": self.leaf_modules}
        else:
            tracer_kwargs = kwargs.pop("tracer_kwargs")
        return create_feature_extractor(*args, **kwargs, tracer_kwargs=tracer_kwargs, suppress_diff_warning=True)

    def _get_return_nodes(self, model):
        set_rng_seed(0)
        exclude_nodes_filter = [
            "getitem",
            "floordiv",
            "size",
            "chunk",
            "_assert",
            "eq",
            "dim",
            "getattr",
        ]
        train_nodes, eval_nodes = get_graph_node_names(
            model, tracer_kwargs={"leaf_modules": self.leaf_modules}, suppress_diff_warning=True
        )
        # Get rid of any nodes that don't return tensors as they cause issues
        # when testing backward pass.
        train_nodes = [n for n in train_nodes if not any(x in n for x in exclude_nodes_filter)]
        eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)]
        return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)

    @pytest.mark.parametrize("model_name", models.list_models(models))
    def test_build_fx_feature_extractor(self, model_name):
        set_rng_seed(0)
        model = models.get_model(model_name, **self.model_defaults).eval()
        train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
        # Check that it works with both a list and dict for return nodes
        self._create_feature_extractor(
            model, train_return_nodes={v: v for v in train_return_nodes}, eval_return_nodes=eval_return_nodes
        )
        self._create_feature_extractor(
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
        # Check must specify return nodes
        with pytest.raises(ValueError):
            self._create_feature_extractor(model)
        # Check return_nodes and train_return_nodes / eval_return nodes
        # mutual exclusivity
        with pytest.raises(ValueError):
            self._create_feature_extractor(
                model, return_nodes=train_return_nodes, train_return_nodes=train_return_nodes
            )
        # Check train_return_nodes / eval_return nodes must both be specified
        with pytest.raises(ValueError):
            self._create_feature_extractor(model, train_return_nodes=train_return_nodes)
        # Check invalid node name raises ValueError
        with pytest.raises(ValueError):
            # First just double check that this node really doesn't exist
            if not any(n.startswith("l") or n.startswith("l.") for n in chain(train_return_nodes, eval_return_nodes)):
                self._create_feature_extractor(model, train_return_nodes=["l"], eval_return_nodes=["l"])
            else:  # otherwise skip this check
                raise ValueError

    def test_node_name_conventions(self):
        model = TestModule()
        train_nodes, _ = get_graph_node_names(model)
        assert all(a == b for a, b in zip(train_nodes, test_module_nodes))

    @pytest.mark.parametrize("model_name", models.list_models(models))
    def test_forward_backward(self, model_name):
        model = models.get_model(model_name, **self.model_defaults).train()
        train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
        model = self._create_feature_extractor(
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
        out = model(self.inp)
        out_agg = 0
        for node_out in out.values():
            if isinstance(node_out, Sequence):
                out_agg += sum(o.float().mean() for o in node_out if o is not None)
            elif isinstance(node_out, Mapping):
                out_agg += sum(o.float().mean() for o in node_out.values() if o is not None)
            else:
                # Assume that the only other alternative at this point is a Tensor
                out_agg += node_out.float().mean()
        out_agg.backward()

    def test_feature_extraction_methods_equivalence(self):
        model = models.resnet18(**self.model_defaults).eval()
        return_layers = {"layer1": "layer1", "layer2": "layer2", "layer3": "layer3", "layer4": "layer4"}

        ilg_model = IntermediateLayerGetter(model, return_layers).eval()
        fx_model = self._create_feature_extractor(model, return_layers)

        # Check that we have same parameters
        for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(), fx_model.named_parameters()):
            assert n1 == n2
            assert p1.equal(p2)

        # And that outputs match
        with torch.no_grad():
            ilg_out = ilg_model(self.inp)
            fgn_out = fx_model(self.inp)
        assert all(k1 == k2 for k1, k2 in zip(ilg_out.keys(), fgn_out.keys()))
        for k in ilg_out.keys():
            assert ilg_out[k].equal(fgn_out[k])

    @pytest.mark.parametrize("model_name", models.list_models(models))
    def test_jit_forward_backward(self, model_name):
        set_rng_seed(0)
        model = models.get_model(model_name, **self.model_defaults).train()
        train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
        model = self._create_feature_extractor(
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
        model = torch.jit.script(model)
        fgn_out = model(self.inp)
        out_agg = 0
        for node_out in fgn_out.values():
            if isinstance(node_out, Sequence):
                out_agg += sum(o.float().mean() for o in node_out if o is not None)
            elif isinstance(node_out, Mapping):
                out_agg += sum(o.float().mean() for o in node_out.values() if o is not None)
            else:
                # Assume that the only other alternative at this point is a Tensor
                out_agg += node_out.float().mean()
        out_agg.backward()

    def test_train_eval(self):
        class TestModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.dropout = torch.nn.Dropout(p=1.0)

            def forward(self, x):
                x = x.float().mean()
                x = self.dropout(x)  # dropout
                if self.training:
                    x += 100  # add
                else:
                    x *= 0  # mul
                x -= 0  # sub
                return x

        model = TestModel()

        train_return_nodes = ["dropout", "add", "sub"]
        eval_return_nodes = ["dropout", "mul", "sub"]

        def checks(model, mode):
            with torch.no_grad():
                out = model(torch.ones(10, 10))
            if mode == "train":
                # Check that dropout is respected
                assert out["dropout"].item() == 0
                # Check that control flow dependent on training_mode is respected
                assert out["sub"].item() == 100
                assert "add" in out
                assert "mul" not in out
            elif mode == "eval":
                # Check that dropout is respected
                assert out["dropout"].item() == 1
                # Check that control flow dependent on training_mode is respected
                assert out["sub"].item() == 0
                assert "mul" in out
                assert "add" not in out

        # Starting from train mode
        model.train()
        fx_model = self._create_feature_extractor(
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
        # Check that the models stay in their original training state
        assert model.training
        assert fx_model.training
        # Check outputs
        checks(fx_model, "train")
        # Check outputs after switching to eval mode
        fx_model.eval()
        checks(fx_model, "eval")

        # Starting from eval mode
        model.eval()
        fx_model = self._create_feature_extractor(
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
        )
        # Check that the models stay in their original training state
        assert not model.training
        assert not fx_model.training
        # Check outputs
        checks(fx_model, "eval")
        # Check outputs after switching to train mode
        fx_model.train()
        checks(fx_model, "train")

    def test_leaf_module_and_function(self):
        class LeafModule(torch.nn.Module):
            def forward(self, x):
                # This would raise a TypeError if it were not in a leaf module
                int(x.shape[0])
                return torch.nn.functional.relu(x + 4)

        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 1, 3)
                self.leaf_module = LeafModule()

            def forward(self, x):
                leaf_function(x.shape[0])
                x = self.conv(x)
                return self.leaf_module(x)

        model = self._create_feature_extractor(
            TestModule(),
            return_nodes=["leaf_module"],
            tracer_kwargs={"leaf_modules": [LeafModule], "autowrap_functions": [leaf_function]},
        ).train()

        # Check that LeafModule is not in the list of nodes
        assert "relu" not in [str(n) for n in model.graph.nodes]
        assert "leaf_module" in [str(n) for n in model.graph.nodes]

        # Check forward
        out = model(self.inp)
        # And backward
        out["leaf_module"].float().mean().backward()

    def test_deepcopy(self):
        # Non-regression test for https://github.com/pytorch/vision/issues/8634
        model = models.efficientnet_b3(weights=None)
        extractor = create_feature_extractor(model=model, return_nodes={"classifier.0": "out"})

        extractor.eval()
        extractor.train()
        extractor = deepcopy(extractor)
        extractor.eval()
        extractor.train()
