# Owner(s): ["oncall: export"]


def random_dag(n: int):
    """
    Util to generate a random DAG with n nodes.

    The nodes are numbered 0, 1, ..., n-1. The DAG is generated by randomly
    choosing a subset of edges from the complete graph on n nodes, such that
    for each (i, j) we have i < j.
    """
    import random

    edges = {}
    for i in range(n):
        edges[i] = []
        for j in range(i + 1, n):
            if random.choice([True, False]):
                edges[i].append(j)

    return edges


class Block:
    """
    Util to generate a block of Python-formatted code.
    """

    def __init__(self):
        self._code = []

    def __repr__(self):
        return "".join(self._code)

    def new_line(self, line: str):
        """
        Add a new line of code. The line is automatically suffixed
        with a newline character.
        """
        self._code.append(line + "\n")

    def new_block(self, block: "Block"):
        """
        Add a new block of code. All lines in the new block are
        automatically prefixed by a tab character.
        """
        self._code.extend("  " + line for line in block._code)


class TestGenerator:
    """
    Abstract base class for generating test code.

    Users should subclass this class and implement the test_name() and
    test_body() methods. The test_name() method should return a string
    that uniquely identifies the test. The test_body() method should
    yield blocks of code.
    """

    def __init__(self):
        self._count = 0

    def _generate_test_name(self):
        self._count += 1
        return f"{self.test_name()}_{self._count}"

    def generate_test(self):
        test_name = self._generate_test_name()

        code = Block()
        code.new_line(f"def {test_name}():")
        for block in self.test_body():
            code.new_block(block)
        code.new_line(f"{test_name}()")
        return str(code)

    def test_name(self):
        raise NotImplementedError

    def test_body(self):
        raise NotImplementedError


class NNModuleGenerator:
    """
    Abstract base class for generating a nn.Module.

    Users should subclass this class and implement the gen_init_body() and
    gen_forward_body() methods. The gen_init_body() method should return a
    block of code that initializes the nn.Module. The gen_forward_body() method
    should return a block of code that defines the forward() of the nn.Module.
    """

    def gen_init_body(self, i: int):
        raise NotImplementedError

    def gen_forward_body(self, i: int):
        raise NotImplementedError

    def gen_nn_module(self, i: int):
        def gen_nn_module_body():
            code = Block()
            code.new_line("def __init__(self):")
            code.new_block(self.gen_init_body(i))
            code.new_line("def forward(self, x):")
            code.new_block(self.gen_forward_body(i))
            return code

        code = Block()
        code.new_line(f"class N{i}(torch.nn.Module):")
        code.new_block(gen_nn_module_body())
        return code


class Unflatten(TestGenerator):
    """
    Generates test that unflattens a model with several nn.Modules that call
    each other. The modules are generated by calling the nn_module_generator()
    method.

    The model is exported and then unflattened. The unflattened model is then
    compared against the eager model.
    """

    def __init__(self, n: int):
        super().__init__()
        self.n = n

    def nn_module_generator(self):
        class GenNNModule(NNModuleGenerator):
            def __init__(self, n: int):
                super().__init__()
                self.n = n
                self.calls = random_dag(self.n)

            def gen_init_body(self, i: int):
                code = Block()
                code.new_line("super().__init__()")
                if i < self.n - 1:
                    code.new_line(f"self.n{i+1} = N{i+1}()")
                return code

            def gen_forward_body(self, i: int):
                def path(i, j):
                    if i + 1 == j:
                        return f"n{j}"
                    else:
                        return f"n{i + 1}.{path(i + 1, j)}"

                code = Block()
                for j in self.calls[i]:
                    code.new_line(f"x = self.{path(i, j)}(x + 1)")
                code.new_line("return x + 1")
                return code

        return GenNNModule(self.n)

    def test_name(self):
        return f"{self.__class__.__name__}_{self.n}"

    def test_body(self):
        def path(i, j):
            if i + 1 == j:
                return f"n{j}"
            else:
                return f"n{i + 1}.{path(i + 1, j)}"

        nn_module_generator = self.nn_module_generator()
        for i in range(self.n):
            yield nn_module_generator.gen_nn_module(self.n - 1 - i)

        fqns = "".join(f"'{path(0, j)},'" for j in range(1, self.n))

        def gen_main():
            code = Block()
            code.new_line("inp = (torch.ones(1),)")
            code.new_line("eager = N0()(*inp)")
            code.new_line(
                f"ep = torch.export.export(N0(), inp, strict=False, preserve_module_call_signature=({fqns}))"
            )
            code.new_line("epm = ep.module()")
            code.new_line("ufm = torch.export.unflatten(ep)")
            code.new_line("assert torch.allclose(epm(*inp), eager)")
            code.new_line("assert torch.allclose(ufm(*inp), eager)")
            return code

        yield gen_main()


class ConstantUnflatten(Unflatten):
    """
    Generates test that unflattens a model with several nn.Modules that call
    each other and access constants. The modules are generated by calling the
    nn_module_generator() method.
    """

    def nn_module_generator(self):
        class GenNNModule(NNModuleGenerator):
            def __init__(self, n):
                super().__init__()
                self.n = n
                self.accesses = random_dag(self.n)
                self.calls = random_dag(self.n)

            def gen_init_body(self, i: int):
                code = Block()
                code.new_line("super().__init__()")
                code.new_line("self.const = torch.ones(1)")
                if i < self.n - 1:
                    code.new_line(f"self.n{i+1} = N{i+1}()")
                return code

            def gen_forward_body(self, i: int):
                def path(i, j):
                    if i + 1 == j:
                        return f"n{j}"
                    else:
                        return f"n{i + 1}.{path(i + 1, j)}"

                code = Block()
                for j in self.accesses[i]:
                    code.new_line(f"x = x + self.{path(i, j)}.const")
                for j in self.calls[i]:
                    code.new_line(f"x = self.{path(i, j)}(x + 1)")
                code.new_line("return x + 1")
                return code

        return GenNNModule(self.n)


class BufferUnflatten(Unflatten):
    """
    Generates test that unflattens a model with several nn.Modules that call
    each other and access and mutate buffers. The modules are generated by
    calling the nn_module_generator() method.
    """

    def nn_module_generator(self):
        class GenNNModule(NNModuleGenerator):
            def __init__(self, n):
                super().__init__()
                self.n = n
                self.accesses = random_dag(self.n)
                self.mutations = random_dag(self.n)
                self.calls = random_dag(self.n)

            def gen_init_body(self, i: int):
                code = Block()
                code.new_line("super().__init__()")
                code.new_line("self.buf = torch.nn.Buffer(torch.ones(1))")
                if i < self.n - 1:
                    code.new_line(f"self.n{i+1} = N{i+1}()")
                return code

            def gen_forward_body(self, i: int):
                def path(i, j):
                    if i + 1 == j:
                        return f"n{j}"
                    else:
                        return f"n{i + 1}.{path(i + 1, j)}"

                code = Block()
                for j in self.accesses[i]:
                    code.new_line(f"x = x + self.{path(i, j)}.buf")
                for j in self.calls[i]:
                    code.new_line(f"x = self.{path(i, j)}(x + 1)")
                for j in self.mutations[i]:
                    code.new_line(f"self.{path(i, j)}.buf.add_(1)")
                code.new_line("return x + 1")
                return code

        return GenNNModule(self.n)
