# Owner(s): ["module: dynamo"]

import unittest
from contextlib import contextmanager
from importlib import import_module

import torch
import torch._prims_common as utils
from torch._dynamo.utils import preserve_rng_state
from torch._inductor import config
from torch._inductor.compiler_bisector import CompilerBisector
from torch._inductor.test_case import TestCase
from torch.library import _scoped_library, Library
from torch.testing._internal.inductor_utils import HAS_CUDA


aten = torch.ops.aten

requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")

f32 = torch.float32
i64 = torch.int64
i32 = torch.int32


@requires_cuda
class TestCompilerBisector(TestCase):
    test_ns = "_test_bisector"

    def tearDown(self):
        if hasattr(torch.ops, self.test_ns):
            delattr(torch.ops, self.test_ns)
        if hasattr(self, "lib"):
            del self.lib.m
            del self.lib

    def get_op(self, name):
        return getattr(getattr(torch.ops, self.test_ns), name).default

    def get_lib(self):
        lib = Library(self.test_ns, "FRAGMENT")  # noqa: TOR901
        self.lib = lib
        return lib

    def test_bad_decomp(self):
        mod = import_module("torch._inductor.compile_fx")

        def bad_exp_decomp(self, rate=1, generator=None):
            assert generator is None
            torch._check(
                not utils.is_complex_dtype(self.dtype)
                and not utils.is_integer_dtype(self.dtype)
                and not utils.is_boolean_dtype(self.dtype),
                lambda: f"Exponential distribution is a continuous probability distribution. \
                dtype must be a floating point but you specified {self.dtype}",
            )
            torch._check(
                rate > 0.0,
                lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}",
            )
            return torch.rand_like(self) * float("nan")

        @contextmanager
        def patch_exp_decomp():
            from torch._inductor.compile_fx import select_decomp_table as old_decomp

            def get_decomp():
                out = old_decomp()
                out = out.copy()
                out[aten.exponential.default] = bad_exp_decomp
                return out

            torch._inductor.compile_fx.select_decomp_table = get_decomp
            try:
                yield

            finally:
                torch._inductor.compile_fx.select_decomp_table = old_decomp

        def vq(x):
            return (x + 3).exponential_() * 10.5

        def test_fn():
            torch._dynamo.reset()
            with patch_exp_decomp():
                vq_compiled = torch.compile(vq)
                x = torch.randn(4, 400, 256).cuda()
                with torch._dynamo.utils.preserve_rng_state():
                    out = vq(x)
                out_compiled = vq_compiled(x)

            return not out_compiled.isnan().any()

        out = CompilerBisector.do_bisect(test_fn)
        self.assertEqual(out.backend, "aot_eager_decomp_partition")
        self.assertEqual(out.subsystem, "decomposition")
        self.assertEqual(out.bisect_number, 1)
        self.assertTrue("aten.exponential" in out.debug_info)

    def test_joint_graph(self):
        from torch._inductor import config

        def pass_fn(graph: torch.fx.Graph):
            nodes = graph.find_nodes(
                op="call_function", target=torch.ops.aten.add.Tensor
            )
            assert len(nodes) == 1
            args = list(nodes[0].args)
            args[1] = 2
            nodes[0].args = tuple(args)

        config.joint_custom_post_pass = pass_fn

        def foo(x):
            return x + 1

        def test_fn():
            torch._dynamo.reset()

            inp = torch.rand([10], device="cuda")

            out = foo(inp)
            out_c = torch.compile(foo)(inp)

            return torch.allclose(out, out_c)

        out = CompilerBisector.do_bisect(test_fn)
        self.assertEqual(out.backend, "inductor")
        self.assertEqual(out.subsystem, "joint_graph_passes")
        self.assertEqual(out.bisect_number, 4)
        self.assertTrue("joint_custom_post_pass" in out.debug_info)

    def test_rng(self):
        def foo():
            return torch.rand([10], device="cuda") + 1

        def test_fn():
            torch._dynamo.reset()

            with preserve_rng_state():
                out = foo()
            with preserve_rng_state():
                out_c = torch.compile(foo)()

            return torch.allclose(out, out_c)

        out = CompilerBisector.do_bisect(test_fn)
        self.assertEqual(out.backend, "inductor")
        self.assertEqual(out.subsystem, "inductor_fallback_random")
        self.assertTrue("inductor_fallback_random" in out.debug_info)

    def test_crossref(self):
        test_ns = "bisect_ops"
        with _scoped_library(self.test_ns, "FRAGMENT") as lib:
            lib.define("foo(Tensor x) -> Tensor")
            op = self.get_op("foo")

            class Foo(torch.autograd.Function):
                @staticmethod
                def forward(ctx, x):
                    # Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
                    with torch._C._AutoDispatchBelowAutograd():
                        with torch._C._ExcludeDispatchKeyGuard(
                            torch._C.DispatchKeySet(
                                torch._C.DispatchKey.ADInplaceOrView
                            )
                        ):
                            return op(x)

                @staticmethod
                def backward(ctx, gx):
                    return gx

            def foo_impl(x):
                return x.view_as(x).clone()

            def foo_meta(x):
                return x.view_as(x)

            lib.impl("foo", Foo.apply, "Autograd")
            lib.impl("foo", foo_impl, "CPU")
            lib.impl("foo", foo_meta, "Meta")

            x = torch.tensor(3.14159 / 3, requires_grad=True)

            def test_fn():
                torch._dynamo.reset()

                try:
                    torch.testing.assert_allclose(torch.compile(op)(x), op(x))
                except Exception:
                    return False
                return True

            out = CompilerBisector.do_bisect(test_fn)
            self.assertEqual(out.backend, "aot_eager_decomp_partition_crossref")

    def test_emulate_precision_casts(self):
        def test_fn():
            torch._dynamo.reset()

            def calculate_scale(inp):
                amax = torch.abs(torch.max(inp))
                scale = 448.0 / torch.clamp(amax, min=1e-12)
                scale = scale.to(torch.float32)
                return scale

            dtype = torch.bfloat16
            torch.manual_seed(0)
            inp = torch.randn(16, 16, 768, dtype=dtype, device="cuda")
            eager_scale = calculate_scale(inp)
            compile_scale = torch.compile(calculate_scale)(inp)

            return torch.equal(eager_scale, compile_scale)

        out = CompilerBisector.do_bisect(test_fn)
        self.assertEqual(out.backend, "inductor")
        self.assertEqual(out.subsystem, "inductor_emulate_precision_casts")

    def test_bad_lowering(self):
        def test_fn():
            torch._dynamo.reset()
            with config.patch("triton.inject_relu_bug_TESTING_ONLY", "accuracy"):

                def my_func(x):
                    return ((x * -1) - 0.01).relu()

                inp = torch.rand([100], device="cuda")

                return torch.allclose(torch.compile(my_func)(inp), my_func(inp))

        out = CompilerBisector.do_bisect(test_fn)
        self.assertEqual(out.backend, "inductor")
        self.assertEqual(out.subsystem, "lowerings")
        self.assertEqual(out.bisect_number, 2)
        self.assertTrue("relu" in out.debug_info)

    def test_eager_backend(self):
        # should indicate problem with first backend
        def test_fn():
            return False

        out = CompilerBisector.do_bisect(test_fn)
        self.assertEqual(out.backend, "eager")
        self.assertEqual(out.subsystem, None)


if __name__ == "__main__":
    from torch._dynamo.test_case import run_tests

    run_tests()
