# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import itertools
import torch
from functorch import vmap
import torch.utils._pytree as pytree
from functorch_additional_op_db import additional_op_db
from torch.testing._internal.common_methods_invocations import DecorateInfo
from torch.testing._internal.common_methods_invocations import op_db
import os
import unittest
from torch.testing._internal.common_device_type import toleranceOverride
from collections import namedtuple

IS_FBCODE = os.getenv('FUNCTORCH_TEST_FBCODE') == '1'


def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values):
    outs = []
    for idx in range(batch_size):
        flat_args, args_spec = pytree.tree_flatten(batched_args)
        flat_dims, dims_spec = pytree.tree_flatten(in_dims)
        assert(args_spec == dims_spec)
        new_args = [a.select(in_dim, idx) if in_dim is not None else a for a, in_dim in zip(flat_args, flat_dims)]
        out = op(*pytree.tree_unflatten(new_args, args_spec), **kwarg_values)
        outs.append(out)

    loop_out = []
    if isinstance(outs[0], torch.Tensor):
        loop_out = torch.stack(outs)
    else:
        for idx in range(len(outs[0])):
            loop_out.append(torch.stack([i[idx] for i in outs], out_dim))
    return loop_out


# Like loop helper function but for 2 levels of vmap. If we need more levels than this, probably possible
# to generalize the loops function but it seemed too complicated for this
def loop2(op, in_dims1, in_dims2, out_dim1, out_dim2, batch_size1, batch_size2, *batched_args, **kwarg_values):
    outs = []
    flat_args, args_spec = pytree.tree_flatten(batched_args)
    flat_dims1, dims_spec1 = pytree.tree_flatten(in_dims1)
    flat_dims2, dims_spec2 = pytree.tree_flatten(in_dims2)
    assert(args_spec == dims_spec1)
    assert(args_spec == dims_spec2)
    assert(len(flat_dims1) == len(flat_dims2))
    for idx1 in range(batch_size1):
        out_split = []
        arg_split = [a.select(in_dim1, idx1) if in_dim1 is not None else a for a, in_dim1 in zip(flat_args, flat_dims1)]
        for idx2 in range(batch_size2):
            new_args = [a.select(in_dim, idx2) if in_dim is not None else a for a, in_dim in zip(arg_split, flat_dims2)]
            out = op(*pytree.tree_unflatten(new_args, args_spec), **kwarg_values)
            out_split.append(out)
        outs.append(out_split)

    loop_out = []
    for out_split in outs:
        if isinstance(out_split[0], torch.Tensor):
            loop_out.append(torch.stack(out_split, out_dim1))
        else:
            new_out = []
            for idx in range(len(out_split[0])):
                new_out.append(torch.stack([i[idx] for i in out_split], out_dim1))
            loop_out.append(new_out)

    new_out = []
    if isinstance(loop_out, torch.Tensor):
        new_out = torch.stack(loop_out, out_dim2)
    else:
        for idx in range(len(loop_out[0])):
            new_out.append(torch.stack([i[idx] for i in loop_out], out_dim2))
    return new_out


def is_valid_inplace_sample_input(sample_input, op, inplace_variant):
    if inplace_variant is None:
        return False
    if sample_input.broadcasts_input:
        return False

    # Check if input's dtype matches the output's dtype
    args = (sample_input.input,) + sample_input.args
    kwargs = sample_input.kwargs
    output_dtype = op(*args, **kwargs).dtype
    return sample_input.input.dtype == output_dtype


# This is kind of dangerous, please think carefully before using it.
# Known risks:
# - the return better not be mutated so it's best to return immutable types
# (e.g. prefer tuples to list)
# - Don't hash tensors in a global context, that'll keep them around forever
def memoize(fn):
    memo = {}

    def wrapped(*args):
        if args not in memo:
            memo[args] = fn(*args)
        return memo[args]
    return wrapped


# NB: This is O(2 ** num_tensors).
# num_tensors ranges from 1 to 10, with 2-4 being most common.
# Try not to extravagate it if you're modifying it.
@memoize
def get_bdim_choices(num_tensors):
    choices = []

    # full of zeros
    choices.append((0,) * num_tensors)

    # All permutations of (-1, None)
    options = (-1, None)
    for choice in itertools.product(options, repeat=num_tensors):
        choices.append(choice)

    assert choices[-1] == (None,) * num_tensors
    return tuple(choices[:-1])

# NB: This is O(2 ** num_tensors).
# num_tensors ranges from 1 to 10, with 2-4 being most common.
# Try not to extravagate it if you're modifying it.
def get_bdim_choices_batch_norm(num_tensors, _, running_mean=None, running_var=None, *args):
    choices = []
    options = (-1, None)

    # instance norm turns these into unbatched 0 tensors, so we cannot batch the input if either is not specified
    if running_mean is None or running_var is None:
        choices.append((None,) + (0,) * (num_tensors - 1))
        for choice in itertools.product(options, repeat=num_tensors - 1):
            choices.append((None,) + choice)

    else:
        # running_mean and running_var are specified as tensors. Batch norm doesn't work if the input is batched but
        # running_mean/var are unbatched, so this tests all other cases
        choices.append((0,) * num_tensors)
        for choice in itertools.product(options, repeat=num_tensors):
            input_bdim = choice[0]
            running_mean_bdim = choice[1]
            running_var_bdim = choice[2]
            if input_bdim and (not running_mean_bdim or not running_var_bdim):
                continue
            choices.append(choice)

    assert choices[-1] == (None,) * num_tensors
    return tuple(choices[:-1])


def add_batch_dim(arg, bdim, batch_size=3):
    assert bdim == 0 or bdim == -1
    assert isinstance(arg, torch.Tensor)
    if bdim == 0:
        shape = [1] * len(arg.shape)
        shape.insert(bdim, batch_size)
        return (arg.repeat(shape), bdim)
    if bdim == -1:
        arg = arg.unsqueeze(-1).expand(*arg.shape, batch_size).contiguous()
        return (arg, bdim)


def construct_in_dims(bdim_choice_for_tensors, is_tensors):
    result = []
    bdim = iter(bdim_choice_for_tensors)
    for is_tensor in is_tensors:
        if not is_tensor:
            result.append(None)
            continue
        result.append(next(bdim))
    return tuple(result)


def is_batch_norm_training(op_name, kwarg_values):
    batch_norm_fns = ("nn.functional.batch_norm", "nn.functional.instance_norm")  # instance norm calls batch norm
    if op_name not in batch_norm_fns:
        return False

    # batch norm and instance norm require the value to be a plain bool
    default_training = op_name == "nn.functional.instance_norm"  # instance norm defaults to training, batch norm doesn't
    is_training = tuple(arg for arg in tuple(kwarg_values.values()) if isinstance(arg, bool))
    if len(is_training) == 0:
        return default_training
    else:
        assert len(is_training) == 1
        return is_training[0]


def generate_vmap_inputs(arg_values, kwarg_values, is_batch_norm_and_training=False, batch_size=2):
    flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values))
    is_tensors = [isinstance(a, torch.Tensor) for a in flat_args]
    num_tensors = sum(is_tensors)
    # For Batch Norm, if there's only an input, we can't
    # batch it since running_mean/var will be seen as unbatched tensors
    if num_tensors == 1 and is_batch_norm_and_training:
        return
    bdim_choices = get_bdim_choices_batch_norm(
        num_tensors, *arg_values) if is_batch_norm_and_training else get_bdim_choices(num_tensors)

    @memoize
    def get_batched_arg(arg, bdim):
        assert isinstance(arg, torch.Tensor)
        assert bdim is not None
        result, _ = add_batch_dim(arg, bdim, batch_size)
        return result

    for bdim_choice in bdim_choices:
        flat_in_dims = construct_in_dims(bdim_choice, is_tensors)

        flat_batched_args = tuple(arg if in_dim is None else get_batched_arg(arg, in_dim)
                                  for arg, in_dim in zip(flat_args, flat_in_dims))
        batched_args = pytree.tree_unflatten(flat_batched_args, arg_spec)
        in_dims = pytree.tree_unflatten(flat_in_dims, arg_spec)
        yield batched_args, in_dims, kwarg_values


def clone_if_tensor(x):
    if isinstance(x, torch.Tensor):
        return x.clone()
    return x


def compute_quantities_for_vmap_test(
        op, orig_batched_args, orig_kwarg_values, in_dims,
        out_dim=0, batch_size=2, compute_loop_out=True,
        clone_inputs=False):

    def maybe_clone_inputs():
        if clone_inputs:
            batched_args = pytree.tree_map(clone_if_tensor, orig_batched_args)
            kwarg_values = pytree.tree_map(clone_if_tensor, orig_kwarg_values)
            return batched_args, kwarg_values
        return orig_batched_args, orig_kwarg_values

    batched_args, kwarg_values = maybe_clone_inputs()
    if compute_loop_out:
        loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values)
    else:
        loop_out = None
    # Used for debugging the resulting operations
    # from functorch import make_fx
    # def f(a):
    #     return op(a)
    # t = make_fx(vmap(f, in_dims=in_dims, out_dims=out_dim))(*batched_args, **kwarg_values)
    # print(in_dims, [arg.shape for arg in batched_args], kwarg_values)
    batched_args, kwarg_values = maybe_clone_inputs()
    batched_out = vmap(op, in_dims=in_dims, out_dims=out_dim)(*batched_args, **kwarg_values)
    yield (loop_out, batched_out)

    # Tests case where we dispatch to a batching rule with no bdims
    # This should be handled by autogenerated plumbing. For vmap support
    # added via a manual plumbing you may need to handle this specially.
    def add_bdim_if_tensor(x):
        if isinstance(x, torch.Tensor):
            return x.unsqueeze(1)
        return x

    def f(dummy, *args, **kwargs):
        return op(*args, **kwargs)

    dummy = torch.ones(batch_size, 1)
    expected = pytree.tree_map(add_bdim_if_tensor, batched_out)

    inner_in_dims = (0,) + pytree.tree_map(lambda x: None, in_dims)
    outer_in_dims = (0,) + in_dims
    batched_args, kwarg_values = maybe_clone_inputs()
    output = vmap(vmap(f, inner_in_dims), outer_in_dims)(dummy, *batched_args, **kwarg_values)
    yield (expected, output)


def get_fallback_and_vmap_exhaustive(op, arg_values, kwarg_values, is_batch_norm_and_training=False, compute_loop_out=True):
    out_dim = 0
    batch_size = 2

    generator = generate_vmap_inputs(arg_values, kwarg_values, is_batch_norm_and_training)
    for batched_args, in_dims, kwarg_values in generator:
        for quantities in compute_quantities_for_vmap_test(
                op, batched_args, kwarg_values, in_dims, out_dim, batch_size, compute_loop_out):
            yield quantities


def opinfo_in_dict(opinfo, d):
    return (opinfo.name in d) or (f'{opinfo.name}.{opinfo.variant_test_name}' in d)


DecorateMeta = namedtuple("DecorateMeta", [
    "op_name",
    "variant_name",
    "decorator",
    "device_type",
    "dtypes",
])


def decorate(op_name, variant_name='', *, decorator=None, device_type=None, dtypes=None):
    assert decorator is not None
    return DecorateMeta(op_name=op_name,
                        variant_name=variant_name,
                        decorator=decorator,
                        device_type=device_type,
                        dtypes=dtypes)


def xfail(op_name, variant_name='', *, device_type=None, dtypes=None):
    return decorate(op_name=op_name,
                    variant_name=variant_name,
                    decorator=unittest.expectedFailure,
                    device_type=device_type,
                    dtypes=dtypes)


def skip(op_name, variant_name='', *, device_type=None, dtypes=None):
    return decorate(op_name=op_name,
                    variant_name=variant_name,
                    decorator=unittest.skip("Skipped!"),
                    device_type=device_type,
                    dtypes=dtypes)


def skipOps(test_case_name, base_test_name, to_skip):
    all_opinfos = op_db + additional_op_db
    for decorate_meta in to_skip:
        matching_opinfos = [o for o in all_opinfos
                            if o.name == decorate_meta.op_name and
                            o.variant_test_name == decorate_meta.variant_name]
        assert len(matching_opinfos) > 0, f"Couldn't find OpInfo for {decorate_meta}"
        assert len(matching_opinfos) == 1, (
            "OpInfos should be uniquely determined by their (name, variant_name). "
            f"Got more than one result for ({decorate_meta.op_name}, {decorate_meta.variant_name})"
        )
        opinfo = matching_opinfos[0]
        decorators = list(opinfo.decorators)
        new_decorator = DecorateInfo(decorate_meta.decorator,
                                     test_case_name, base_test_name,
                                     device_type=decorate_meta.device_type,
                                     dtypes=decorate_meta.dtypes)
        decorators.append(new_decorator)
        opinfo.decorators = tuple(decorators)

    # This decorator doesn't modify fn in any way
    def wrapped(fn):
        return fn
    return wrapped


def expectedFailureIf(condition):
    def decorator(fn):
        if condition:
            return unittest.expectedFailure(fn)
        return fn
    return decorator


def tol2(op_name, variant_name, override_dct, *, device_type=None):
    return (op_name, variant_name, override_dct, device_type)


def tol1(op_name, override_dct, *, device_type=None):
    return tol2(op_name, '', override_dct, device_type=device_type)


def opsToleranceOverride(test_case_name, base_test_name, overrides):
    all_opinfos = op_db + additional_op_db
    for override in overrides:
        op_name, variant_name, override, device_type = override
        matching_opinfos = [o for o in all_opinfos
                            if o.name == op_name and o.variant_test_name == variant_name]
        assert len(matching_opinfos) == 1, f"Couldn't find OpInfo for {override}"
        opinfo = matching_opinfos[0]
        decorators = list(opinfo.decorators)
        decorators.append(DecorateInfo(
            toleranceOverride(override),
            test_case_name, base_test_name, device_type=device_type))
        opinfo.decorators = tuple(decorators)

    # This decorator doesn't modify fn in any way
    def wrapped(fn):
        return fn
    return wrapped


class DisableVmapFallback:
    def __enter__(self):
        self.prev_state = torch._C._functorch._is_vmap_fallback_enabled()
        torch._C._functorch._set_vmap_fallback_enabled(False)

    def __exit__(self, *ignored):
        torch._C._functorch._set_vmap_fallback_enabled(self.prev_state)


def check_vmap_fallback(test_case, thunk, opinfo, dry_run=False):
    try:
        with DisableVmapFallback():
            thunk()
    except Exception:
        if not dry_run:
            raise
        if opinfo.variant_test_name:
            print(f"xfail('{opinfo.name}', '{opinfo.variant_test_name}'),")
        else:
            print(f"xfail('{opinfo.name}'),")
