# Owner(s): ["oncall: jit"]

import os
import sys
import io

import torch
import warnings
from contextlib import redirect_stderr
from torch.testing import FileCheck

# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase

if __name__ == '__main__':
    raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
                       "\tpython test/test_jit.py TESTNAME\n\n"
                       "instead.")


class TestWarn(JitTestCase):
    def test_warn(self):
        @torch.jit.script
        def fn():
            warnings.warn("I am warning you")

        f = io.StringIO()
        with redirect_stderr(f):
            fn()

        FileCheck() \
            .check_count(
                str="UserWarning: I am warning you",
                count=1,
                exactly=True) \
            .run(f.getvalue())

    def test_warn_only_once(self):
        @torch.jit.script
        def fn():
            for _ in range(10):
                warnings.warn("I am warning you")

        f = io.StringIO()
        with redirect_stderr(f):
            fn()

        FileCheck() \
            .check_count(
                str="UserWarning: I am warning you",
                count=1,
                exactly=True) \
            .run(f.getvalue())

    def test_warn_only_once_in_loop_func(self):
        def w():
            warnings.warn("I am warning you")

        @torch.jit.script
        def fn():
            for _ in range(10):
                w()

        f = io.StringIO()
        with redirect_stderr(f):
            fn()

        FileCheck() \
            .check_count(
                str="UserWarning: I am warning you",
                count=1,
                exactly=True) \
            .run(f.getvalue())

    def test_warn_once_per_func(self):
        def w1():
            warnings.warn("I am warning you")

        def w2():
            warnings.warn("I am warning you")

        @torch.jit.script
        def fn():
            w1()
            w2()

        f = io.StringIO()
        with redirect_stderr(f):
            fn()

        FileCheck() \
            .check_count(
                str="UserWarning: I am warning you",
                count=2,
                exactly=True) \
            .run(f.getvalue())

    def test_warn_once_per_func_in_loop(self):
        def w1():
            warnings.warn("I am warning you")

        def w2():
            warnings.warn("I am warning you")

        @torch.jit.script
        def fn():
            for _ in range(10):
                w1()
                w2()

        f = io.StringIO()
        with redirect_stderr(f):
            fn()

        FileCheck() \
            .check_count(
                str="UserWarning: I am warning you",
                count=2,
                exactly=True) \
            .run(f.getvalue())

    def test_warn_multiple_calls_multiple_warnings(self):
        @torch.jit.script
        def fn():
            warnings.warn("I am warning you")

        f = io.StringIO()
        with redirect_stderr(f):
            fn()
            fn()

        FileCheck() \
            .check_count(
                str="UserWarning: I am warning you",
                count=2,
                exactly=True) \
            .run(f.getvalue())

    def test_warn_multiple_calls_same_func_diff_stack(self):
        def warn(caller: str):
            warnings.warn("I am warning you from " + caller)

        @torch.jit.script
        def foo():
            warn("foo")

        @torch.jit.script
        def bar():
            warn("bar")

        f = io.StringIO()
        with redirect_stderr(f):
            foo()
            bar()

        FileCheck() \
            .check_count(
                str="UserWarning: I am warning you from foo",
                count=1,
                exactly=True) \
            .check_count(
                str="UserWarning: I am warning you from bar",
                count=1,
                exactly=True) \
            .run(f.getvalue())
