# Owner(s): ["module: functorch"]

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import itertools
import unittest

from torch.testing._internal.common_utils import TestCase, run_tests, is_iterable_of_tensors, IS_ARM64, parametrize
import torch
from torch import Tensor
import functools
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_device_type import ops
from torch.testing._internal.common_device_type import \
    toleranceOverride, tol
from functorch_additional_op_db import additional_op_db
from torch.testing._internal.common_methods_invocations import op_db
from common_utils import (
    get_fallback_and_vmap_exhaustive,
    generate_vmap_inputs,
    decorate,
    xfail,
    skip,
    skipOps,
    tol1,
    tol2,
    opsToleranceOverride,
    check_vmap_fallback,
    is_batch_norm_training,
    is_valid_inplace_sample_input,
    loop2,
)

from torch.testing._internal.opinfo.core import SampleInput
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
from functorch import grad, vjp, vmap, jacrev, jacfwd
import torch.autograd.forward_ad as fwAD
from functorch._src.eager_transforms import _as_tuple, jvp
aten = torch.ops.aten


# Version of autograd.grad with some differences:
#   - pytree inputs is allowed (but leaves of the pytree have to all
#     be tensors)
#   - if an input is not used as part of derivatives, we will return a
#     zero-filled tensor for the result
def _autograd_grad(
    outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True
):
    inputs, inputs_spec = tree_flatten(inputs)
    diff_inputs = tuple(inp for inp in inputs if inp.requires_grad)
    if grad_outputs is None:
        diff_outputs = tuple(out for out in outputs if out.requires_grad)
    else:
        diff_grad_outputs = [
            (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad
        ]
        if len(diff_grad_outputs) == 0:
            diff_outputs, grad_outputs = (), ()
        else:
            diff_outputs, grad_outputs = zip(*diff_grad_outputs)
    grad_inputs = torch.autograd.grad(
        diff_outputs,
        diff_inputs,
        grad_outputs,
        retain_graph=retain_graph,
        create_graph=create_graph,
        allow_unused=True,
    )
    result = []
    grad_inputs_iter = iter(grad_inputs)
    for inp in inputs:
        if inp.requires_grad:
            grad_input = next(grad_inputs_iter)
            if grad_input is None:
                result.append(torch.zeros_like(inp))
            else:
                result.append(grad_input)
        else:
            result.append(torch.zeros_like(inp))
    return tree_unflatten(result, inputs_spec)


def diff_arg(arg, requires_grad=True):
    def is_differentiable_arg(arg):
        if requires_grad:
            return arg.requires_grad
        else:
            return arg.is_floating_point() or arg.is_complex()
    if is_iterable_of_tensors(arg):
        if all([is_differentiable_arg(a) for a in arg]):
            return True
        if all([not is_differentiable_arg(a) for a in arg]):
            return False
        raise RuntimeError("NYI: The test runner can't handle this")
    return isinstance(arg, Tensor) and is_differentiable_arg(arg)


# Given f, returns an f' such that:
# - f' takes only positional arguments
# - All arguments to f' are floating-point Tensors
# - All outputs of f' are floating-point Tensors
def normalize_op_input_output2(f, args, kwargs, output_process_fn_grad=None, requires_grad=True):
    flat_args, args_spec = tree_flatten(args)
    diff_argnums = tuple(i for i, arg in enumerate(flat_args) if diff_arg(arg, requires_grad=requires_grad))
    assert len(diff_argnums) > 0
    primals = tuple(flat_args[i] for i in diff_argnums)

    @functools.wraps(f)
    def wrapped(*primals):
        _args = list(flat_args)
        for num, arg in zip(diff_argnums, primals):
            _args[num] = arg
        _args = tree_unflatten(_args, args_spec)
        result = f(*_args, **kwargs)
        if output_process_fn_grad is not None:
            result = output_process_fn_grad(result)
        if isinstance(result, tuple):
            result = tuple(r for r in result if torch.is_floating_point(r))
            assert len(result) > 0
        return result
    return wrapped, primals


# TODO: consolidate with normalize_op_input_output2
def normalize_op_input_output3(f, args, kwargs, sample_args, output_process_fn_grad=None):
    flat_args, args_spec = tree_flatten(args)
    flat_sample_args, _ = tree_flatten(sample_args)
    diff_argnums = tuple(i for i, (arg, sample) in enumerate(zip(flat_args, flat_sample_args))
                         if diff_arg(sample, requires_grad=True))
    assert len(diff_argnums) > 0
    primals = tuple(flat_args[i] for i in diff_argnums)

    @functools.wraps(f)
    def wrapped(*primals):
        _args = list(flat_args)
        for num, arg in zip(diff_argnums, primals):
            _args[num] = arg
        _args = tree_unflatten(_args, args_spec)
        result = f(*_args, **kwargs)
        if output_process_fn_grad is not None:
            result = output_process_fn_grad(result)
        if isinstance(result, tuple):
            result = tuple(r for r in result if torch.is_floating_point(r))
            assert len(result) > 0
        return result
    return wrapped, primals


def normalize_op_input_output(f, sample, requires_grad=True):
    args = tuple([sample.input] + list(sample.args))
    return normalize_op_input_output2(
        f, args, sample.kwargs, sample.output_process_fn_grad, requires_grad=requires_grad
    )


def ref_vjp(f, *primals):
    result = f(*primals)

    def wrapped(cotangents):
        return _autograd_grad(_as_tuple(result), primals, _as_tuple(cotangents))

    return result, wrapped


def simulate_jvp(f, primals, tangents):
    primals_out, tangents_out = torch.autograd.functional.jvp(f, primals, tangents)
    return primals_out, tangents_out


def ref_jvp(f, primals, tangents):
    with fwAD.dual_level():
        duals = tuple(fwAD.make_dual(p, t) for p, t in zip(primals, tangents))
        result_duals = f(*duals)
        result_duals, spec = tree_flatten(result_duals)
        primals_out, tangents_out = zip(*(fwAD.unpack_dual(d) for d in result_duals))
        return tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec)


def get_sample_cotangents(f, sample):
    fn, primals = normalize_op_input_output(f, sample)
    output = fn(*primals)
    return tree_map(torch.randn_like, output)


# returns a new function g(*args, *cotangents)
# that computes vjps and (*args, cotangents)
def get_vjp_fn_and_args_with_cotangents(f, sample, cotangents):
    args = tuple([sample.input] + list(sample.args))
    kwargs = sample.kwargs
    flat_args, args_spec = tree_flatten(args)
    flat_cotangents, cotangents_spec = tree_flatten(cotangents)

    @functools.wraps(f)
    def wrapped(*args):
        assert len(args) == len(flat_args) + len(flat_cotangents)
        actual_args = args[:len(flat_args)]
        cotangents = args[len(flat_args):]
        actual_args = tree_unflatten(actual_args, args_spec)
        cotangents = tree_unflatten(cotangents, cotangents_spec)

        fn, primals = normalize_op_input_output3(f, actual_args, kwargs,
                                                 flat_args,
                                                 sample.output_process_fn_grad)
        _, vjp_fn = vjp(fn, *primals)
        return vjp_fn(cotangents)

    return wrapped, tuple(flat_args + flat_cotangents)


# Returns a new function g(*args, *cotangents) that computes vjps and
# sample (*args, *cotangents)
def get_vjpfull_variant(f, sample):
    fn, primals = normalize_op_input_output(f, sample)
    result = fn(*primals)
    cotangents = _as_tuple(
        tree_map(lambda x: torch.randn_like(x, requires_grad=True), result))
    num_primals = len(primals)
    args = (*primals, *cotangents)

    @functools.wraps(f)
    def wrapped(*args):
        primals = args[:num_primals]
        cotangents = args[num_primals:]
        result, vjp_fn = vjp(fn, *primals)
        if isinstance(result, torch.Tensor):
            assert len(cotangents) == 1
            cotangents = cotangents[0]
        return vjp_fn(cotangents)

    return wrapped, args


def get_jvp_variant(f, sample):
    # We want this higher-order variant of jvp, so that it can
    # be used to wrap vmap
    fn, primals = normalize_op_input_output(f, sample, requires_grad=False)
    tangents = _as_tuple(
        tree_map(lambda x: torch.randn_like(x), primals))

    @functools.wraps(f)
    def wrapped(*args):
        tangents = args
        primals_out, tangents_out = jvp(fn, primals, tangents)

        if isinstance(primals_out, torch.Tensor):
            return (primals_out, tangents_out)
        else:
            flat_primals_out, _ = tree_flatten(primals_out)
            flat_tangents_out, _ = tree_flatten(tangents_out)
            return tuple(flat_primals_out + flat_tangents_out)

    return wrapped, tangents


def get_jvp_variant_primals_tangents(f, sample):
    # We want this higher-order variant of jvp, so that it can
    # be used to wrap vmap
    fn, primals = normalize_op_input_output(f, sample, requires_grad=False)
    tangents = _as_tuple(
        tree_map(lambda x: torch.randn_like(x), primals))

    @functools.wraps(f)
    def wrapped(*args):
        primals_in = args[:len(primals)]
        tangents_in = args[len(primals):]
        primals_out, tangents_out = jvp(fn, primals_in, tangents_in)

        if isinstance(primals_out, torch.Tensor):
            return (primals_out, tangents_out)
        else:
            flat_primals_out, _ = tree_flatten(primals_out)
            flat_tangents_out, _ = tree_flatten(tangents_out)
            return tuple(flat_primals_out + flat_tangents_out)

    return wrapped, primals + tangents


def is_inplace(op, variant):
    if hasattr(variant, "__wrapped__"):
        return variant.__wrapped__ is op.get_inplace()
    return variant is op.get_inplace()


vjp_fail = {
    xfail('tensor_split'),  # data_ptr composite compliance
}

aliasing_ops = {
    'T',
    'broadcast_to',
    'conj',
    'contiguous',
    'diagonal',  # linalg.diagonal is an alias
    'expand',
    'flatten',
    'imag',
    'mH',  # adjoint is an alias
    'mT',
    'movedim',  # moveaxis is an alias
    'narrow',
    'permute',
    'positive',
    # 'ravel', is composite implict autograd and may call clone
    'real',
    'reshape',
    'resolve_conj',
    'resolve_neg',
    'select',
    'squeeze',
    'transpose',  # swapdims and swapaxes are aliases
    'unflatten',
    'unfold',
    'unsqueeze',
    'view',
    'view_as',
    'view_as_complex',
    'view_as_real',
}

aliasing_ops_list_return = {
    'chunks',
    'dsplit',
    'hsplit',
    'split',
    'unbind',
    'vsplit',
    # 'tensor_split' not composite compliant, see vjp_fail
}


class TestOperators(TestCase):
    @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
    @skipOps('TestOperators', 'test_grad', vjp_fail.union({
        xfail('linalg.eig'),  # diagonal_scatter does not support complex
        xfail('chalf', '', device_type='cpu'),  # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
        skip('as_strided_scatter', ''),  # silent incorrectness; seems flaky
        xfail('sparse.sampled_addmm', ''),  # RuntimeError: Sparse CSR tensors do not have strides
        xfail('to_sparse', ''),  # Could not run 'aten::sum.dim_IntList'
    }))
    @opsToleranceOverride('TestOperators', 'test_grad', (
        tol1('nn.functional.binary_cross_entropy_with_logits',
             {torch.float32: tol(atol=1e-04, rtol=1e-04)}),
    ))
    def test_grad(self, device, dtype, op):
        if op.name in vjp_fail:
            self.skipTest("Skipped; Expected failures")
            return

        if not op.supports_autograd:
            self.skipTest("Skipped! Autograd not supported.")
            return

        samples = op.sample_inputs(device, dtype, requires_grad=True)

        if is_inplace(op, op.get_op()):
            self.skipTest("Skipped for redundancy. test_vjp handles in-place testing.")
            return

        for sample in samples:
            args = [sample.input] + list(sample.args)
            kwargs = sample.kwargs

            diff_argnums = tuple(i for i, arg in enumerate(args) if diff_arg(arg))
            assert len(diff_argnums) > 0
            diff_args = tuple(args[i] for i in diff_argnums)

            def wrapped_fn(*args, **kwargs):
                result = op(*args, **kwargs)
                if sample.output_process_fn_grad is not None:
                    result = sample.output_process_fn_grad(result)

                # Reduce into single value for grad
                if isinstance(result, torch.Tensor):
                    return result.sum()
                result = sum([res.sum() for res in result])
                return result

            result = grad(wrapped_fn, diff_argnums)(*args, **kwargs)
            expected = _autograd_grad(_as_tuple(wrapped_fn(*args, **kwargs)), diff_args)

            self.assertEqual(result, expected)

    @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
    @skipOps('TestOperators', 'test_jvp', set({
        # Composite ops that do bad things. Need to be fixed in PyTorch core.
        # RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
        xfail('tensor_split'),

        # BUG: silent incorrectness: runs and produces numerical differences
        skip('nn.functional.max_unpool1d'),  # fails everywhere except on mac
        skip('nn.functional.max_unpool2d'),  # fails everywhere except on windows
        skip('nn.functional.max_unpool3d'),  # fails everywhere except on mac
        xfail("native_batch_norm"),

        xfail('nn.functional.rrelu')  # in-place test errors out with no formula implemented
    }))
    @opsToleranceOverride('TestOperators', 'test_jvp', (
        tol1('nn.functional.conv_transpose3d',
             {torch.float32: tol(atol=1e-04, rtol=1.3e-06)}, device_type='cuda'),
        tol1('nn.functional.binary_cross_entropy_with_logits',
             {torch.float32: tol(atol=4e-04, rtol=4e-04)}),
    ))
    def test_jvp(self, device, dtype, op):
        # TODO: get rid of vjp_decomp when we add decomposition support to
        # PyTorch's forward-mode ad. Currently the decomposition support only
        # works for functorch.jvp
        VJP_DECOMP = {
            'nn.functional.logsigmoid',
        }
        if op.name in VJP_DECOMP:
            fixme_ref_jvp_local = simulate_jvp
        else:
            fixme_ref_jvp_local = ref_jvp

        if not op.supports_forward_ad and op.name not in VJP_DECOMP:
            self.skipTest("Skipped! Forward AD not supported.")
            return

        samples = op.sample_inputs(device, dtype, requires_grad=True)

        outplace_variant = op if not is_inplace(op, op.get_op()) else None
        inplace_variant = op.inplace_variant if op.supports_inplace_autograd else None

        for sample in samples:
            args = (sample.input,) + sample.args
            kwargs = sample.kwargs
            if outplace_variant:
                self.jvp_opinfo_test(outplace_variant, args, kwargs,
                                     sample.output_process_fn_grad,
                                     clone_inputs=False,
                                     fixme_ref_jvp_local=fixme_ref_jvp_local)
            if is_valid_inplace_sample_input(sample, op, inplace_variant):
                self.jvp_opinfo_test(inplace_variant, args, kwargs,
                                     sample.output_process_fn_grad,
                                     clone_inputs=True,
                                     fixme_ref_jvp_local=fixme_ref_jvp_local)

    def jvp_opinfo_test(self, fn, args, kwargs, output_process_fn,
                        clone_inputs, fixme_ref_jvp_local):
        # NB: we used requires_grad=True to determine where the primals are,
        # but don't need that information otherwise
        fn, primals = normalize_op_input_output2(
            fn, args, kwargs, output_process_fn, requires_grad=True)
        orig_primals = tree_map(lambda x: x.detach(), primals)
        orig_tangents = tree_map(lambda x: torch.randn_like(x), primals)

        def maybe_clone_inputs():
            if clone_inputs:
                primals = tree_map(torch.clone, orig_primals)
                tangents = tree_map(torch.clone, orig_tangents)
                return primals, tangents
            return orig_primals, orig_tangents

        primals, tangents = maybe_clone_inputs()
        expected_primal_outs, expected_tangent_outs = \
            fixme_ref_jvp_local(fn, primals, tangents)

        primals, tangents = maybe_clone_inputs()
        primal_outs, tangent_outs = jvp(fn, primals, tangents)

        self.assertEqual(primal_outs, expected_primal_outs)
        self.assertEqual(tangent_outs, expected_tangent_outs)

    @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
    @skipOps('TestOperators', 'test_vjp', vjp_fail.union({
        skip('as_strided_scatter', ''),  # silent incorrectness; also might be flaky
        xfail('sparse.sampled_addmm', ''),
    }))
    @opsToleranceOverride('TestOperators', 'test_vjp', (
        tol1('nn.functional.conv_transpose3d',
             {torch.float32: tol(atol=5e-05, rtol=9e-05)}, device_type='cuda'),
        tol1('nn.functional.binary_cross_entropy_with_logits',
             {torch.float32: tol(atol=1e-04, rtol=1e-04)}),
    ))
    def test_vjp(self, device, dtype, op):
        if not op.supports_autograd:
            self.skipTest("Skipped! Autograd not supported.")
            return

        samples = op.sample_inputs(device, dtype, requires_grad=True)

        def _test(_op, inplace=False):
            for sample in samples:
                if inplace and not is_valid_inplace_sample_input(sample, op, op.inplace_variant):
                    continue
                fn, primals = normalize_op_input_output(_op, sample)
                result = fn(*primals)
                cotangents = tree_map(lambda x: torch.randn_like(x), result)

                out, vjp_fn = vjp(fn, *primals)
                self.assertEqual(out, result)
                result_vjps = vjp_fn(cotangents)

                _, vjp_fn = ref_vjp(fn, *primals)
                expected_vjps = vjp_fn(cotangents)

                self.assertEqual(result_vjps, expected_vjps)

        _test(op)
        for a_op in op.aliases:
            _test(a_op)
        if op.inplace_variant:
            def f(inp, *args, **kwargs):
                return op.inplace_variant(inp.clone(), *args, **kwargs)
            _test(f, inplace=True)

    @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
    @skipOps('TestOperators', 'test_vjpvjp', vjp_fail.union({
        skip('nn.functional.max_unpool1d'),  # silent incorrectness; Flaky
        skip('nn.functional.max_unpool2d'),  # silent incorrectness; Flaky
        xfail('nn.functional.ctc_loss'),  # Not Implemented
        xfail('native_layer_norm', ''),  # Expected a proper Tensor but got None for argument #1 'other'
        xfail('sparse.sampled_addmm', ''),  # sparse tensors have no strides
        # AssertionError: Tensor-likes are not close!
        # Mismatched elements: 1 / 15 (6.7%)
        # Greatest absolute difference: 24.0 at index (2, 4) (up to 1e-05 allowed)
        # Greatest relative difference: 1.7933241714393998e-06 at index (2, 4) (up to 1.3e-06 allowed)
        # The failure occurred for item [0]
        xfail('masked.prod')
    }))
    @opsToleranceOverride('TestOperators', 'test_vjpvjp', (
        tol1('nn.functional.conv_transpose3d',
             {torch.float32: tol(atol=5e-05, rtol=9e-05)}, device_type='cuda'),
        tol1('prod',
             {torch.float32: tol(atol=2e-05, rtol=1e-04)}),
        tol1('masked.cumprod',
             {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
        tol1('cumprod',
             {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
        tol1('linalg.vander',
             {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
        tol2('linalg.det', 'singular',
             {torch.float32: tol(atol=2e-05, rtol=2e-05)}),
    ))
    def test_vjpvjp(self, device, dtype, op):
        if not op.supports_autograd:
            self.skipTest("Skipped! Autograd not supported.")
            return
        if not op.supports_gradgrad:
            self.skipTest("Skipped! Operation does not support gradgrad")
            return

        samples = op.sample_inputs(device, dtype, requires_grad=True)

        def test(_op, inplace=False):
            for sample in samples:
                if inplace and not is_valid_inplace_sample_input(sample, op, op.inplace_variant):
                    continue
                fn, args = get_vjpfull_variant(_op, sample)
                result = fn(*args)
                cotangents = tree_map(lambda x: torch.randn_like(x), result)

                # Compute vjp of vjp
                _, vjp_fn = vjp(fn, *args)
                result_vjps = vjp_fn(cotangents)

                # Compute ref_vjp of vjp. We could have done ref_vjp of ref_vjp,
                # but since we're confident that vjp works by itself, this is
                # an equivalent way to test that.
                _, vjp_fn = ref_vjp(fn, *args)
                expected_vjps = vjp_fn(cotangents)

                self.assertEqual(result_vjps, expected_vjps)

        test(op)
        if op.inplace_variant:
            def fn(inp, *args, **kwargs):
                return op.inplace_variant(inp.clone(), *args, **kwargs)
            test(fn, inplace=True)

    @skipOps('TestOperators', 'test_vmapvjpvjp', vjp_fail.union({
        skip("atleast_1d"),  # Takes too long
        skip("atleast_2d"),  # Takes too long
        skip("atleast_3d"),  # Takes too long
        xfail("as_strided"),  # incorrect output
        xfail("as_strided_scatter"),  # incorrect output
        skip("bernoulli"),  # calls random op
        xfail("bfloat16"),  # rank 4 tensor for channels_last
        xfail("chalf"),  # rank 4 tensor for channels_last
        xfail("double"),  # rank 4 tensor for channels_last
        xfail("float"),  # rank 4 tensor for channels_last
        xfail("half"),  # rank 4 tensor for channels_last
        # It looks like you're either (1) calling .item() on a Tensor or
        # (2) attempting to use a Tensor in some data-dependent control flow or
        # (3) encountering this error in PyTorch internals.
        xfail("index_reduce"),
        xfail("linalg.eig"),  # vmap over torch.allclose
        xfail("linalg.eigvals"),  # vmap over torch.allclose
        xfail("linalg.householder_product"),  # vmap: inplace into a regular tensor
        xfail("nanquantile", device_type='cpu'),  # vmap not implemented for at::equal.
        xfail("native_layer_norm"),  # vmap: inplace into a regular tensor
        # got a batched tensor as input while the running_mean or running_var,
        # which will be updated in place, were not batched.
        xfail("nn.functional.batch_norm"),
        xfail("nn.functional.binary_cross_entropy"),  # vmap: inplace into a regular tensor
        xfail("nn.functional.ctc_loss"),  # derivate not implemented for _ctc_loss_backward
        skip("nn.functional.dropout"),  # calls random op
        skip("nn.functional.dropout2d"),  # calls random op
        skip("nn.functional.dropout3d"),  # calls random op
        skip("nn.functional.feature_alpha_dropout", "with_train"),  # calls random op
        skip("nn.functional.fractional_max_pool2d"),  # calls random op
        skip("nn.functional.fractional_max_pool3d"),  # calls random op
        skip('nn.functional._scaled_dot_product_attention'),  # randomness
        # It looks like you're either (1) calling .item() on a Tensor or
        # (2) attempting to use a Tensor in some data-dependent control flow or
        # (3) encountering this error in PyTorch internals.
        xfail("nn.functional.gaussian_nll_loss"),
        # got a batched tensor as input while the running_mean or running_var,
        # which will be updated in place, were not batched.
        xfail("nn.functional.instance_norm"),
        xfail("nn.functional.layer_norm"),  # vmap: inplace into a regular tensor
        # RuntimeError: NYI: querying is_contiguous inside of vmap
        # for memory_format other than torch.contiguous_formats
        xfail("nn.functional.max_pool2d"),
        # RuntimeError: NYI: Tensor.clone(memory_format) inside vmap is only
        # supported with memory_format torch.preserve_format or
        # torch.contiguous_format (got ChannelsLast)
        xfail("nn.functional.max_unpool2d"),
        # RuntimeError: NYI: Tensor.clone(memory_format) inside vmap is only
        # supported with memory_format torch.preserve_format
        # or torch.contiguous_format (got ChannelsLast)s
        xfail("nn.functional.max_unpool2d", "grad"),
        xfail("nn.functional.rrelu"),  # RuntimeError: vmap: we do not yet support aten::rrelu_with_noise.
        xfail("normal"),  # calls random op
        xfail("normal", "number_mean"),  # calls random op
        xfail("pca_lowrank"),  # calls random op
        xfail("put"),  # vmap: inplace into a regular tensor
        xfail("quantile", device_type='cpu'),  # Batching rule not implemented for `at::equal`
        xfail("scatter_reduce", "prod"),  # vmap (looks like you are calling item/data-dependent)
        xfail("sparse.sampled_addmm"),  # RuntimeError: Sparse CSR tensors do not have strides
        xfail("svd_lowrank"),  # calls random op
        xfail("take"),  # vmap: inplace into a regular tensor
        xfail("to"),  # rank 4 tensor for channels_last
        xfail("view_as_complex"),  # RuntimeError: Tensor must have a last dimension with stride 1
        xfail("masked.softmax", device_type='cuda'),  # Mismatch in values!
        xfail("masked.softmin", device_type='cuda'),  # Mismatch in values!
        # got a batched tensor as input while the running_mean or running_var,
        # which will be updated in place, were not batched.
        xfail("nn.functional.batch_norm", 'without_cudnn'),
        # view doesn't work on sparse
        xfail("to_sparse"),
        xfail("native_batch_norm"),
    }))
    @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
    @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
    @opsToleranceOverride('TestOperators', 'test_vmapvjpvjp', (
        tol1('linalg.svd',
             {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
        tol1('linalg.lu_factor',
             {torch.float32: tol(atol=2e-03, rtol=2e-02)}),
        tol1('svd',
             {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
    ))
    def test_vmapvjpvjp(self, device, dtype, op):
        # Since, we test `vjpvjp` independently,
        # for this test, we just verify that vmap
        # of `vjpvjp` is correct.
        if not op.supports_autograd:
            self.skipTest("Skipped! Autograd not supported.")
            return
        if not op.supports_gradgrad:
            self.skipTest("Skipped! Operation does not support gradgrad")
            return

        samples = op.sample_inputs(device, dtype, requires_grad=True)

        # TODO: test in-place
        if is_inplace(op, op.get_op()):
            self.skipTest("Skipped! NYI: inplace-testing not supported.")
            return

        for sample in samples:
            fn, args = get_vjpfull_variant(op, sample)
            result = fn(*args)
            cotangents = tree_map(lambda x: torch.randn_like(x), result)
            cotangents, _ = tree_flatten(cotangents)
            num_args = len(args)

            args_and_cotangents = tuple(args) + tuple(cotangents)

            def vjp_of_vjp(*args_and_cotangents):
                args = args_and_cotangents[:num_args]
                cotangents = args_and_cotangents[num_args:]
                result, vjp_fn = vjp(fn, *args)
                result_vjps = vjp_fn(cotangents)
                result, _ = tree_flatten(result)
                result_vjps, _ = tree_flatten(result_vjps)
                return (*result, *result_vjps)

            is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs)
            generator = get_fallback_and_vmap_exhaustive(
                vjp_of_vjp, args_and_cotangents, {}, is_batch_norm_and_training=is_batch_norm_and_training)
            for loop_out, batched_out in generator:
                self.assertEqual(loop_out, batched_out)

    vmapvjp_fail = vjp_fail.union({
        # -------------------- ALLOWED FAILURES --------------------------------
        # The following are not bugs and are expected behavior
        xfail('masked_select'),  # Not possible due to dynamic shapes
        skip('bernoulli'),  # randomness
        skip('normal', ''),  # randomness
        skip('normal', 'number_mean'),  # randomness
        skip('nn.functional.rrelu'),  # randomness
        skip('nn.functional.feature_alpha_dropout', 'with_train'),  # randomness
        skip('nn.functional.feature_alpha_dropout', 'without_train'),  # randomness
        skip('nn.functional.dropout'),  # randomness
        skip('nn.functional.dropout2d'),  # randomness
        skip('nn.functional.dropout3d', ''),  # randomness
        skip('nn.functional._scaled_dot_product_attention'),  # randomness
        xfail('as_strided'),  # as_strided is too wild for us to support, wontfix
        xfail('index_put', ''),  # not possible due to dynamic shapes; we support a subset
        xfail('masked_scatter'),  # dynamic
        xfail('nn.functional.fractional_max_pool2d'),  # random
        xfail('nn.functional.fractional_max_pool3d'),  # random
        xfail('take'),  # dynamic
        xfail('pca_lowrank', ''),  # randomness
        xfail('svd_lowrank', ''),  # randomness
        xfail('to_sparse', ''),  # non-dense output
        skip('to'),  # RuntimeError: required rank 4 tensor to use channels_last format
        # ----------------------------------------------------------------------

        # ---------------------------- BUGS ------------------------------------
        # All of the following are bugs and need to be fixed
        skip('linalg.svdvals'),  # # really annoying thing where it passes correctness check but not has_batch_rule
        skip("native_batch_norm"),
        xfail('__getitem__', ''),  # dynamic error
        xfail('linalg.eig'),  # Uses aten::allclose
        xfail('linalg.householder_product'),  # needs select_scatter
        xfail('nanquantile', device_type='cpu'),  # checks q via a .item() call
        xfail('nn.functional.gaussian_nll_loss'),  # checks var for if any value < 0
        xfail('narrow'),  # .item() call
        xfail('quantile', device_type='cpu'),  # checks q via a .item() call
        xfail('view_as_complex'),  # Tensor must have a last dimension with stride 1

        # required rank 4 tensor to use channels_last format
        xfail('bfloat16'),
        xfail('double'),
        xfail('float'),
        xfail('half'),
        xfail('chalf', ''),

        xfail('scatter_reduce', 'prod'),  # item call

        # Batching rule not implemented for aten::_use_cudnn_ctc_loss.Tensor
        xfail('nn.functional.ctc_loss', device_type='cuda'),
        # NYI: querying is_contiguous inside of vmap for memory_format other than torch.contiguous_format
        xfail('nn.functional.max_unpool2d'),
        xfail('nn.functional.max_unpool2d', 'grad'),

        xfail('sparse.sampled_addmm', ''),
        xfail('as_strided_scatter', ''),  # calls as_strided
        xfail('index_reduce', ''),  # .item() call
        # ---------------------------------------------------------------------
    })

    @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
    @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
    @opsToleranceOverride('TestOperators', 'test_vmapvjp', (
        tol1('linalg.svd',
             {torch.float32: tol(atol=1.5e-04, rtol=1e-04)}, device_type="cuda"),
        tol1('svd',
             {torch.float32: tol(atol=1.5e-04, rtol=1e-04)}, device_type="cuda"),
    ))
    @skipOps('TestOperators', 'test_vmapvjp', vmapvjp_fail)
    def test_vmapvjp(self, device, dtype, op):
        if not op.supports_autograd:
            self.skipTest("Skipped! Autograd not supported.")
            return

        samples = op.sample_inputs(device, dtype, requires_grad=True)

        # TODO: test in-place
        if is_inplace(op, op.get_op()):
            self.skipTest("Skipped! NYI: inplace-testing not supported.")
            return

        for sample in samples:
            cotangents = get_sample_cotangents(op, sample)
            fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents)
            is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs)
            generator = get_fallback_and_vmap_exhaustive(
                fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training)
            for loop_out, batched_out in generator:
                self.assertEqual(loop_out, batched_out)

    vmapjvpall_fail = {
        # -------------------- ALLOWED FAILURES --------------------------------
        # The following are expected (not a bug)
        skip('bernoulli', ''),  # randomness
        skip('nn.functional.dropout'),  # randomness
        skip('nn.functional.rrelu'),  # randomness
        skip('nn.functional.dropout2d', ''),
        skip('nn.functional.dropout3d', ''),
        skip('nn.functional._scaled_dot_product_attention'),  # randomness
        skip('nn.functional.feature_alpha_dropout', 'without_train'),
        skip('nn.functional.feature_alpha_dropout', 'with_train'),
        xfail('nn.functional.fractional_max_pool2d'),  # Cannot access data pointer of Tensor that doesn't have storage
        xfail('nn.functional.fractional_max_pool3d'),  # Cannot access data pointer of Tensor that doesn't have storage
        # Not actually a problem: embedding with max_norm mutates the weight
        # and causes different runs to produce different results.
        # skip because this is flaky depending on what the max_norm is!
        skip('nn.functional.embedding', ''),
        skip('to'),  # RuntimeError: required rank 4 tensor to use channels_last format
        # ----------------------------------------------------------------------

        # ---------------------------- BUGS ------------------------------------
        # The following are bugs that we should fix
        decorate('nn.functional.conv2d', decorator=unittest.skipIf(IS_ARM64, "Fails on M1")),
        skip('nn.functional.max_pool1d'),  # fails on cpu, runs on cuda
        xfail('masked.mean'),  # silent incorrectness (nan difference)

        xfail('nn.functional.soft_margin_loss', ''),  # soft_margin_loss_backward does not support forward-ad
        xfail('tensor_split'),  # data_ptr composite compliance
        xfail('quantile'),  # at::equal batching rule (cpu), also, in-place vmap (cuda)
        skip('as_strided'),  # Test runner cannot handle this
        xfail('nn.functional.gaussian_nll_loss'),  # .item or data-dependent control flow
        xfail('scatter'),  # forward-mode AD does not support at::scatter
        xfail('nanquantile'),  # at::equal batching rule (cpu), also, in-place vmap (cuda)
        xfail('view_as_complex'),  # Tensor must have a last dimension with stride 1

        skip('pca_lowrank', ''),  # randomness
        skip('svd_lowrank', ''),  # randomness

        xfail('double'),  # required rank 4 tensor to use channels_last format

        # potential silent incorrectness
        skip('nn.functional.max_unpool1d'),  # Flaky, seems to sometimes his max_unpool2d
        skip('nn.functional.max_unpool2d'),  # fails everywhere except on mac
        skip('nn.functional.max_unpool3d'),  # fails everywhere except on mac

        # erroring because running_mean and running_var aren't differentiable
        xfail('nn.functional.batch_norm'),
        xfail('nn.functional.batch_norm', 'without_cudnn'),
        xfail("native_batch_norm"),
        # ----------------------------------------------------------------------
    }

    @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
    @opsToleranceOverride('TestOperators', 'test_vmapjvpall', (
        tol1('nn.functional.conv_transpose3d',
             {torch.float32: tol(atol=2e-04, rtol=9e-3)}, device_type='cuda'),
        tol1('linalg.householder_product',
             {torch.float32: tol(atol=2e-04, rtol=9e-3)}, device_type='cuda'),
        tol1('linalg.householder_product',
             {torch.float32: tol(atol=2e-04, rtol=1e-4)}, device_type='cpu'),
    ))
    @skipOps('TestOperators', 'test_vmapjvpall', vmapjvpall_fail)
    @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
    # This is technically a superset of test_vmapjvp. We should either delete test_vmapjvp
    # or figure out if we can split vmapjvpall. It's useful to keep test_vmapjvp intact
    # because that coresponds to "batched forward-mode AD" testing in PyTorch core
    def test_vmapjvpall(self, device, dtype, op):
        if is_inplace(op, op.get_op()):
            # TODO: test in-place
            self.skipTest("Skipped! NYI: inplace-testing not supported.")
            return

        samples = op.sample_inputs(device, dtype, requires_grad=False)

        if not op.supports_forward_ad:
            self.skipTest("Skipped! Forward AD not supported.")
            return

        for sample in samples:
            arg_values = [sample.input] + list(sample.args)
            kwarg_values = sample.kwargs
            args = tuple(arg_values) + tuple(kwarg_values)
            fn, args = get_jvp_variant_primals_tangents(op, sample)
            is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values)
            generator = get_fallback_and_vmap_exhaustive(
                fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training)
            for loop_out, batched_out in generator:
                self.assertEqual(loop_out, batched_out)

    @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
    @skipOps('TestOperators', 'test_vmapjvpall_has_batch_rule', vmapjvpall_fail.union({
        skip('to'),  # RuntimeError: required rank 4 tensor to use channels_last format
        xfail('nn.functional.huber_loss'),
        xfail('lu'),
        xfail('cumprod'),
        xfail('masked_fill'),
        xfail('copysign'),
        xfail('complex'),
        skip('masked.mean'),  # ???
        xfail('masked_scatter'),
        xfail('index_fill'),
        xfail('put'),
        xfail('take'),
        xfail('nn.functional.max_pool3d'),
        xfail('vdot'),
        xfail('nanmean'),
        xfail('nansum'),
        xfail('nn.functional.feature_alpha_dropout', 'without_train'),
        xfail('linalg.lu_factor', ''),
        xfail('nn.functional.dropout2d', ''),
        xfail('pca_lowrank', ''),
        xfail('svd_lowrank', ''),
        xfail('linalg.lu_factor_ex', ''),
        xfail('nn.functional.feature_alpha_dropout', 'with_train'),
        xfail('special.log_ndtr', ''),
        xfail('fft.ihfft2'),  # conj_physical fallback
        xfail('fft.ihfftn'),  # conj_physical fallback
        xfail('polar'),  # complex fallback
        xfail('nn.functional.max_unpool3d', 'grad'),
        xfail('nn.functional.smooth_l1_loss', ''),
        xfail('nn.functional.max_unpool2d', 'grad'),
        xfail('nn.functional.soft_margin_loss', ''),
        xfail('nn.functional.max_unpool1d', 'grad'),
        xfail('nn.functional.embedding', ''),
        xfail('scatter_reduce', "sum"),   # aten::scatter_reduce.two hit the vmap fallback
        xfail('scatter_reduce', "mean"),  # aten::scatter_reduce.two hit the vmap fallback
        xfail('scatter_reduce', "amin"),  # aten::scatter_reduce.two hit the vmap fallback
        xfail('scatter_reduce', "amax"),  # aten::scatter_reduce.two hit the vmap fallback
        xfail('lu_unpack'),
        xfail('nn.functional.glu'),
        xfail('nn.functional.bilinear'),  # trilinear doesn't have batching rule
        xfail('linalg.lu', ''),
        xfail('linalg.lu_solve', ''),
        xfail('nn.functional.dropout3d', ''),
        xfail('as_strided_scatter', ''),
        xfail('masked.cumprod', ''),
        xfail('linalg.vecdot', ''),
    }))
    @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
    def test_vmapjvpall_has_batch_rule(self, device, dtype, op):
        if is_inplace(op, op.get_op()):
            # TODO: test in-place
            self.skipTest("Skipped! NYI: inplace-testing not supported.")
            return

        samples = op.sample_inputs(device, dtype, requires_grad=False)

        if not op.supports_forward_ad:
            self.skipTest("Skipped! Forward AD not supported.")
            return

        def test():
            for sample in samples:
                arg_values = [sample.input] + list(sample.args)
                kwarg_values = sample.kwargs
                args = tuple(arg_values) + tuple(kwarg_values)
                fn, args = get_jvp_variant_primals_tangents(op, sample)
                is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values)
                for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
                        fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, compute_loop_out=False):
                    pass
        check_vmap_fallback(self, test, op, dry_run=False)

    @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
    @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
    @skipOps('TestOperators', 'test_vmapvjp_has_batch_rule', vmapvjp_fail.union({
        skip('to'),  # RuntimeError: required rank 4 tensor to use channels_last format
        xfail('view_as_complex'),
        xfail('complex'),
        xfail('copysign'),
        xfail('cummax'),
        xfail('cummin'),
        xfail('cumprod'),
        xfail('nansum'),
        xfail('nanmean'),
        xfail('narrow'),  # Batching rule not implemented for `narrow.Tensor` (and view op)
        xfail('special.log_ndtr'),
        xfail('index_copy'),
        xfail('index_fill'),
        xfail('linalg.eig'),
        xfail('linalg.householder_product'),
        xfail('lu'),
        xfail('lu_solve'),
        xfail('lu_unpack'),
        xfail('masked_fill'),
        xfail('masked_scatter'),
        xfail('masked_select'),
        xfail('nanquantile'),
        xfail('put'),
        xfail('scatter_reduce', "sum"),   # aten::scatter_reduce.two hit the vmap fallback
        xfail('scatter_reduce', "mean"),  # aten::scatter_reduce.two hit the vmap fallback
        xfail('scatter_reduce', "amin"),  # aten::scatter_reduce.two hit the vmap fallback
        xfail('scatter_reduce', "amax"),  # aten::scatter_reduce.two hit the vmap fallback
        xfail('quantile'),
        xfail('renorm'),
        xfail('take'),
        xfail('tensor_split'),
        xfail('to_sparse'),
        xfail('unfold'),
        xfail('vdot'),
        xfail('nn.functional.dropout'),
        xfail('fft.ihfft2'),
        xfail('fft.ihfftn'),
        xfail('nn.functional.gaussian_nll_loss'),
        xfail('nn.functional.huber_loss'),
        xfail('nn.functional.bilinear'),
        xfail('nn.functional.fractional_max_pool3d'),
        xfail('nn.functional.ctc_loss'),
        xfail('as_strided'),
        xfail('stft'),
        xfail('nn.functional.rrelu'),
        xfail('nn.functional.embedding_bag'),
        xfail('nn.functional.max_pool3d'),
        xfail('nn.functional.fractional_max_pool2d'),
        xfail('linalg.lu_factor', ''),
        xfail('nn.functional.feature_alpha_dropout', 'with_train'),
        xfail('pca_lowrank', ''),
        xfail('nn.functional.dropout2d', ''),
        xfail('nn.functional.feature_alpha_dropout', 'without_train'),
        xfail('svd_lowrank', ''),
        xfail('linalg.lu_factor_ex', ''),

        xfail('nn.functional.max_unpool2d', ''),
        xfail('nn.functional.multi_margin_loss', ''),
        xfail('nn.functional.multilabel_margin_loss', ''),
        xfail('nn.functional.pdist', ''),
        xfail('nn.functional.smooth_l1_loss', ''),
        xfail('scatter_reduce', 'prod'),
        xfail('nn.functional.max_unpool1d', ''),
        xfail('nn.functional.max_unpool3d', ''),
        xfail('nn.functional.max_unpool3d', 'grad'),
        xfail('nn.functional.soft_margin_loss', ''),
        xfail('nn.functional.max_unpool1d', 'grad'),
        xfail('nn.functional.max_unpool2d', 'grad'),
        xfail('linalg.lu', ''),
        xfail('linalg.lu_solve', ''),
        xfail('chalf', ''),
        xfail('index_reduce', ''),
        xfail('linalg.vander', ''),
        xfail('nn.functional.dropout3d', ''),
        xfail('as_strided_scatter', ''),
        xfail('segment_reduce', 'offsets'),
        xfail('masked.cumprod', ''),
        xfail('linalg.vecdot', ''),
        xfail('segment_reduce', 'lengths'),
        xfail('sparse.sampled_addmm', ''),
        xfail("native_batch_norm"),
    }))
    def test_vmapvjp_has_batch_rule(self, device, dtype, op):
        if not op.supports_autograd:
            self.skipTest("Skipped! Autograd not supported.")
            return

        samples = op.sample_inputs(device, dtype, requires_grad=True)

        # TODO: test in-place
        if is_inplace(op, op.get_op()):
            self.skipTest("Skipped! NYI: inplace-testing not supported.")
            return

        def test():
            for sample in samples:
                cotangents = get_sample_cotangents(op, sample)
                fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents)
                is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs)
                for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
                        fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, compute_loop_out=False):
                    pass
                for a_op in op.aliases:
                    fn, args = get_vjp_fn_and_args_with_cotangents(a_op, sample, cotangents)
                    for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
                            fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training, compute_loop_out=False):
                        pass

        check_vmap_fallback(self, test, op, dry_run=False)

    @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
    @skipOps('TestOperators', 'test_vjpvmap', vjp_fail.union({
        skip('bernoulli', ''),  # vjpvmap testing can't handle randomness
        skip('normal', ''),  # vjpvmap testing can't handle randomness
        skip('normal', 'number_mean'),  # vjpvmap testing can't handle randomness
        skip('nn.functional.rrelu'),  # randomness
        skip('nn.functional.feature_alpha_dropout', 'with_train'),  # randomness
        skip('nn.functional.feature_alpha_dropout', 'without_train'),  # randomness
        skip('to'),  # RuntimeError: required rank 4 tensor to use channels_last format
        skip('to_sparse', ''),  # non-dense output

        # fallback path doesn't work
        # All of the following are bugs and need to be fixed
        xfail('__getitem__', ''),
        xfail('index_put', ''),
        xfail('view_as_complex'),
        xfail('nn.functional.gaussian_nll_loss'),
        xfail('masked_select'),
        xfail('narrow'),  # Batching rule not implemented for `narrow.Tensor` (and view op)
        skip('nn.functional.fractional_max_pool3d'),  # generator works on cpu, fails on cuda
        xfail('__rpow__'),  # https://github.com/pytorch/functorch/issues/617
        skip('nn.functional.fractional_max_pool2d'),  # generator works on cpu, fails on cuda
        xfail('column_stack', ''),
        xfail('nn.functional.dropout2d', ''),
        xfail('svd_lowrank', ''),
        xfail('pca_lowrank', ''),
        xfail('clamp'),
        xfail('cross'),  # The defaults of this op are *very* weird. No wonder it doesn't work
        # something weird happening with channels_last
        xfail('bfloat16'),
        xfail('double'),
        xfail('float'),
        xfail('half'),
        xfail('nn.functional.dropout3d', ''),
        xfail('as_strided_scatter', ''),
        xfail('sparse.sampled_addmm', ''),
        xfail("native_batch_norm"),
    }))
    def test_vjpvmap(self, device, dtype, op):
        # NB: there is no vjpvmap_has_batch_rule test because that is almost
        # certainly redundant with the vmap_has_batch_rule test in test_vmap.py

        # one-off skip
        if op.name == 'nn.functional.dropout':
            self.skipTest("Skipped!")

        if not op.supports_autograd:
            # If the op doesn't support autograd, vmap(op) won't either
            self.skipTest("Skipped! Autograd not supported.")
            return

        # TODO: test in-place
        if is_inplace(op, op.get_op()):
            self.skipTest("Skipped! NYI: inplace-testing not supported.")
            return

        samples = op.sample_inputs(device, dtype, requires_grad=True)
        batch_norm_fns = ("nn.functional.batch_norm", "nn.functional.instance_norm")  # instance norm calls batch norm
        is_batch_norm = op.name in batch_norm_fns

        for sample in samples:
            args = [sample.input] + list(sample.args)
            kwargs = sample.kwargs

            is_batch_norm_and_training = is_batch_norm and is_batch_norm_training(op.name, kwargs)
            generator = generate_vmap_inputs(args, kwargs,
                                             is_batch_norm_and_training=is_batch_norm_and_training)

            for batched_args, in_dims, kwargs in generator:
                vmapped_op = vmap(op, in_dims)
                fn, primals = normalize_op_input_output2(vmapped_op, batched_args, kwargs,
                                                         sample.output_process_fn_grad)
                result = fn(*primals)
                cotangents = tree_map(lambda x: torch.randn_like(x), result)

                _, vjp_fn = vjp(fn, *primals)
                result_vjps = vjp_fn(cotangents)

                _, vjp_fn = ref_vjp(fn, *primals)
                expected_vjps = vjp_fn(cotangents)

                self.assertEqual(result_vjps, expected_vjps)

    def _compare_jacobians_of_vjp(self, fn, cotangents_and_primals, argnums=None, atol_rtol=None):
        if argnums is None:
            argnums = tuple(range(len(cotangents_and_primals)))

        def get_vjp(cotangents, *primals):
            _, vjp_fn = vjp(fn, *primals)
            return vjp_fn(cotangents)

        jacobian_jvp = jacfwd(get_vjp, argnums)(*cotangents_and_primals)
        jacobian_vjp = jacrev(get_vjp, argnums)(*cotangents_and_primals)

        # For dtype changing operations, the jacobians have different dtype.
        jacobian_jvp = tree_map(lambda x: x.to(torch.float), jacobian_jvp)
        jacobian_vjp = tree_map(lambda x: x.to(torch.float), jacobian_vjp)

        if atol_rtol is not None:
            (atol, rtol) = atol_rtol
            self.assertEqual(jacobian_jvp, jacobian_vjp, atol=atol, rtol=rtol)
        else:
            self.assertEqual(jacobian_jvp, jacobian_vjp)

    @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
    @skipOps('TestOperators', 'test_jvpvjp', vjp_fail.union({
        xfail('to_sparse', ''),  # NYI
        # RuntimeError: Trying to set a forward gradient that has a different size than that of the original Tensor,
        # this is not supported. Tensor is of size [5, 2, 3] while the given forward gradient is of size [1, 2, 3].
        xfail('normal', ''),
        xfail('cdist', ''),  # NYI: forward-AD for _cdist_forward
        xfail('cholesky', ''),  # NYI: forward-AD for cholesky
        xfail('logcumsumexp', ''),  # NYI: forward-AD for logcumsumexp
        xfail('nn.functional.embedding_bag', ''),  # NYI: forward-AD for _embedding_bag
        xfail('nn.functional.grid_sample', ''),  # NYI: forward AD for grid_sampler_2d
        xfail('nn.functional.hardsigmoid', ''),  # NYI: forward AD for hardsigmoid_backward
        xfail('nn.functional.huber_loss', ''),  # NYI: forward AD for huber_loss_backward
        xfail('nn.functional.logsigmoid', ''),  # not differentiable w.r.t. buffer
        xfail('renorm', ''),  # NYI: forward AD for renorm
        xfail('symeig', ''),  # NYI: forward AD for symeig
        xfail('nn.functional.multilabel_margin_loss', ''),  # NYI: multilabel_margin_loss_forward
        xfail('nn.functional.multilabel_soft_margin_loss', ''),  # NYI: log_sigmoid_backward
        xfail('nn.functional.soft_margin_loss', ''),  # NYI: forward-AD for log_sigmoid_backward
        xfail('nn.functional.ctc_loss', ''),  # NYI: forward-AD for _ctc_loss
        xfail('nn.functional.pdist', ''),  # NYI: forward-AD with _pdist_forward
        xfail('nn.functional.multi_margin_loss', ''),  # NYI: forward AD with multi_margin_loss
        skip('linalg.householder_product', '', device_type='cuda'),  # flaky, I'm not sure why
        xfail('sparse.sampled_addmm', ''),  # Sparse tensors have no strides
        skip('as_strided_scatter', ''),  # seems flaky
        xfail('segment_reduce', 'offsets'),  # NYI: forward-AD for segment_reduce
        xfail('index_reduce', ''),  # NYI: forward-AD for index_reduce
        xfail('segment_reduce', 'lengths'),  # NYI: forward-AD for segment_reduce
    }))
    @opsToleranceOverride('TestOperators', 'test_jvpvjp', (
        tol1('masked.prod',
             {torch.float32: tol(atol=1e-04, rtol=1.3e-05)}),
        tol1('masked.cumprod',
             {torch.float32: tol(atol=1e-04, rtol=1e-04)}),
        tol1('cumprod',
             {torch.float32: tol(atol=1e-04, rtol=1.3e-05)}, device_type='cuda'),
        tol1('linalg.vander',
             {torch.float32: tol(atol=1e-04, rtol=1.3e-05)}, device_type='cuda'),
    ))
    def test_jvpvjp(self, device, dtype, op):
        if not op.supports_autograd:
            self.skipTest("Skipped! Autograd not supported.")
            return

        samples = op.sample_inputs(device, dtype, requires_grad=True)

        # TODO: test in-place
        if is_inplace(op, op.get_op()):
            self.skipTest("Skipped! NYI: inplace-testing not supported.")
            return

        for sample in samples:
            fn, primals = normalize_op_input_output(op, sample)
            result = fn(*primals)
            cotangents = tree_map(lambda x: torch.randn_like(x), result)

            primals_tangents = tree_map(lambda x: torch.randn_like(x), primals)
            cotangents_tangents = tree_map(lambda x: torch.randn_like(x), cotangents)

            def push_vjp(primals, cotangents):
                _, vjp_fn = vjp(fn, *primals)
                return vjp_fn(cotangents)

            result = jvp(push_vjp, (primals, cotangents), (primals_tangents, cotangents_tangents))
            self.assertEqual(len(result), 2)

            def tree_map2(fn, first, second):
                flat_first, spec_first = tree_flatten(first)
                flat_second, spec_second = tree_flatten(second)
                assert spec_first == spec_second
                flat_result = [fn(f, s) for f, s in zip(flat_first, flat_second)]
                return tree_unflatten(flat_result, spec_first)

            def reference(primals, cotangents, primals_tangents, cotangents_tangents):
                with fwAD.dual_level():
                    primal_duals = tree_map2(fwAD.make_dual, primals, primals_tangents)
                    _, vjp_fn = ref_vjp(fn, *primal_duals)

                    cotangent_duals = tree_map2(fwAD.make_dual, cotangents, cotangents_tangents)
                    result = vjp_fn(cotangent_duals)

                    flat_result, spec = tree_flatten(result)
                    primals_out, tangents_out = zip(*[fwAD.unpack_dual(r) for r in flat_result])
                    tangents_out = [t if t is not None else torch.zeros_like(p)
                                    for p, t in zip(primals_out, tangents_out)]
                    expected = (tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec))
                return expected

            expected = reference(primals, cotangents, primals_tangents, cotangents_tangents)
            self.assertEqual(result, expected)

    @skipOps('TestOperators', 'test_vmapjvpvjp', vjp_fail.union({
        # Following operatos take too long, hence skipped
        skip('atleast_1d'),
        skip('atleast_2d'),
        skip('atleast_3d'),
        skip('meshgrid', 'list_of_tensors'),
        skip('meshgrid', 'variadic_tensors'),
        skip('broadcast_tensors'),
        skip('linalg.lstsq'),
        skip('nn.functional.bilinear'),
        skip('native_layer_norm'),

        # Potential bugs/errors
        xfail('as_strided'),  # AssertionError: Tensor-likes are not close!
        xfail('as_strided_scatter'),  # AssertionError: Tensor-likes are not close!
        xfail('bernoulli'),  # calls random op
        xfail('bfloat16'),  # required rank 4 tensor to use channels_last format
        xfail('cdist'),  # Forward AD not implemented and no decomposition
        xfail('chalf'),  # required rank 4 tensor to use channels_last format
        xfail('cholesky'),  # Forward AD not implemented and no decomposition
        xfail('double'),  # required rank 4 tensor to use channels_last format
        xfail('float'),  # required rank 4 tensor to use channels_last format
        xfail('half'),  # required rank 4 tensor to use channels_last format
        xfail('index_reduce'),  # Forward AD not implemented and no decomposition
        xfail('linalg.eig'),  # vmap over torch.allclose isn't supported yet.
        # AssertionError: Tensor-likes are not close!
        # Mismatched elements: 2 / 120 (1.7%)
        # Greatest absolute difference: 0.09438323974609375
        # Greatest relative difference: 0.00115722746596277
        xfail('linalg.householder_product', device_type='cuda'),
        xfail('logcumsumexp'),  # Forward AD not implemented and no decomposition
        xfail('mvlgamma', 'mvlgamma_p_1'),  # vmap: inplace into a regular tensor
        xfail('mvlgamma', 'mvlgamma_p_3'),  # vmap: inplace into a regular tensor
        xfail('mvlgamma', 'mvlgamma_p_5'),  # vmap: inplace into a regular tensor
        xfail('nanquantile'),  # Batching rule not implemented for aten::equal
        # RuntimeError: Batch norm got a batched tensor as input while the
        # running_mean or running_var, which will be updated in place,
        # were not batched.
        xfail('nn.functional.batch_norm'),
        xfail('nn.functional.batch_norm', 'without_cudnn'),
        xfail('nn.functional.binary_cross_entropy'),  # vmap: inplace into a regular tensor
        xfail("nn.functional.ctc_loss"),  # ForwardAD not implemented and no decomposition
        xfail('nn.functional.dropout2d'),  # calls random op
        xfail('nn.functional.dropout3d'),  # calls random op
        xfail('nn.functional.dropout'),  # calls random op
        skip('nn.functional._scaled_dot_product_attention'),  # randomness
        xfail('nn.functional.embedding_bag'),  # Forward AD not implemented and no decomposition
        xfail('nn.functional.feature_alpha_dropout', 'with_train'),  # calls random op
        xfail('nn.functional.fractional_max_pool2d'),  # calls random op
        xfail('nn.functional.fractional_max_pool3d'),  # calls random op
        xfail('nn.functional.gaussian_nll_loss'),  # data depenedant flow
        xfail('nn.functional.grid_sample'),  # Forward AD not implemented and no decomposition
        xfail('nn.functional.hardsigmoid'),  # Forward AD not implemented and no decomposition
        xfail('nn.functional.hinge_embedding_loss'),  # vmap: inplace into a regular tensor
        xfail('nn.functional.huber_loss'),  # Forward AD not implemented and no decomposition
        # RuntimeError: Batch norm got a batched tensor as input while the
        # running_mean or running_var, which will be updated in place,
        # were not batched.
        xfail('nn.functional.instance_norm'),
        xfail('nn.functional.logsigmoid'),  # Forward AD not implemented and no decomposition
        # NYI: Tensor.clone(memory_format) inside vmap is only supported with
        # memory_format torch.preserve_format or torch.contiguous_format (got ChannelsLast)
        xfail('nn.functional.max_unpool2d'),
        xfail('nn.functional.max_unpool2d', 'grad'),
        xfail('nn.functional.multi_margin_loss'),  # Forward AD not implemented and no decomposition
        xfail('nn.functional.multilabel_margin_loss'),  # Forward AD not implemented and no decomposition
        xfail('nn.functional.multilabel_soft_margin_loss'),  # Forward AD not implemented and no decomposition
        xfail('nn.functional.pdist'),  # Forward AD not implemented and no decomposition
        xfail('nn.functional.rrelu'),  # vmap: we do not yet support aten::rrelu_with_noise.
        xfail('nn.functional.soft_margin_loss'),  # Forward AD not implemented and no decomposition
        xfail('normal'),  # calls random op
        xfail('normal', 'number_mean'),  # calls random op
        xfail('pca_lowrank'),  # calls random op
        xfail('quantile'),  # Batching rule not implemented for aten::equal
        xfail('renorm'),  # Forward AD not implemented and no decomposition
        xfail('scatter_reduce', 'prod'),  # Forward AD not implemented and no decomposition
        xfail('segment_reduce', 'lengths'),  # Forward AD not implemented and no decomposition
        xfail('segment_reduce', 'offsets'),  # Forward AD not implemented and no decomposition
        xfail('sparse.sampled_addmm'),  # RuntimeError: Sparse CSR tensors do not have strides
        xfail('svd_lowrank'),  # calls random op
        xfail('symeig'),  # Forward AD not implemented and no decomposition
        xfail('take'),  # vmap: inplace into regular tensor
        xfail('to'),  # RuntimeError: required rank 4 tensor to use channels_last format
        xfail('to_sparse'),  # Forward AD not implemented and no decomposition
        xfail('view_as_complex'),  # RuntimeError: Tensor must have a last dimension with stride 1
        # RuntimeError: Batch norm got a batched tensor as
        # input while the running_mean or running_var, which will be updated in
        # place, were not batched.
        xfail("native_batch_norm"),
    }))
    @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
    @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
    @opsToleranceOverride('TestOperators', 'test_vmapjvpvjp', (
        tol1('linalg.svd',
             {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
        tol1('linalg.householder_product',
             {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
        tol1('linalg.multi_dot',
             {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
        tol1('svd',
             {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
    ))
    def test_vmapjvpvjp(self, device, dtype, op):
        # Since we test `jvpvjp` seperately,
        # in this we just check that vmap of `jvpvjp`
        # is correct.
        if not op.supports_autograd:
            self.skipTest("Skipped! Autograd not supported.")
            return

        samples = op.sample_inputs(device, dtype, requires_grad=True)

        # TODO: test in-place
        if is_inplace(op, op.get_op()):
            self.skipTest("Skipped! NYI: inplace-testing not supported.")
            return

        for sample in samples:
            fn, primals = normalize_op_input_output(op, sample)
            result = fn(*primals)
            cotangents = tree_map(lambda x: torch.randn_like(x), result)

            primals_tangents = tree_map(lambda x: torch.randn_like(x), primals)
            cotangents_tangents = tree_map(lambda x: torch.randn_like(x), cotangents)

            def push_vjp(primals, cotangents):
                _, vjp_fn = vjp(fn, *primals)
                return vjp_fn(cotangents)

            args, spec = tree_flatten(((primals, cotangents), (primals_tangents, cotangents_tangents)))

            def jvp_of_vjp(*args):
                (primals, tangents) = tree_unflatten(args, spec)
                primals_out, tangents_out = jvp(push_vjp, primals, tangents)

                flat_primals_out, _ = tree_flatten(primals_out)
                flat_tangents_out, _ = tree_flatten(tangents_out)
                return tuple(flat_primals_out + flat_tangents_out)

            is_batch_norm_and_training = is_batch_norm_training(op, sample.kwargs)
            generator = get_fallback_and_vmap_exhaustive(
                jvp_of_vjp, args, {}, is_batch_norm_and_training=is_batch_norm_and_training)
            for loop_out, batched_out in generator:
                self.assertEqual(loop_out, batched_out)


    def _make_extremal_inputs(self, shape, device):
        if shape is None:
            return (None,)
        return (
            torch.full(shape, -1000., device=device),
            torch.zeros(shape, device=device),
            torch.full(shape, 1000., device=device),
        )

    def _arg_and_kwarg_options(self, args_options, kwargs_options):
        return itertools.product(*args_options, kwargs_options)

    def test_extremal_numerics_nll_loss(self, device):
        N, C = 3, 4
        d1, d2, d3 = 5, 6, 7
        shapes = (
            ((N, C), (N,), (C,)),
            ((N, C), (N,), None),
            ((N, C, d1, d2, d3), (N, d1, d2, d3), (C,)),
            ((N, C, d1, d2, d3), (N, d1, d2, d3), None),
        )
        kwargs_options = ({'ignore_index': 0, 'reduction': 'mean'}, {'reduction': 'sum'}, {'reduction': 'none'}, {})
        for input_shape, target_shape, weight_shape in shapes:
            input_options = self._make_extremal_inputs(input_shape, device)
            for input, kwargs in self._arg_and_kwarg_options((input_options,), kwargs_options):
                if weight_shape is None:
                    weight = None
                else:
                    weight = torch.randn(weight_shape, device=device)
                target = torch.randint(0, C, target_shape, device=device)
                target[0] = 1  # since we're ignoring index 0, at least one element must be non-zero

                fn = functools.partial(torch.nn.functional.nll_loss, target=target, weight=weight, **kwargs)
                result = fn(input)
                cotangents = torch.randn_like(result, device=device)
                self._compare_jacobians_of_vjp(fn, (cotangents, input))

    def test_extremal_numerics_l1_loss(self, device):
        N, C, H, W = 3, 4, 5, 6
        shapes = ((N, C), (N, C, H), (N, C, H, W))
        kwargs_options = ({'reduction': 'sum'}, {'reduction': 'none'}, {})
        for shape in shapes:
            input_options = self._make_extremal_inputs(shape, device)
            target_options = self._make_extremal_inputs(shape, device)
            for input, target, kwargs in self._arg_and_kwarg_options((input_options, target_options), kwargs_options):
                result = torch.nn.functional.l1_loss(input, target)
                cotangents = torch.randn_like(result, device=device)
                self._compare_jacobians_of_vjp(torch.nn.functional.l1_loss, (cotangents, input, target))

    def test_extremal_numerics_mse_loss(self, device):
        N, C, H, W = 3, 4, 5, 6
        shapes = ((N, C), (N, C, H), (N, C, H, W))
        kwargs_options = ({'reduction': 'sum'}, {'reduction': 'none'}, {})
        for shape in shapes:
            input_options = self._make_extremal_inputs(shape, device)
            target_options = self._make_extremal_inputs(shape, device)
            for input, target, kwargs in self._arg_and_kwarg_options((input_options, target_options), kwargs_options):
                result = torch.nn.functional.mse_loss(input, target)
                cotangents = torch.randn_like(result, device=device)
                self._compare_jacobians_of_vjp(torch.nn.functional.mse_loss, (cotangents, input, target))

    def test_extremal_numerics_softmax(self, device):
        N, C, H, W = 3, 4, 5, 6
        shapes = ((N, C), (N, C, H), (N, C, H, W))
        kwargs_options = ({'dim': 1}, {})
        for shape in shapes:
            input_options = self._make_extremal_inputs(shape, device)
            for input, kwargs in self._arg_and_kwarg_options((input_options,), kwargs_options):
                result = torch.nn.functional.softmax(input)
                cotangents = torch.randn_like(result, device=device)
                self._compare_jacobians_of_vjp(torch.nn.functional.softmax, (cotangents, input))


    def test_extremal_numerics_log_softmax(self, device):
        N, C, H, W = 3, 4, 5, 6
        shapes = ((N, C), (N, C, H), (N, C, H, W))
        kwargs_options = ({'dim': 1}, {})
        for shape in shapes:
            input_options = self._make_extremal_inputs(shape, device)
            for input, kwargs in self._arg_and_kwarg_options((input_options,), kwargs_options):
                result = torch.nn.functional.log_softmax(input)
                cotangents = torch.randn_like(result, device=device)
                self._compare_jacobians_of_vjp(torch.nn.functional.log_softmax, (cotangents, input))

    def test_extremal_numerics_cross_entropy(self, device):
        N, C = 3, 4
        d1, d2, d3 = 5, 6, 7
        shapes = (
            ((N, C), (N,), (C,)),
            ((N, C), (N,), None),
            ((N, C), (N, C), (C,)),
            ((N, C), (N, C), None),
            ((C,), (), (C,)),
            ((C,), (), None),
            ((C,), (C,), (C,)),
            ((C,), (C,), None),
            ((N, C, d1, d2, d3), (N, d1, d2, d3), (C,)),
            ((N, C, d1, d2, d3), (N, d1, d2, d3), None),
            ((N, C, d1, d2, d3), (N, C, d1, d2, d3), (C,)),
            ((N, C, d1, d2, d3), (N, C, d1, d2, d3), None),
        )
        for input_shape, target_shape, weight_shape in shapes:
            input_options = self._make_extremal_inputs(input_shape, device)
            kwargs_options = [{'reduction': 'sum'}, {'reduction': 'none'}, {}]
            if input_shape != target_shape:
                kwargs_options.append({'ignore_index': 0, 'reduction': 'mean'})

            for input, kwargs in self._arg_and_kwarg_options((input_options,), kwargs_options):
                if weight_shape is None:
                    weight = None
                else:
                    weight = torch.randn(weight_shape, device=device)

                if input_shape == target_shape:
                    target = torch.rand(target_shape, device=device)
                elif len(target_shape) == 0:
                    target = torch.tensor(1, device=device)  # must be non-zero since ignore_index may be 0
                else:
                    target = torch.randint(0, C, target_shape, device=device)

                fn = functools.partial(torch.nn.functional.cross_entropy, target=target, weight=weight, **kwargs)
                result = fn(input)
                cotangents = torch.randn_like(result, device=device)
                self._compare_jacobians_of_vjp(fn, (cotangents, input), atol_rtol=(1e-4, 1e-5))

    def test_extremal_numerics_binary_cross_entropy(self, device):
        N, C, H, W = 3, 4, 5, 6
        shapes = ((N, C), (N, C, H), (N, C, H, W))
        for shape in shapes:
            weight_options = self._make_extremal_inputs(shape, device)
            kwargs_options = [{'reduction': 'sum'}, {'reduction': 'none'}, {}]

            for weight, kwargs in self._arg_and_kwarg_options((weight_options,), kwargs_options):
                input = torch.rand(shape, device=device)
                target = torch.rand(shape, device=device)
                fn = functools.partial(torch.nn.functional.binary_cross_entropy, target=target, weight=weight, **kwargs)
                result = fn(input)
                cotangents = torch.randn_like(result, device=device)
                self._compare_jacobians_of_vjp(fn, (cotangents, input), atol_rtol=(1e-4, 2e-5))

    def test_extremal_numerics_layer_norm(self, device):
        N, C, H, W = 3, 4, 5, 6
        shapes = ((N, C), (N, C, H), (N, C, H, W))
        for shape in shapes:
            input_options = self._make_extremal_inputs(shape, device)
            normalized_shape = shape[1:]
            weight_options = self._make_extremal_inputs(normalized_shape, device)
            bias_options = self._make_extremal_inputs(normalized_shape, device)

            for input, bias, weight in self._arg_and_kwarg_options((input_options, bias_options, weight_options), ()):
                def fn(input, weight, bias):
                    return torch.nn.functional.layer_norm(input, normalized_shape, weight=weight, bias=bias)
                result = fn(input, weight, bias)
                cotangents = torch.randn_like(result, device=device)
                self._compare_jacobians_of_vjp(fn, (cotangents, input, weight, bias))

    @ops(op_db + additional_op_db, allowed_dtypes=(torch.float32, torch.double))
    @skipOps('TestOperators', 'test_vmap_autograd_grad', {
        xfail('linalg.eig'),  # all close?
        # The size of tensor a (4) must match the size of tensor b (10) at non-singleton dimension 0
        xfail('masked_select'),
        xfail('nn.functional.max_unpool2d', 'grad'),  # contiguous call
        xfail('nn.functional.max_unpool2d'),  # contiguous call
        xfail('to_sparse'),  # dispatch key issue

        # numerical inconsistencies, look like bugs
        skip('matrix_exp', dtypes=(torch.float32,), device_type='cuda'),  # fails on linux, passes on windows
        skip('ldexp', dtypes=(torch.float32,), device_type='cpu'),  # fails on all but mac
        skip('__rmatmul__'),  # flaky needs investigation
        skip('matmul'),  # flaky needs investigation
        skip('nn.functional.conv_transpose3d'),  # flaky needs investigation
        skip('nn.functional.conv_transpose2d'),  # flaky needs investigation
        skip('nn.functional.conv_transpose1d'),  # flaky needs investigation
        skip('nn.functional.layer_norm', dtypes=(torch.float32,), device_type='cpu'),  # fails on windows
        skip('linalg.lu_factor', dtypes=(torch.float32,), device_type='cuda'),  # fails on all but windows
        skip('linalg.lu_factor_ex', dtypes=(torch.float32,), device_type='cuda'),  # fails on all but windows
        skip('linalg.multi_dot', '', device_type='cpu'),
        skip('sparse.sampled_addmm', ''),
        skip('native_layer_norm', '', device_type='cpu'),
        xfail('as_strided_scatter', ''),
    })
    @opsToleranceOverride('TestOperators', 'test_vmap_autograd_grad', (
        tol1('linalg.householder_product',
             {torch.float32: tol(atol=5e-04, rtol=9e-03)}, device_type='cuda'),
        tol1('linalg.householder_product',
             {torch.float32: tol(atol=1e-04, rtol=1e-04)}, device_type='cpu'),
    ))
    def test_vmap_autograd_grad(self, device, dtype, op):
        def is_differentiable(inp):
            return isinstance(inp, Tensor) and (inp.grad_fn is not None or inp.requires_grad)

        def get_flat_differentiable(pytree):
            flattened = tree_flatten(pytree)[0]
            return tuple(i for i in flattened if is_differentiable(i))

        def get_differentiable_linked(list1, list2):
            paired_list = zip(list1, list2)
            paired_list = tuple((first, second) for (first, second) in paired_list if is_differentiable(first))
            return zip(*paired_list)

        def filter_none(out):
            flattened = tree_flatten(out)[0]
            return tuple(o for o in flattened if o is not None)

        if not op.supports_autograd:
            self.skipTest("Skipped! Autograd not supported.")
            return

        sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)

        for sample_input in sample_inputs:
            fn, primals = normalize_op_input_output(op, sample_input)
            out = fn(*primals)
            cotangents = tree_map(torch.randn_like, out)

            def compute_grad(cotangents):
                out_flattened = out
                cotangents_flattened = cotangents
                if not isinstance(out_flattened, torch.Tensor):
                    out_flattened = tree_flatten(out)[0]
                    cotangents_flattened = tree_flatten(cotangents)[0]
                    out_flattened, cotangents_flattened = get_differentiable_linked(out_flattened, cotangents_flattened)

                return filter_none(
                    torch.autograd.grad(out_flattened, get_flat_differentiable(primals), cotangents_flattened,
                                        retain_graph=True, allow_unused=True))

            is_batch_norm_and_training = is_batch_norm_training(op, sample_input.kwargs)
            generator = get_fallback_and_vmap_exhaustive(
                compute_grad, (cotangents,), {}, is_batch_norm_and_training=is_batch_norm_and_training)
            for loop_out, batched_out in generator:
                self.assertEqual(loop_out, batched_out)

    def test_vmapvmapjvp_linalg_solve(self):
        ops = [op for op in op_db if op.name == "linalg.solve"]
        assert len(ops) > 0

        # this specializes a lot of code from the get_fallback_and_vmap_exhaustive test. If we need this more
        # generally, this could go for a refactor

        B0 = 2
        B1 = 3

        # we want to check the case where A will be seen as contiguous by jvp but during the vmap calls will become
        # non-contiguous because vmap will expand. This will happen during both levels of vmap
        A = torch.randn(4, 4)
        k = torch.randn(4, 5, B1, B0)
        fn, args = get_jvp_variant_primals_tangents(torch.linalg.solve, SampleInput(A, args=(k,)))

        in_dims_all = (None, -1, None, -1)
        batched_out = vmap(vmap(fn, in_dims=in_dims_all), in_dims=in_dims_all)(*args)
        loop_out = loop2(fn, in_dims_all, in_dims_all, 0, 0, B0, B1, *args)
        self.assertEqual(loop_out, batched_out)

    @ops(filter(lambda op: op.name in aliasing_ops, op_db + additional_op_db), allowed_dtypes=(torch.float,))
    @parametrize("grad_op", ["jvp", "vjp"])
    def test_view_then_inplace(self, device, dtype, op, grad_op):
        for sample_input in op.sample_inputs(device, dtype):
            def f(x):
                op(sample_input.input, *sample_input.args, **sample_input.kwargs).copy_(x)
                return x

            without_grad = op(sample_input.input, *sample_input.args, **sample_input.kwargs)
            if grad_op == "jvp":
                with self.assertRaisesRegex(RuntimeError, "During a grad .* attempted to call in-place operation"):
                    jvp(f, (torch.randn_like(without_grad),), (torch.randn_like(without_grad),))
            else:
                assert grad_op == "vjp"
                with self.assertRaisesRegex(RuntimeError, "During a grad .* attempted to call in-place operation"):
                    vjp(f, torch.randn_like(without_grad))

    @ops(filter(lambda op: op.name in aliasing_ops_list_return, op_db + additional_op_db), allowed_dtypes=(torch.float,))
    @parametrize("grad_op", ["jvp", "vjp"])
    def test_view_then_inplace_list_return(self, device, dtype, op, grad_op):
        for sample_input in op.sample_inputs(device, dtype):
            def f(x):
                op(sample_input.input, *sample_input.args, **sample_input.kwargs)[0].copy_(x)
                return x

            without_grad = op(sample_input.input, *sample_input.args, **sample_input.kwargs)[0]
            with self.assertRaisesRegex(RuntimeError, "During a grad .* attempted to call in-place operation"):
                if grad_op == "jvp":
                    jvp(f, (torch.randn_like(without_grad),), (torch.randn_like(without_grad),))
                else:
                    assert grad_op == "vjp"
                    vjp(f, torch.randn_like(without_grad))

    @parametrize("grad_op", ["jvp", "vjp"])
    def test_view_then_inplace_special(self, grad_op):
        # some things in __getitem__ use at::index, which doesn't alias, so this tests a subset of them that do alias
        ops = [
            lambda x: x[0],
            lambda x: x[0, 0, 0],
            lambda x: x[:1],
            lambda x: x[:, :1],
            lambda x: x[:, :1, :],
        ]

        for op in ops:
            def f(x):
                op(captured).copy_(x)
                return x

            captured = torch.randn(4, 3, 3)
            without_grad = op(captured)
            if grad_op == "jvp":
                with self.assertRaisesRegex(RuntimeError, "During a grad .* attempted to call in-place operation"):
                    jvp(f, (torch.randn_like(without_grad),), (torch.randn_like(without_grad),))
            else:
                assert grad_op == "vjp"
                with self.assertRaisesRegex(RuntimeError, "During a grad .* attempted to call in-place operation"):
                    vjp(f, torch.randn_like(without_grad))

only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestOperators, globals(), only_for=only_for)

if __name__ == '__main__':
    run_tests()
