# Owner(s): ["module: functorch"]

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import OrderedDict
from unittest.case import skipIf
from torch.testing._internal.common_utils import TestCase, run_tests
import torch
import torch.nn.functional as F
from torch import Tensor
import functools
import itertools
import warnings
import unittest
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
    skipCUDAIfNoMagma
from torch.testing._internal.common_device_type import ops
from torch.testing._internal.common_utils import (
    parametrize,
    instantiate_parametrized_tests,
    subtest
)
from torch.testing._internal.common_device_type import \
    toleranceOverride, tol
from functorch_additional_op_db import additional_op_db
from common_utils import (
    get_fallback_and_vmap_exhaustive,
    xfail,
    skip,
    skipOps,
    check_vmap_fallback,
    tol1,
    opsToleranceOverride,
    is_batch_norm_training,
    generate_vmap_inputs,
    compute_quantities_for_vmap_test,
    is_valid_inplace_sample_input,
)
import types
from collections import namedtuple

import functorch
from functorch import vmap, grad, grad_and_value, jvp, vjp, jacfwd
from functorch.experimental import chunk_vmap
from torch._C._functorch import reshape_dim_into, reshape_dim_outof
from functorch._src.make_functional import functional_init_with_buffers

FALLBACK_REGEX = 'There is a performance drop'


class EnableVmapFallbackWarnings:
    def __enter__(self):
        self.prev_state = torch._C._debug_only_are_vmap_fallback_warnings_enabled()
        torch._C._debug_only_display_vmap_fallback_warnings(True)

    def __exit__(self, *ignored):
        torch._C._debug_only_display_vmap_fallback_warnings(self.prev_state)


class TestVmapAPI(TestCase):
    def test_non_tensor_output_raises(self):
        with self.assertRaisesRegex(ValueError, "got type <class 'float'> as a return"):
            vmap(lambda x: 3.14)(torch.ones(3))

        def multiple_outputs(x):
            return x, 3

        with self.assertRaisesRegex(ValueError, "got type <class 'int'> as a return"):
            vmap(multiple_outputs)(torch.ones(3))

    def test_different_map_dim_size_raises(self):
        x = torch.randn(2)
        y = torch.randn(3)
        expected_msg = 'Expected all tensors to have the same size in the mapped dimension'
        with self.assertRaisesRegex(ValueError, expected_msg):
            vmap(torch.mul)(x, y)
        with self.assertRaisesRegex(ValueError, expected_msg):
            vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y))
        with self.assertRaisesRegex(ValueError, expected_msg):
            vmap(lambda z: z['x'] + z['y'], in_dims=({'x': 0, 'y': 0},))({'x': x, 'y': y})

    def test_func_with_no_inputs(self):
        expected_msg = 'got no inputs'

        def foo():
            return torch.randn(3)

        def bar(x):
            return torch.randn(3)

        with self.assertRaisesRegex(ValueError, expected_msg):
            vmap(foo)()

        with self.assertRaisesRegex(ValueError, expected_msg):
            vmap(bar)()

    def test_func_with_no_tensors(self):
        def foo(x):
            return torch.randn(3)

        with self.assertRaisesRegex(ValueError, 'at least one Tensor'):
            vmap(foo, (None,))(1)

    def test_constant_function(self):
        output = vmap(lambda x: torch.tensor(3.14))(torch.ones(3))
        self.assertEqual(output, torch.tensor([3.14, 3.14, 3.14]))

    def test_single_input(self):
        x = torch.randn(2, 3)

        def square(x):
            return x * x

        output = vmap(square)(x)
        self.assertEqual(output, x * x)

    def test_multiple_inputs(self):
        x = torch.randn(2, 3)
        y = torch.randn(2, 3)
        output = vmap(torch.mul)(x, y)
        self.assertEqual(output, x * y)

    def test_multiple_outputs(self):
        def foo(x):
            return x * x, x * x * x

        x = torch.randn(3)
        outputs = vmap(foo)(x)
        self.assertEqual(outputs[0], x * x)
        self.assertEqual(outputs[1], x * x * x)

    def test_multiple_outputs2(self):
        # This is the same thing as
        # def returns_tuple_of_tensors(x):
        #     return x, x
        def returns_tuple_of_tensors(x):
            return (x, x)

        def returns_list_of_two_tensors(x):
            return [x, x]

        def returns_list_of_one_tensor(x):
            return [x]

        x = torch.randn(3)

        # should not throw
        vmap(returns_tuple_of_tensors)(x)
        vmap(returns_list_of_two_tensors)(x)
        vmap(returns_list_of_one_tensor)(x)

    def test_nested_with_same_map_dim(self):
        x = torch.randn(2, 3, 5)
        y = torch.randn(2, 3, 5)
        output = vmap(vmap(torch.mul))(x, y)
        self.assertEqual(output, x * y)

        output = vmap(vmap(vmap(torch.mul)))(x, y)
        self.assertEqual(output, x * y)

    def test_nested_with_diag_embed(self):
        # diag_embed requires special testing because it is registered with conditional functionalization.
        x = torch.randn(3, 3, 5)
        output = vmap(vmap(torch.diag_embed))(x)
        self.assertEqual(output, torch.diag_embed(x))

    def test_nested_with_different_map_dim(self):
        x = torch.randn(2, 3)
        y = torch.randn(5, 3)
        output = vmap(lambda x: vmap(lambda y: x * y)(y))(x)
        self.assertEqual(output.shape, (2, 5, 3))
        self.assertEqual(output, x.view(2, 1, 3) * y)

        z = torch.randn(7, 3)
        output = vmap(lambda x: vmap(lambda y: vmap(lambda z: x * y * z)(z))(y))(x)
        self.assertEqual(output.shape, (2, 5, 7, 3))
        self.assertEqual(output, x.view(2, 1, 1, 3) * y.view(5, 1, 3) * z)

    def test_noop_in_inner_vmap(self):
        x = torch.randn(3)
        y = torch.randn(5)
        output = vmap(lambda x: vmap(lambda y: x)(y))(x)
        self.assertEqual(output, x.view(3, 1).expand(3, 5))

    def test_unsupported_op_err_msg(self):
        # Unsupported view op
        tensor = torch.randn(2, 3)
        msg = (
            r"Batching rule not implemented for aten::.+; the "
            r"fallback path doesn't work on out= or view ops"
        )
        # TODO: find a view op
        # with self.assertRaisesRegex(RuntimeError, msg):
        #     vmap(torch.ravel)(tensor)

        def out_op(x, y):
            return torch.abs(x, out=y)

        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(out_op)(tensor, tensor)

        # Don't support non-tensor returns. This is a limitation of vmap;
        # functions that don't return tensors must be special cased
        with self.assertRaisesRegex(RuntimeError, 'Batching rule not implemented'):
            vmap(torch.equal)(tensor, tensor)

    def test_nonzero_out_dims(self):
        # Basic test
        tensor = torch.randn(2, 3)
        result = vmap(lambda x: x, out_dims=1)(tensor)
        self.assertEqual(result, tensor.permute(1, 0))
        self.assertEqual(result.data_ptr(), tensor.data_ptr())

        # Test that the batch dimension gets permuted to dim 2
        tensor = torch.randn(2, 3, 5, 7)
        result = vmap(lambda x: x, out_dims=2)(tensor)
        self.assertEqual(result, tensor.permute(1, 2, 0, 3))
        self.assertEqual(result.data_ptr(), tensor.data_ptr())

        # negative out_dim
        tensor = torch.randn(2, 3, 5, 7)
        result = vmap(lambda x: x, out_dims=-1)(tensor)
        self.assertEqual(result, tensor.permute(1, 2, 3, 0))
        self.assertEqual(result.data_ptr(), tensor.data_ptr())

        # check that out_dims works on ALL outputs
        tensor = torch.randn(2, 3, 5, 7)
        other = torch.randn(2, 3, 5, 7)
        result = vmap(lambda x, y: (x, y), out_dims=2)(tensor, other)
        self.assertEqual(result, (tensor.permute(1, 2, 0, 3), other.permute(1, 2, 0, 3)))

        # use out_dims with the maximum vmap-able tensor dims (64 dims)
        ndims = 64
        shape = [2] + [1] * (ndims - 1)
        expected_shape = [1, 1, 2] + [1] * (ndims - 3)
        tensor = torch.randn(shape)
        result = vmap(lambda x: x, out_dims=2)(tensor)
        self.assertEqual(result.shape, expected_shape)

        # test something that is not the identity function
        def foo(x, y):
            return x, x * y, x * y * y
        x = torch.randn(2, 3, 5)
        y = torch.randn(2, 3, 5)
        result = vmap(foo, out_dims=1)(x, y)
        self.assertEqual(
            result,
            (x.permute(1, 0, 2), (x * y).permute(1, 0, 2), (x * y * y).permute(1, 0, 2)))

    def test_multiple_out_dims(self):
        def foo(x):
            return x, x

        def bar(x, y):
            return x, x, x, x * y

        x = torch.randn(2, 3, 5)
        y = torch.randn(2, 3, 5)
        result = vmap(foo, out_dims=(0, 1))(x)
        self.assertEqual(result, (x, x.permute(1, 0, 2)))

        result = vmap(bar, out_dims=(-1, 0, 1, 2))(x, y)
        expected = (
            x.permute(1, 2, 0),
            x,
            x.permute(1, 0, 2),
            (x * y).permute(1, 2, 0),
        )
        self.assertEqual(result, expected)

    def test_nested_out_dims(self):
        y = torch.randn(2, 3, 5, 7)

        # Inner vmap has non-zero out_dim
        result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y))(y)
        self.assertEqual(result.shape, (2, 5, 3, 7))
        self.assertEqual(result, y.permute(0, 2, 1, 3))

        # all vmaps have non-zero out_dim
        result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y), out_dims=1)(y)
        self.assertEqual(result.shape, (5, 2, 3, 7))
        self.assertEqual(result, y.permute(2, 0, 1, 3))

        # throwing in some negative out_dims
        result = vmap(lambda y: vmap(lambda x: x, out_dims=-1)(y), out_dims=-1)(y)
        self.assertEqual(result.shape, (5, 7, 3, 2))
        self.assertEqual(result, y.permute(2, 3, 1, 0))

        # testing fn that isn't the identity
        x = torch.randn(2, 3)
        y = torch.randn(5, 3)
        result = vmap(lambda y: vmap(lambda x: x * y, out_dims=1)(x), out_dims=-1)(y)
        self.assertEqual(result.shape, (3, 2, 5))
        self.assertEqual(result, (y.view(5, 1, 3) * x).permute(2, 1, 0))

    def test_out_dims_edge_case(self):
        def foo(x):
            return x

        # Test that we accept out_dims=(1,) for a function with one output.
        tensor = torch.randn(2, 3)
        expected = vmap(foo, out_dims=1)(tensor)
        result = vmap(foo, out_dims=(1,))(tensor)
        self.assertEqual(result, expected)

    def test_pytree_returns(self):
        x = torch.randn(2, 3)

        def f(x):
            y = x.sin()
            return y, (y, y), [y, (y, y)]

        y0, (y1, y2), (y3, (y4, y5)) = vmap(f)(x)
        self.assertEqual(y0, x.sin())
        self.assertEqual(y0, y1)
        self.assertEqual(y2, y1)
        self.assertEqual(y2, y3)
        self.assertEqual(y4, y3)
        self.assertEqual(y5, y4)

    def test_pytree_odict_returns(self):
        x = torch.randn(2, 3)

        def f(t):
            y = t.sin()
            return OrderedDict([("sin", y), ("cos", t.cos())])

        out = vmap(f)(x)
        assert isinstance(out, OrderedDict)
        expected = f(x)
        self.assertEqual(out["sin"], expected["sin"])
        self.assertEqual(out["cos"], expected["cos"])

    def test_pytree_returns_outdims(self):
        x = torch.randn(2, 3)

        def f(x):
            y = x.sin()
            return y, (y, y)

        y0, (y1, y2) = vmap(f, out_dims=(0, (0, 1)))(x)
        self.assertEqual(y0, x.sin())
        self.assertEqual(y1, x.sin())
        self.assertEqual(y2, x.sin().t())

    def test_pytree_returns_broadcast_simple(self):
        x = torch.randn(2, 3)

        def f(x):
            y = x.sin()
            return y, (y, y)

        y0, (y1, y2) = vmap(f, out_dims=1)(x)
        self.assertEqual(y0, x.sin().t())
        self.assertEqual(y1, y0)
        self.assertEqual(y2, y0)

    def test_pytree_returns_broadcast_nested(self):
        x = torch.randn(2, 3)

        def f(x):
            y = x.sin()
            return y, (y, y)

        y0, (y1, y2) = vmap(f, out_dims=(0, 1))(x)
        self.assertEqual(y0, x.sin())
        self.assertEqual(y1, y0.t())
        self.assertEqual(y2, y0.t())

    def test_out_dims_must_be_int_or_collection_of_int_err_msg(self):
        msg = 'must be an int or a python collection of ints'
        tensor = torch.randn(2, 3)
        with self.assertRaisesRegex(ValueError, msg):
            vmap(lambda x: x, out_dims='lol')(tensor)
        with self.assertRaisesRegex(ValueError, msg):
            vmap(lambda x: x, out_dims=('lol',))(tensor)
        with self.assertRaisesRegex(ValueError, msg):
            vmap(lambda x: x, out_dims=None)(tensor)
        with self.assertRaisesRegex(ValueError, msg):
            vmap(lambda x: x, out_dims=(None,))(tensor)

    def test_out_dims_and_num_outputs_mismatch_err_msg(self):
        msg = 'not compatible'
        x = torch.randn(2, 3, 5)

        # Too many out_dims
        with self.assertRaisesRegex(ValueError, msg):
            vmap(lambda x: x, out_dims=(0, 0))(x)
        with self.assertRaisesRegex(ValueError, msg):
            vmap(lambda x: (x, x, x), out_dims=(0, 0, 0, 0))(x)

        # Too few out_dims
        with self.assertRaisesRegex(ValueError, msg):
            vmap(lambda x: (x, x), out_dims=(0,))(x)
        with self.assertRaisesRegex(ValueError, msg):
            vmap(lambda x: (x, x, x), out_dims=(0, 0))(x)

    def test_out_dim_out_of_bounds_err_msg(self):
        # TODO(rzou): This error message isn't that great. It comes straight
        # from maybe_wrap_dim. Consider doing a try-catch-(add some context) to
        # the error message in the future in C++
        msg = 'Dimension out of range'
        x = torch.randn(2, 3, 5)
        with self.assertRaisesRegex(IndexError, msg):
            vmap(lambda x: x, out_dims=3)(x)
        with self.assertRaisesRegex(IndexError, msg):
            vmap(lambda x: x, out_dims=-4)(x)

    def test_non_zero_in_dims(self):
        tensor = torch.randn(2, 3, 5)

        # Implicit out_dims = 0; vmap will move the batch dim to the front.
        output = vmap(lambda x: x, (1,))(tensor)
        self.assertEqual(output, tensor.permute(1, 0, 2))
        self.assertEqual(output.data_ptr(), tensor.data_ptr())

        x = torch.randn(2, 3)
        y = torch.randn(3, 2)
        output = vmap(torch.mul, (0, 1))(x, y)
        self.assertEqual(output, x * y.t())
        output = vmap(torch.mul, (1, 0))(x, y)
        self.assertEqual(output, x.t() * y)

    def test_none_in_dims(self):
        x = torch.randn(2, 3)
        y = torch.randn(2, 3)

        # None in_dim for a Tensor means we don't map over it
        output = vmap(torch.mul, (0, None))(x, y)
        self.assertEqual(output.shape, (2, 2, 3))
        self.assertEqual(output, x.view(2, 1, 3) * y)

        # None in_dim for non-tensor arguments
        output = vmap(torch.mul, (0, None))(x, 2)
        self.assertEqual(output, x * 2)

    def test_nested_non_default_in_dims(self):
        x = torch.rand(5, 2, 3)
        y = torch.rand(3, 5, 2)
        result = vmap(vmap(vmap(torch.mul), (1, 0)), (1, 2))(x, y)
        self.assertEqual(result, x.permute(1, 2, 0) * y.permute(2, 0, 1))

    def test_nested_negative_in_dims(self):
        x = torch.randn(2, 3)
        y = torch.randn(2, 3)
        output = vmap(torch.mul, (-1, -1))(x, y)
        self.assertEqual(output.shape, (3, 2))
        self.assertEqual(output, (x * y).permute(1, 0))

    def test_non_default_in_dims_out_dims(self):
        x = torch.randn(2, 3, 5)

        # Same in_dim as out_dim, vmap over identity
        result = vmap(lambda x: x, in_dims=1, out_dims=1)(x)
        self.assertEqual(result, x)
        self.assertEqual(result.data_ptr(), x.data_ptr())

        # Different in_dim from out_dim, vmap over identity
        result = vmap(lambda x: x, in_dims=2, out_dims=1)(x)
        self.assertEqual(result.shape, (2, 5, 3))
        self.assertEqual(result, x.transpose(1, 2))
        self.assertEqual(result.data_ptr(), x.data_ptr())

        def foo(x):
            return x * 2

        # Same in_dim as out_dim, vmap over operation
        result = vmap(foo, in_dims=1, out_dims=1)(x)
        self.assertEqual(result, x * 2)

        # Different in_dim as out_dim, vmap over operation
        result = vmap(foo, in_dims=2, out_dims=1)(x)
        self.assertEqual(result.shape, (2, 5, 3))
        self.assertEqual(result, (x * 2).transpose(1, 2))

        # Basic nested test.
        result = vmap(vmap(foo, 1, 1), 1, 1)(x)
        self.assertEqual(result, x * 2)

    def test_item_throws(self):
        def f(x):
            return x.item()

        with self.assertRaisesRegex(RuntimeError, r'item\(\) on a Tensor'):
            vmap(f)(torch.randn(3))

    def test_data_dependent_control_flow_throws(self):
        def f(x):
            if x:
                return x
            return 0

        with self.assertRaisesRegex(RuntimeError, r'data-dependent control flow'):
            vmap(f)(torch.randn(3))

    def test_accepts_nested_inputs(self):
        x = torch.randn(2, 3)
        y = torch.randn(2, 3)

        # Single layer of nesting
        out = vmap(lambda z: z[0] + z[1])((x, y))
        self.assertEqual(out, x + y)
        out = vmap(lambda z: z[0] + z[1], in_dims=(0,))((x, y))
        self.assertEqual(out, x + y)
        out = vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y))
        self.assertEqual(out, x + y)

        out = vmap(lambda z: z[0] + z[1])([x, y])
        self.assertEqual(out, x + y)
        out = vmap(lambda z: z[0] + z[1], in_dims=(0,))([x, y])
        self.assertEqual(out, x + y)
        out = vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, y])
        self.assertEqual(out, x + y)

        out = vmap(lambda z: z['x'] + z['y'])({'x': x, 'y': y})
        self.assertEqual(out, x + y)
        out = vmap(lambda z: z['x'] + z['y'], in_dims=(0,))({'x': x, 'y': y})
        self.assertEqual(out, x + y)
        out = vmap(lambda z: z['x'] + z['y'], in_dims=({'x': 0, 'y': 0},))({'x': x, 'y': y})
        self.assertEqual(out, x + y)

        # Multiple layers of nesting
        out_fn = vmap(lambda z: z['x'][0] + z['x'][1][0] + z['y'][0] + z['y'][1])
        out = out_fn({'x': [x, (x,)], 'y': [y, y]})
        self.assertEqual(out, x + x + y + y)

    def test_in_dims_wrong_type_err_msg(self):
        x = torch.randn(3)
        y = torch.randn(3)
        msg = r'expected `in_dims` to be int or a \(potentially nested\) tuple'
        with self.assertRaisesRegex(ValueError, msg):
            vmap(torch.mul, [0, 0])(x, y)
        with self.assertRaisesRegex(ValueError, msg):
            vmap(torch.mul, set({0, 0}))(x, y)
        with self.assertRaisesRegex(ValueError, msg):
            vmap(torch.mul, 'lol')(x, y)
        with self.assertRaisesRegex(ValueError, msg):
            vmap(lambda z: z[0] + z[1], in_dims=[0, 0])([x, y])
        # The following should not throw
        vmap(torch.mul, (0, 0))(x, y)

    def test_not_enough_in_dims_err_msg(self):
        x = torch.randn(3)
        y = torch.randn(3)
        msg = r'in_dims is not compatible with the structure of `inputs`'

        with self.assertRaisesRegex(ValueError, msg):
            vmap(torch.mul, (0,))(x, y)
        with self.assertRaisesRegex(ValueError, msg):
            vmap(torch.mul, (0, 0, 0))(x, y)
        with self.assertRaisesRegex(ValueError, msg):
            vmap(lambda z: z[0] + z[1], in_dims=([0],))([x, y])
        with self.assertRaisesRegex(ValueError, msg):
            vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))([x, y])
        # The following should not throw
        vmap(torch.mul, (0, 0))(x, y)

    def test_integer_in_dim_but_not_tensor_input_err_msg(self):
        def foo(xy):
            return xy[0] * xy[1]

        def bar(x, yz):
            return x * yz[0] * yz[1]

        x = torch.randn(2, 3)

        # the following are errors in jax (and will always be errors)
        msg = 'Got in_dim=0 for an input but the input is of type'
        with self.assertRaisesRegex(ValueError, msg):
            vmap(torch.sum)(x, 0)
        with self.assertRaisesRegex(ValueError, msg):
            vmap(torch.sum, (0, 0))(x, 0)
        with self.assertRaisesRegex(ValueError, msg):
            vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, 1])
        # The following should not throw
        vmap(torch.sum, (0, None))(x, 0)

    def test_in_dim_not_in_tensor_err_msg(self):
        def foo(x):
            return x * x

        x = torch.randn(2, 3)
        y = torch.randn(2, 3)

        msg = r'Got in_dim=-?\w for some input, but that input is a Tensor of dimensionality \w'
        with self.assertRaisesRegex(ValueError, msg):
            vmap(foo)(torch.randn([]))
        with self.assertRaisesRegex(ValueError, msg):
            vmap(foo, in_dims=(0,))(torch.randn([]))
        with self.assertRaisesRegex(ValueError, msg):
            vmap(foo, in_dims=(-3,))(x)
        with self.assertRaisesRegex(ValueError, msg):
            vmap(foo, in_dims=(2,))(y)
        with self.assertRaisesRegex(ValueError, msg):
            vmap(lambda z: z[0] + z[1], in_dims=([3, 0],))([x, y])
        # the following should not throw
        vmap(foo, in_dims=(0,))(torch.randn(2, 3))
        vmap(foo, in_dims=(1,))(torch.randn(2, 3))

    def test_fallback_does_not_warn_by_default(self):
        # NB: One day we will implement a batching rule for torch.atan2.
        # If/when we do, this test should be replaced to test the fallback
        # path on another operator to avoid bitrot.
        op = torch.copysign
        x = torch.randn(11)
        y = torch.randn(11)
        with warnings.catch_warnings(record=True) as wa:
            vmap(op)(x, y)
            # The single warning here is the "vmap is experimental"
            # warning, not a warning from the vmap fallback path.
            self.assertEqual(len(wa), 1)

    @unittest.expectedFailure
    def test_fallback_warns_when_warnings_are_enabled(self):
        # NB: One day we will implement a batching rule for torch.atan2.
        # If/when we do, this test should be replaced to test the fallback
        # path on another operator to avoid bitrot.
        op = torch.copysign
        x = torch.randn(11)
        y = torch.randn(11)
        with warnings.catch_warnings(record=True) as wa:
            with EnableVmapFallbackWarnings():
                vmap(op)(x, y)
            self.assertEqual(len(wa), 2)
            self.assertRegex(str(wa[-1].message), FALLBACK_REGEX)

    def _assert_uses_vmap_fallback(self, vmap_args, inputs):
        return
        # with warnings.catch_warnings(record=True) as wa:
        #     with EnableVmapFallbackWarnings():
        #         result = vmap(*vmap_args)(*inputs)
        #     self.assertEqual(len(wa), 2)
        #     self.assertRegex(str(wa[-1].message), FALLBACK_REGEX)

    def test_fallback_zero_dim(self):
        # NB: One day we will implement a batching rule for torch.atan2.
        # If/when we do, this test should be replaced to test the fallback
        # path on another operator to avoid bitrot.
        op = torch.copysign
        x = torch.randn(11)
        y = torch.randn(11)
        self._assert_uses_vmap_fallback((op,), (x, y))

        B0, B1 = 0, 3
        x = torch.randn(B0, 11)
        y = torch.randn(11)

        msg = 'The fallback path does not support vmap over dims of size 0'

        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(op, (0, None))(x, y)
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(op, (None, 0))(y, x)
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(op)(x, x)

        x = torch.randn(B0, B1, 11)
        y = torch.randn(B1, 11)
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(op, (0, None))(x, y)
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(op, (None, 0))(y, x)
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(op)(x, x)

    def test_fallback_atan2(self):
        # NB: One day we will implement a batching rule for torch.atan2.
        # If/when we do, this test should be replaced to test the fallback
        # path on another operator to avoid bitrot.
        op = torch.copysign

        x = torch.randn(5, 7, 11)
        y = torch.randn(5, 7, 11)

        self._assert_uses_vmap_fallback((op,), (x, y))

        # fallback on torch.atan2
        x = torch.randn(7, 11, 5)
        y = torch.randn(5, 7, 11)
        result = vmap(op, (2, 0))(x, y)
        self.assertEqual(result, op(x.permute(2, 0, 1), y))

        # fallback on torch.atan2, nested vmap
        x = torch.randn(7, 11, 5)
        y = torch.randn(5, 7, 11)
        result = vmap(vmap(op), (2, 0))(x, y)
        self.assertEqual(result, op(x.permute(2, 0, 1), y))

        # big batch size (total 10000)
        x = torch.randn(100, 10, 10, 5)
        y = torch.randn(100, 10, 10)
        result = vmap(vmap(vmap(op)))(x, y)
        self.assertEqual(result, op(x, y.view(100, 10, 10, 1)))

    # TODO: No clue what is wrong here.
    @unittest.skip
    def test_fallback_masked_fill(self):
        # NB: One day we will implement a batching rule for masked_fill
        # If/when we do, this test should be replaced to test the fallback
        # path on another operator to avoid bitrot.
        def run_test(batch_size):
            B0 = batch_size
            x = torch.randn(B0, 7, 11, 13)
            dim = 0
            index = torch.tensor([0, 4, 2])
            values = torch.randn(B0, 3, 13)

            self._assert_uses_vmap_fallback((torch.index_add, (0, None, None, 0)), (x, dim, index, values))

            result = vmap(torch.index_add, (0, None, None, 0))(x, dim, index, values)
            expected = torch.index_add(
                x, dim + 1, index, values.view(B0, 3, 1, 13))
            self.assertEqual(result, expected)

        run_test(batch_size=5)
        run_test(batch_size=1237)

    def test_fallback_multiple_returns(self):
        # NB: One day we will implement a batching rule for torch.var_mean
        # If/when we do, this test should be replaced to test the fallback
        # path on another operator to avoid bitrot.
        B0, B1, B2 = 2, 3, 1237
        tensor = torch.randn(B0, 10)

        self._assert_uses_vmap_fallback((torch.var_mean,), (tensor,))

        # fallback correctness on torch.var_mean
        result = vmap(torch.var_mean)(tensor)
        expected = torch.var_mean(tensor, dim=1)
        self.assertEqual(result, expected)

        # nested vmap
        tensor = torch.randn(B0, B1, 10)
        result = vmap(vmap(torch.var_mean))(tensor)
        expected = torch.var_mean(tensor, dim=2)
        self.assertEqual(result, expected)

        # big batch size, nested vmap
        tensor = torch.randn(B0, B1, B2, 10)
        result = vmap(vmap(vmap(torch.var_mean)))(tensor)
        expected = torch.var_mean(tensor, dim=3)
        self.assertEqual(result, expected)

    def test_inplace_fallback_unary(self):
        # Test the in-place fallback on an in-place method that takes no
        # additional Tensor arguments. This is the simplest case of the fallback.
        # NB: One day we will implement a batching rule for acos_.
        # If/when we do, this test should be replaced to test the fallback
        # path on another operator to avoid bitrot.
        op = Tensor.acos_
        B0, B1, B2 = 2, 3, 10000

        x = torch.randn(B0, 5)
        self._assert_uses_vmap_fallback((op,), (x,))

        # Single vmap
        x_orig = torch.rand(B0, 5)
        x = x_orig.clone()
        result = vmap(op)(x)
        self.assertTrue(result is x)
        self.assertEqual(result, x_orig.acos())

        # Single vmap + different out_dim produces a view(!)
        x_orig = torch.rand(B0, 5)
        x = x_orig.clone()
        result = vmap(op, out_dims=(1,))(x)
        self.assertTrue(result._base is x)
        self.assertEqual(result, x_orig.t().acos())

        # Nested vmap
        x_orig = torch.randn(B0, B1, 5)
        x = x_orig.clone()
        result = vmap(vmap(op))(x)
        self.assertTrue(result is x)
        self.assertEqual(result, x_orig.acos())

        # Nested vmap, large batch size
        x_orig = torch.randn(B0, B1, B2, 5)
        x = x_orig.clone()
        result = vmap(vmap(vmap(op)))(x)
        self.assertTrue(result is x)
        self.assertEqual(result, x_orig.acos())

    def test_inplace_fallback_nary_same_levels(self):
        # NB: One day we will implement a batching rule for atan2_
        # If/when we do, this test should be replaced to test the fallback
        # path on another operator to avoid bitrot.
        op = Tensor.atan2_
        outplace_op = torch.atan2

        x = torch.randn(5, 7, 11)
        y = torch.randn(5, 7, 11)
        self._assert_uses_vmap_fallback((op,), (x, y))

        # Single vmap
        B0 = 5
        x_orig = torch.randn(7, 11, B0)
        x = x_orig.clone()
        y = torch.randn(B0, 7, 11)
        vmap(op, (2, 0))(x, y)
        self.assertEqual(x, outplace_op(x_orig, y.movedim(0, 2)))

        # Nested vmap
        B0, B1 = 5, 7
        x_orig = torch.randn(B1, 11, B0)
        x = x_orig.clone()
        y = torch.randn(B0, B1, 11)
        vmap(vmap(op), (2, 0))(x, y)
        self.assertEqual(x, outplace_op(x_orig, y.movedim([0, 1], [2, 0])))

        # big batch size (total 10000)
        B0, B1, B2 = 100, 10, 10
        x_orig = torch.randn(B0, B1, B2, 5)
        x = x_orig.clone()
        y = torch.randn(B0, B1, B2)
        vmap(vmap(vmap(op)))(x, y)
        self.assertEqual(x, outplace_op(x_orig, y.view(B0, B1, B2, 1)))

    # ("Fallback isInplaceVmapCompatible check is broken")
    @unittest.expectedFailure
    def test_inplace_fallback_nary_different_levels(self):
        # NB: One day we will implement a batching rule for atan2_
        # If/when we do, this test should be replaced to test the fallback
        # path on another operator to avoid bitrot.
        op = Tensor.atan2_
        outplace_op = torch.atan2
        B0, B1 = 2, 3

        x = torch.rand(B0, 7)
        y = torch.rand(7)
        self._assert_uses_vmap_fallback((op, (0, None)), (x, y))

        # op(left, right): All of the levels in right are found in left
        x_orig = torch.rand(B0, 7)
        x = x_orig.clone()
        y = torch.rand(7)
        vmap(op, in_dims=(0, None))(x, y)
        self.assertEqual(x, outplace_op(x_orig, y))

        x_orig = torch.rand(B0, B1, 7)
        x = x_orig.clone()
        y = torch.rand(B0, 7)
        vmap(vmap(op, in_dims=(0, None)))(x, y)
        self.assertEqual(x, outplace_op(x_orig, y.view(B0, 1, 7)))

        # op(left, right): Some of the levels in right are not found in left
        msg = r'vmap: aten::atan2_\(self, \*extra_args\) is not possible'
        x = torch.rand(7)
        y = torch.rand(B0, 7)
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(op, in_dims=(None, 0))(x, y)

        x = torch.rand(B1, 7)
        y = torch.rand(B0, 7)
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 0))(x, y)

        x = torch.rand(B1, 7)
        y = torch.rand(7, B0)
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 1))(x, y)

        x = torch.rand(B0, 7)
        y = torch.rand(B0, B1, 7)
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(vmap(op, in_dims=(None, 0)))(x, y)

    def test_backward_unsupported_interaction(self):
        x = torch.randn(3, requires_grad=True)
        y = torch.randn(5)
        grad = torch.randn_like(x)
        err_msg = r'backward\(\) called inside a functorch transform'

        def backward_on_vmapped_tensor(x):
            x.sum().backward()

        # FIXME
        return self.skipTest("error: element 0 of tensors does not require grad and does not have a grad_fn")
        with self.assertRaisesRegex(RuntimeError, err_msg):
            vmap(backward_on_vmapped_tensor)(x)

        def backward_with_vmapped_grad(x, grad):
            x.backward(grad)

        with self.assertRaisesRegex(RuntimeError, err_msg):
            vmap(backward_with_vmapped_grad)(x, grad)

        def completely_unrelated_backward(y):
            x.sum().backward()
            return y

        with self.assertRaisesRegex(RuntimeError, err_msg):
            vmap(completely_unrelated_backward)(y)

    @unittest.expectedFailure
    def test_grad_unsupported_interaction(self):
        input_tensor = torch.randn(3, requires_grad=True)
        err_msg = 'autograd.grad.* called inside torch.vmap'

        captured = torch.randn(3, requires_grad=True)

        def output_to_grad_is_vmapped(input_tensor):
            output = (captured * input_tensor).sum()
            return torch.autograd.grad([output], [captured])[0]

        with self.assertRaisesRegex(RuntimeError, err_msg):
            vmap(output_to_grad_is_vmapped)(input_tensor)

        output = (input_tensor ** 2).sum()

        def input_to_grad_is_vmapped(input_tensor):
            return torch.autograd.grad([output], [input_tensor])[0]

        with self.assertRaisesRegex(RuntimeError, err_msg):
            vmap(input_to_grad_is_vmapped)(input_tensor)

    def test_batched_gradient_basic(self):
        N = 3
        x = torch.randn(N, requires_grad=True)
        y = torch.randn(N)

        def vjp_mul(v):
            return torch.autograd.grad([x * y], [x], grad_outputs=[v])[0]

        batched_v = torch.eye(N)
        jacobian = vmap(vjp_mul)(batched_v)
        self.assertEqual(jacobian, torch.diagflat(y))

    def test_functools_partial(self):
        x = torch.randn(3)
        y = torch.randn(2, 3)
        result = vmap(functools.partial(torch.mul, x))(y)
        self.assertEqual(result, x * y)

    def test_nn_module(self):
        tensor = torch.randn(2, 3)
        model = torch.nn.Linear(3, 3, bias=False)
        result = vmap(model)(tensor)
        self.assertEqual(result, model(tensor))

    def test_fallback_with_undefined_grad(self):
        B0 = 7
        x = torch.randn(2, 3, 4, 5, requires_grad=True)
        weight = torch.randn(3, 3, 1, 1)
        v = torch.randn(B0, 2, 3, 4, 5)

        def get_vjp(v):
            result = torch.nn.functional.conv2d(x, weight)
            grad_x, = torch.autograd.grad(result, x, v)
            return grad_x

        # Runs vmap(get_vjp)(v), which should not error out.
        # The backward formula for convolution returns an undefined
        # Tensor for grad_bias because the original bias does not exist.
        #
        # In the future we'll probably add a batching rule for convolution
        # backward. When this happens, we should modify this test to use a
        # different op (and/or create and use a dummy operator) to avoid bitrot.
        self._assert_uses_vmap_fallback([get_vjp], [v])

    def test_reshape_dim_into(self):
        x = torch.randn(2, 3, 5, 7)

        y = reshape_dim_into(0, 0, x)
        self.assertEqual(y, x.reshape(6, 5, 7))

        y = reshape_dim_into(0, 1, x)
        self.assertEqual(y, x.movedim(0, 1).reshape(3, 2 * 5, 7))

        y = reshape_dim_into(0, 2, x)
        self.assertEqual(y, x.movedim(0, 2).reshape(3, 5, 2 * 7))

        y = reshape_dim_into(1, 2, x)
        self.assertEqual(y, x.movedim(1, 2).reshape(2, 5, 3 * 7))

        y = reshape_dim_into(0, -2, x)
        self.assertEqual(y, x.movedim(0, 1).reshape(3, 2 * 5, 7))

        y = reshape_dim_into(0, -1, x)
        self.assertEqual(y, x.movedim(0, 2).reshape(3, 5, 2 * 7))

        y = reshape_dim_into(-4, -1, x)
        self.assertEqual(y, x.movedim(0, 2).reshape(3, 5, 2 * 7))

    def test_reshape_dim_outof(self):
        x = torch.randn(12, 12, 12).permute(2, 1, 0)

        y = reshape_dim_outof(0, 2, x)
        self.assertEqual(y, x.reshape(2, 6, 12, 12))

        y = reshape_dim_outof(1, 4, x)
        self.assertEqual(y, x.reshape(12, 4, 3, 12))

        y = reshape_dim_outof(2, 6, x)
        self.assertEqual(y, x.reshape(12, 12, 6, 2))

        y = reshape_dim_outof(-1, 6, x)
        self.assertEqual(y, x.reshape(12, 12, 6, 2))

    def test_batch_rule_does_not_need_to_handle_no_batched_input(self):
        def f(x, y):
            res = torch.dot(y, torch.ones(2))
            return x + res

        x = torch.randn(7, 5)
        y = torch.randn(3, 2)
        out = vmap(vmap(f, in_dims=(0, None)), in_dims=(None, 0))(x, y)
        expected = torch.mv(y, torch.ones(2)).view(3, 1, 1) + x
        self.assertEqual(out, expected)

    def _test_vmap_autocast(self, device):

        if torch.device(device).type == "cpu":
            amp_dtype = torch.bfloat16
        else:
            amp_dtype = torch.float16

        a_float32 = torch.rand(4, 2, 3, device=device)
        b_float32 = torch.rand(4, 3, 2, device=device)
        c_float32 = torch.rand(4, 2, 2, device=device)
        d_float32 = torch.rand(4, 3, 2, device=device)

        # Case 1, autocast inside vmapped function
        def func1(x, y, z, w):
            with torch.autocast(dtype=amp_dtype, device_type=device):
                e_float16 = torch.matmul(x, y)
                assert e_float16.dtype == amp_dtype, e_float16.dtype
                f_float16 = torch.matmul(z, e_float16)
                assert f_float16.dtype == amp_dtype, f_float16.dtype
            return torch.matmul(w, f_float16.float())

        expected = func1(a_float32, b_float32, c_float32, d_float32)
        out = vmap(func1)(a_float32, b_float32, c_float32, d_float32)
        assert expected.allclose(out)

        # Case 2, autocast decorator inside vmapped function
        @torch.autocast(dtype=amp_dtype, device_type=device)
        def func2(x, y, z, w):
            e_float16 = torch.matmul(x, y)
            assert e_float16.dtype == amp_dtype, e_float16.dtype
            f_float16 = torch.matmul(z, e_float16)
            assert f_float16.dtype == amp_dtype, f_float16.dtype
            return torch.matmul(w, f_float16)

        expected = func2(a_float32, b_float32, c_float32, d_float32)
        out = vmap(func2)(a_float32, b_float32, c_float32, d_float32)
        assert expected.allclose(out)

        # Case 3, autocast is outside vmapped function
        def func3(x, y, z, w):
            e_float16 = torch.matmul(x, y)
            assert e_float16.dtype == amp_dtype, e_float16.dtype
            f_float16 = torch.matmul(z, e_float16)
            assert f_float16.dtype == amp_dtype, f_float16.dtype
            return torch.matmul(w, f_float16)

        with torch.autocast(dtype=amp_dtype, device_type=device):
            expected = func3(a_float32, b_float32, c_float32, d_float32)
            out = vmap(func3)(a_float32, b_float32, c_float32, d_float32)

        assert expected.allclose(out)

    @unittest.skip("Somehow, vmap and autocast do not work on CPU")
    def test_vmap_autocast_cpu(self):
        self._test_vmap_autocast("cpu")

    @skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
    def test_vmap_autocast_cuda(self):
        self._test_vmap_autocast("cuda")


def slice_inputs(inputs, bdims, i):
    result = []
    for inp, bdim in zip(inputs, bdims):
        if bdim is None:
            result.append(inp)
        else:
            result.append(inp.select(bdim, i))
    return tuple(result)


def reference_vmap(op, inputs, in_dims=0, out_dims=0):
    if isinstance(in_dims, int):
        in_dims = (in_dims,) * len(inputs)
    bdim_sizes = [inp.size(dim) for inp, dim in zip(inputs, in_dims) if dim is not None]
    assert all(bdim_size == bdim_sizes[0] for bdim_size in bdim_sizes)
    bdim_size = bdim_sizes[0]
    results = tuple(op(*slice_inputs(inputs, in_dims, i)) for i in range(bdim_size))

    assert len(results) > 0
    op_has_single_return = not isinstance(results[0], tuple)
    if op_has_single_return:
        assert all(isinstance(result, torch.Tensor) for result in results)
        if isinstance(out_dims, int):
            out_dims = (out_dims,) * 1
        return torch.stack(results, dim=out_dims[0])

    assert all(isinstance(result, tuple) for result in results)
    num_returns = len(results[0])
    assert all(len(result) == num_returns for result in results)
    if isinstance(out_dims, int):
        out_dims = (out_dims,) * num_returns
    return tuple(torch.stack(result_shards, out_dim)
                 for result_shards, out_dim in zip(zip(*results), out_dims))


class TensorFactory:
    @staticmethod
    def rand(size, device='cpu', dtype=torch.float):
        return torch.rand(size, device=device, dtype=dtype)

    @staticmethod
    def randn(size, device='cpu', dtype=torch.float):
        return torch.randn(size, device=device, dtype=dtype)

    @staticmethod
    def randp1(size, device='cpu', dtype=torch.float):
        return torch.rand(size, device=device, dtype=dtype) + 1

# Tests vmap(op, in_dims, out_dims)(*inputs) by comparing the output to a
# (slow) sequential map+stack fallback.
#
# check_view: Test if the first returned output is a view of the first input
# check_propagates_grad: Test if the operation propagates gradients.


def _vmap_test(self, op, inputs, in_dims=0, out_dims=0,
               check_view=False, check_propagates_grad=True):
    result = vmap(op, in_dims, out_dims)(*inputs)
    reference_result = reference_vmap(op, inputs, in_dims, out_dims)
    self.assertEqual(result, reference_result)
    op_has_single_return = not isinstance(result, tuple)

    if check_view:
        result_as_tuple = (result,) if op_has_single_return else result
        for output in result_as_tuple:
            input0_base = inputs[0] if inputs[0]._base is None else inputs[0]._base
            self.assertTrue(output._base is input0_base,
                            msg="result was not a view of the first input!")

    if not check_propagates_grad:
        return
    # Assuming input[0] is a floating-point tensor. Check if the vmap
    # operation propagates the requires_grad flag to the zeroth output.
    # Some vmap operators are implemented in a way that assumes that
    # they are composite with respect to autograd. If the operator ever is
    # changed to not be composite with respect to autograd, then the
    # following check should fail.
    inputs_clone = list(inputs)
    inputs_clone[0] = inputs[0].clone().requires_grad_()
    result = vmap(op, in_dims, out_dims)(*inputs_clone)
    result_as_tuple = (result,) if op_has_single_return else result
    self.assertTrue(result[0].requires_grad)


def should_allow_vmap_fallback_usage(fn):
    return getattr(fn, '_allow_vmap_fallback_usage', False)


def allowVmapFallbackUsage(fn):
    fn._allow_vmap_fallback_usage = True
    return fn

# All tests of TestVmapBase check that the slow vmap fallback is never invoked.
# This is so that we can incrementally add batching rules for operators to
# replace the slow vmap fallback path for said operators. To skip this check,
# please use the allowVmapFallbackUsage decorator.
#
# NB: Don't add tests to TestVmapBase directly, unless you want them to run
# on every subclass of TestVmapBase. Add them to e.g. TestVmapOperators.
#
# NB: TestVmapBase is a nested class. This prevents test runners from picking
# it up and running it.


class Namespace:
    class TestVmapBase(TestCase):
        def __init__(self, method_name='runTest'):
            super().__init__(method_name)

            test_method = getattr(self, method_name, None)
            if test_method is None:
                return

            if not should_allow_vmap_fallback_usage(test_method):
                setattr(self, method_name,
                        self._wrap_method_with_vmap_fallback_check(test_method))

        def _wrap_method_with_vmap_fallback_check(self, method):
            # msg = (
            #     'Expected the test to not invoke the vmap fallback path, i.e., '
            #     'all of the operators being tested in this test should have batching '
            #     'rules implemented. If you are intentionally testing something to '
            #     'do with the fallback path, use allowVmapFallbackUsage. Otherwise, '
            #     'please make sure that batching rules are implemented for the '
            #     'operator(s) being tested.'
            # )

            @functools.wraps(method)
            def wrapper(self, *args, **kwargs):
                with warnings.catch_warnings(record=True):
                    warnings.simplefilter('always')
                    with EnableVmapFallbackWarnings():
                        method(*args, **kwargs)
                    # for captured_warning in wa:
                    #     self.assertNotRegex(str(captured_warning.message), FALLBACK_REGEX, msg)
            return types.MethodType(wrapper, self)

        @allowVmapFallbackUsage
        def test_vmap_fallback_check_ok(self):
            # One day we'll implement a batching rule for torch.var_mean.
            # When that happens, please change the example to use an
            # operator that doesn't have a batching rule implemented.
            op_using_fallback = torch.var_mean
            vmap(op_using_fallback)(torch.rand(3))

        @unittest.expectedFailure
        def test_vmap_fallback_check(self):
            @self._wrap_method_with_vmap_fallback_check
            def no_fallback(self):
                pass

            # One day we'll implement a batching rule for torch.var_mean.
            # When that happens, please change the example to use an
            # operator that doesn't have a batching rule implemented.
            op_using_fallback = torch.var_mean

            @self._wrap_method_with_vmap_fallback_check
            def uses_fallback(self):
                vmap(op_using_fallback)(torch.rand(3))

            no_fallback(self)

            with self.assertRaises(AssertionError):
                uses_fallback(self)


def _make_case(op, input_getter=TensorFactory.randn):
    return (op, input_getter)


class TestVmapOperators(Namespace.TestVmapBase):
    def _vmap_test(self, *args, **kwargs):
        return _vmap_test(self, *args, **kwargs)

    def _vmap_view_test(self, *args, **kwargs):
        self._vmap_test(*args, **kwargs, check_view=True)

    def _test_unary(self, op, getter, device, *args, **kwargs):
        test = functools.partial(self._vmap_test, *args, **kwargs)
        B0, B1 = 7, 11

        # Single vmap, various in_dims / out_dims
        test(op, [getter([B0, 3], device)])
        test(op, [getter([2, 5, B0, 3], device)], in_dims=2)
        test(op, [getter([2, 5, B0, 3], device)], in_dims=2, out_dims=2)

        # Doubly nested vmap
        test(vmap(op), [getter([B0, B1], device)])
        test(vmap(op), [getter([B1, 2, 5, B0, 3], device)], in_dims=2)
        test(vmap(op, in_dims=2), [getter([2, 5, B0, B1, 3], device)],
             in_dims=2, out_dims=2)

    @parametrize("case", [
        (torch.abs, TensorFactory.randn),
        (torch.acos, TensorFactory.rand),
        (torch.asin, TensorFactory.rand),
        (torch.atan, TensorFactory.rand),
        (torch.ceil, TensorFactory.randn),
        (torch.cos, TensorFactory.rand),
        (torch.cosh, TensorFactory.rand),
        (torch.digamma, TensorFactory.rand),
        (torch.exp, TensorFactory.randn),
        (torch.expm1, TensorFactory.randn),
        (torch.floor, TensorFactory.randn),
        (torch.frac, TensorFactory.randn),
        (torch.lgamma, TensorFactory.rand),
        (torch.log, TensorFactory.randp1),
        (torch.log10, TensorFactory.randp1),
        (torch.log1p, TensorFactory.randp1),
        (torch.log2, TensorFactory.randp1),
        (torch.neg, TensorFactory.randn),
        (torch.reciprocal, TensorFactory.randp1),
        (torch.relu, TensorFactory.randn),
        (torch.round, TensorFactory.randn),
        (torch.rsqrt, TensorFactory.randp1),
        (torch.sigmoid, TensorFactory.randn),
        (torch.sign, TensorFactory.randn),
        (torch.sin, TensorFactory.rand),
        (torch.sinh, TensorFactory.rand),
        (torch.sqrt, TensorFactory.rand),
        (torch.tan, TensorFactory.rand),
        (torch.tanh, TensorFactory.rand),
        (torch.trunc, TensorFactory.randn),
    ], name_fn=lambda x: x[0].__name__)
    def test_unary_pointwise(self, case):
        op, getter = case
        self._test_unary(op, getter, 'cpu')

        # test in-place
        method = getattr(Tensor, f'{op.__name__ + "_"}')
        self._test_unary(method, getter, 'cpu', check_propagates_grad=False)

    def test_clone(self):
        # Some basic tests
        self._test_unary(lambda x: x.clone(), TensorFactory.randn, 'cpu')
        self._test_unary(lambda x: x.clone(memory_format=torch.preserve_format),
                         TensorFactory.randn, 'cpu')
        self._test_unary(lambda x: x.clone(memory_format=torch.contiguous_format),
                         TensorFactory.randn, 'cpu')

        # Test that the per-examples are contiguous when using torch.contiguous_format
        def clone_contiguous(x):
            return x.clone(memory_format=torch.contiguous_format)

        B0, B1 = 3, 5
        x = torch.randn(2, B0, 7)
        y = vmap(clone_contiguous, in_dims=1, out_dims=1)(x)
        self.assertTrue(y.movedim(1, 0).is_contiguous())
        self.assertTrue(y[:, 0, :].is_contiguous())

        x = torch.randn(2, B0, 7, B1)
        y = vmap(vmap(clone_contiguous, in_dims=2), in_dims=1)(x)
        self.assertTrue(y.is_contiguous())
        self.assertTrue(y[0][0].is_contiguous())

        msg = r'only supported with memory_format torch.preserve_format or torch.contiguous_format'
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(lambda x: x.clone(memory_format=torch.channels_last))(torch.randn(B0))
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(lambda x: x.clone(memory_format=torch.channels_last_3d))(torch.randn(B0))

    def test_weird_matmul_case(self):
        # Check that this doesn't crash.
        # https://github.com/pytorch/functorch/issues/417
        x = torch.randn(5, 2, 2, 2)
        y = torch.randn(5, 7, 2)

        vmap(vmap(torch.matmul, in_dims=(None, 0)))(x, y)

    @parametrize("case",
                 (
                     (torch.clamp_min_, TensorFactory.randn),
                     (torch.clamp_max_, TensorFactory.randn),
                 ), name_fn=lambda x: x[0].__name__)
    def test_clamp_inplace_variant(self, case):
        test = self._vmap_test

        def get_number(getter):
            return getter([]).item()

        op, getter = case
        device = 'cpu'
        B0, B1 = 7, 11

        # Single vmap: op(Tensor, Tensor)
        test(op, (getter([B0, 3], device), getter([B0, 3], device)), check_propagates_grad=False)
        test(op, (getter([B0], device), getter([B0], device)), check_propagates_grad=False)
        test(op, (getter([2, B0, 3], device), getter([2, B0, 3], device)), in_dims=(1, 1), check_propagates_grad=False)
        test(op, (getter([B0, 2, 3], device), getter([2, B0, 3], device)),
             in_dims=(0, 1), out_dims=1, check_propagates_grad=False)
        test(op, (getter([B0, 2, 3], device), getter([1, 1], device)), in_dims=(0, None), check_propagates_grad=False)
        test(op, (getter([B0, 3], device), getter([B0, 3], device)), in_dims=(0, 0), check_propagates_grad=False)

        # Nested vmap: op(Tensor, Tensor)
        test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 1, 3], device)), check_propagates_grad=False)

        # Python number overload: op(Tensor, Number)
        number = get_number(getter)
        self._test_unary(lambda t: op(t, number), getter, device, check_propagates_grad=False)

    @parametrize('case', [
        subtest(_make_case(torch.clamp_min), name='clamp_min'),
        subtest(_make_case(torch.clamp_max), name='clamp_max'),
    ])
    def test_clamp_variant(self, case):
        test = self._vmap_test

        def get_number(getter):
            return getter([]).item()

        op, getter = case
        device = 'cpu'
        B0, B1 = 7, 11

        # Single vmap: op(Tensor, Tensor)
        test(op, (getter([B0, 3], device), getter([B0, 3], device)))
        test(op, (getter([B0], device), getter([B0, 2, 3], device)))
        test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1))
        test(op, (getter([B0], device), getter([2, B0, 3], device)),
             in_dims=(0, 1), out_dims=1)
        test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None))
        test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(None, 0))

        # Nested vmap: op(Tensor, Tensor)
        test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device)))
        test(vmap(op, in_dims=(None, 0)),
             (getter([B0, 2, 3], device), getter([B1, 3], device)), in_dims=(0, None))

        # Python number overload: op(Tensor, Number)
        number = get_number(getter)
        self._test_unary(lambda t: op(t, number), getter, device)

    def test_copy_(self):
        x = torch.randn(3)
        y = torch.randn(3)
        vmap(Tensor.copy_)(x, y)
        self.assertEqual(x, y)

        x = torch.randn(3)
        y = torch.randn(3, 2)
        vmap(Tensor.copy_, in_dims=(1, None))(y, x)
        self.assertEqual(y, x.expand(2, 3).t())

        x = torch.randn(3)
        y = torch.randn(2, 3)
        with self.assertRaisesRegex(RuntimeError, 'inplace'):
            vmap(Tensor.copy_, in_dims=(None, 0))(x, y)

    def test_silu_backward(self):
        test = self._vmap_test
        device = 'cpu'
        getter = TensorFactory.randp1
        B0 = 7
        op = torch.ops.aten.silu_backward

        # Single vmap: op(Tensor, Tensor)
        test(op, (getter([B0, 3], device), getter([B0, 3], device)))
        test(op, (getter([], device), getter([B0], device)), in_dims=(None, 0))
        test(op, (getter([2, B0], device), getter([2], device)), in_dims=(1, None))

    @parametrize('case', [
        subtest(_make_case(torch.add), name='add'),
        subtest(_make_case(lambda x, y: x + y), name='add_dunder'),
        subtest(_make_case(torch.sub), name='sub'),
        subtest(_make_case(lambda x, y: x - y), name='sub_dunder'),
        subtest(_make_case(torch.mul), name='mul'),
        subtest(_make_case(lambda x, y: x * y), name='mul_dunder'),
        subtest(_make_case(torch.div, input_getter=TensorFactory.randp1), name='div'),
        subtest(_make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1), name='div_dunder'),
        subtest(_make_case(torch.pow, input_getter=TensorFactory.randp1), name='pow'),
        subtest(_make_case(lambda x, y: x ** y, input_getter=TensorFactory.randp1), name='pow_dunder'),
    ])
    def test_arithmetic(self, case):
        test = self._vmap_test

        def get_number(getter):
            return getter([]).item()

        op, getter = case
        device = 'cpu'
        B0, B1 = 7, 11

        # Single vmap: op(Tensor, Tensor)
        test(op, (getter([B0, 3], device), getter([B0, 3], device)))
        test(op, (getter([B0], device), getter([B0, 2, 3], device)))
        test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1))
        test(op, (getter([B0], device), getter([2, B0, 3], device)),
             in_dims=(0, 1), out_dims=1)
        test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None))
        test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(0, None))

        # Nested vmap: op(Tensor, Tensor)
        test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device)))
        test(vmap(op, in_dims=(None, 0)),
             (getter([B0, 2, 3], device), getter([B1, 3], device)), in_dims=(0, None))

        # Python number overload: op(Tensor, Number) (and vice-versa)
        number = get_number(getter)
        self._test_unary(lambda t: op(t, number), getter, device)
        number = get_number(getter)
        self._test_unary(lambda t: op(number, t), getter, device)

        # Type promotion: op(Logical Scalar Tensor, Logical Scalar Tensor)
        test(op, (getter([B0], device), getter([B0], device, dtype=torch.double)))
        test(op, (getter([B0], device, dtype=torch.double), getter([B0], device)))
        test(op, (getter([B0], device), getter([B0], device)))

        # Type promotion: op(Tensor, Logical Scalar Tensor) (and vice-versa)
        test(op, (getter([B0, 2], device), getter([B0], device, torch.double)))
        test(op, (getter([B0], device, torch.double), getter([B0, 2], device)))

        if not torch.cuda.is_available():
            return

        # TODO(rzou): fix the following
        # # Test cross-device scalars
        # number = get_number(getter)
        # self._test_unary(lambda t: op(t, number), getter, device='cuda')
        # self._test_unary(lambda t: op(number, t), getter, device='cuda')
        # self._test_unary(lambda t: op(t, torch.tensor(number)), getter, device='cuda')

    def test_as_strided(self):
        def _test(sizes, strides, offset, tensor, lambd):
            # bdim at dim 0 test
            result = vmap(lambda t: t.as_strided(sizes, strides, offset))(tensor)
            expected = vmap(lambd)(tensor)
            self.assertTrue(result._base is expected._base)
            self.assertEqual(result, expected)

            # bdim at dim -1 test
            tensor = tensor.movedim(0, -1)
            result = vmap(lambda t: t.as_strided(sizes, strides, offset), -1)(tensor)
            expected = vmap(lambd, -1)(tensor)
            self.assertTrue(result._base is expected._base)
            self.assertEqual(result, expected)

        # single vmap test
        B0 = 5
        # Each Tensor has shape [B0, 2, 3]; the expressions below
        # are just to get tensors of different strides that have shape [B0, 2, 3]
        tensors = [
            # contiguous
            torch.randn(B0, 2, 3),
            # non-contiguous
            torch.randn(B0, 3, 2).transpose(1, 2),
            torch.randn(3, 2, B0).movedim(-1, 0).transpose(1, 2),
            # non-zero storage offset
            torch.randn(2, B0, 2, 3)[1],
            torch.randn(2, 2, B0, 3)[1].movedim(1, 0),
            # non-contiguous strides, zero storage offset
            torch.randn(B0, 2, 4, 3, 7)[:, :, 0, :, 0],
            torch.randn(2, 4, B0, 3, 7).movedim(2, 0)[:, :, 0, :, 0],
            # non-contiguous strides, non-zero storage offset
            torch.randn(B0, 2, 4, 3, 7)[:, :, 2, :, 1],
            torch.randn(2, 4, 3, 7, B0).movedim(-1, 0)[:, :, 2, :, 1],
        ]

        for x in tensors:
            S0, S1 = x.stride()[1:]
            offset = x.storage_offset()

            # Broadcast
            _test([5, 5, 2, 3], [0, 0, S0, S1], offset, x, lambda x: x.expand(5, 5, 2, 3))
            # transpose
            _test([3, 2], [S1, S0], offset, x, lambda x: x.transpose(0, 1))
            # select
            _test([2], [S0], offset + S1, x, lambda x: x[:, 1])
            # diagonal
            _test([2], [S0 + S1], offset, x, lambda x: x.diagonal())
            # strided slice
            _test([2], [S1 * 2], offset, x, lambda x: x[0, ::2])

        # Nested vmap test
        B1 = 7
        x = torch.randn(B1, B0, 2, 3)
        S0, S1 = x.stride()[2:]
        result = vmap(vmap(lambda t: t.as_strided([5, 5, 2, 3], [0, 0, S0, S1])), in_dims=1)(x)
        expected = vmap(vmap(lambda t: t.expand(5, 5, 2, 3)), in_dims=1)(x)
        self.assertTrue(result._base is expected._base)
        self.assertEqual(result, expected)

        # Check that mal-formatted size/strides doesn't crash
        with self.assertRaisesRegex(RuntimeError, 'size and stride must have the same length'):
            x = torch.randn(B0, 2, 3).transpose(0, 1)
            vmap(lambda x: x.as_strided([1, 1, 1], [1, 1]))(x)

        # All the Sanity check #1{a,b,c} cases check that
        # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
        # doesn't index memory that is out of bounds of xs[i]. This condition
        # is important to the correctness of the as_strided batching rule
        # (see NOTE: [When will the as_strided_batching_rule fail?])

        # Sanity check #1a: The maximum indexable location of
        # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
        # is less than or equal to the maximum indexable location of xs[i].
        msg = 'This is not supported inside of vmap'
        with self.assertRaisesRegex(RuntimeError, msg):
            x = torch.randn(B0, 3)
            vmap(lambda x: x.as_strided([3], [1], 1))(x)
        with self.assertRaisesRegex(RuntimeError, msg):
            x = torch.randn(B0, 3, 5)
            vmap(lambda x: x.as_strided([4, 4], [4, 1], 0))(x)
        with self.assertRaisesRegex(RuntimeError, msg):
            x = torch.randn(B0, B1, 3, 5)
            vmap(vmap(lambda x: x.as_strided([4, 4], [4, 1], 0)))(x)

        # Sanity check #1b: The min indexable location of
        # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
        # is greater than or equal to the min indexable location of xs[i].
        with self.assertRaisesRegex(RuntimeError, msg):
            x = torch.randn(2, B0, 3)[1]
            vmap(lambda x: x.as_strided([3], [1], B0 * 3 - 1))(x)

        # Sanity check #1c:
        # xs[i] is a zero-dim tensor, but
        # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
        # is not
        with self.assertRaisesRegex(RuntimeError, msg):
            x = torch.randn(B0, 0, 3)
            vmap(lambda x: x.as_strided([3], [1]))(x)

    def test_nll_loss(self):
        test = self._vmap_test
        op = F.nll_loss
        B = 3

        y = torch.randn(B, 2, 5)
        t = torch.randint(0, 5, (B, 2))
        test(op, (y, t))
        test(functools.partial(op, reduction='sum'), (y, t))
        test(functools.partial(op, reduction='none'), (y, t))

        y = torch.randn(B, 2, 5)
        t = torch.randint(0, 5, (2,))
        test(op, (y, t), in_dims=(0, None))
        test(functools.partial(op, reduction='sum'), (y, t), in_dims=(0, None))
        test(functools.partial(op, reduction='none'), (y, t), in_dims=(0, None))

    def test_adaptive_avg_pool2d(self):
        test = self._vmap_test
        op = functools.partial(F.adaptive_avg_pool2d, output_size=(3, 3))

        x = torch.randn(3, 5, 7, 9, 11)
        test(op, (x,))
        test(op, (x,), in_dims=(1,))
        test(op, (x,), in_dims=(4,))

    def test_bmm(self):
        op = torch.bmm
        test = self._vmap_test
        B0, B1 = 7, 11

        # shape mismatch
        msg = ""
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(op, in_dims=(0, None))(torch.randn(B0, 3, 3, 2), torch.randn(2, 2))
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2))

        # left arg is vmapped
        test(op, (torch.rand(B0, 2, 3, 5), torch.rand(2, 5, 3)), in_dims=(0, None))
        test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 2, 3, 5), torch.rand(2, 5, 3)),
             in_dims=(1, None))

        # right arg is vmapped
        test(op, (torch.rand(2, 5, 3), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0))
        test(vmap(op, in_dims=(None, 0)), (torch.rand(2, 5, 3), torch.rand(B1, B0, 2, 3, 5)),
             in_dims=(None, 1))

        # both args are vmapped
        test(op, (torch.rand(B0, 2, 3, 5), torch.rand(B0, 2, 5, 3)))
        test(vmap(op), (torch.rand(B1, B0, 2, 3, 5), torch.rand(B0, B1, 2, 5, 3)), in_dims=(1, 0))
        test(vmap(op, in_dims=(0, None)),
             (torch.rand(B1, 2, 3, 5), torch.rand(B0, 2, 5, 3)), in_dims=(None, 0))

    def test_cat(self):
        test = self._vmap_test
        B0, B1 = 5, 7

        # Quick hack b/c vmap can't accept a list of tensors as an argument
        def get_op(dim):
            def op(*tensors):
                return torch.cat(tensors, dim=dim)
            return op

        test(get_op(0), (torch.rand(B0, 2), torch.rand(B0, 3)))
        test(get_op(0), (torch.rand(2), torch.rand(B0, 3)), in_dims=(None, 0))
        test(get_op(0), (torch.rand(2, 17), torch.rand(3, 17, B0)), in_dims=(None, 2))
        test(get_op(-1), (torch.rand(17, 2), torch.rand(17, 3, B0)), in_dims=(None, 2))
        test(vmap(get_op(0), in_dims=(0, None)),
             (torch.rand(B1, 2), torch.rand(B0, 3)), in_dims=(None, 0))
        test(vmap(get_op(0), in_dims=(0, 0)),
             (torch.rand(B1, 2), torch.rand(B0, B1, 3)), in_dims=(None, 0))

    def test_unsafe_view(self):
        # Unsafe view isn't exposed, so we get at it via
        # vmap(grad(matmul))
        test = functools.partial(self._vmap_test, check_propagates_grad=False)
        B = 2
        x = torch.randn(B, 2, 3, 3)
        y = torch.randn(B, 3, 3)

        def baz(x, y):
            return (x @ y).sum()

        test(functorch.grad(baz), (x, y))

    def test_conj(self):
        op = torch.conj

        def run_test(dtype):
            def get(shape):
                return torch.randn(shape, dtype=dtype)
            B0, B1 = 7, 11
            test = self._vmap_test

            # Single vmap, various in_dims / out_dims
            test(op, [get([B0, 3])])
            test(op, [get([2, 5, B0, 3])], in_dims=2)
            test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2)

            # Doubly nested vmap
            test(vmap(op), [get([B0, B1])])
            test(vmap(op), [get([B1, 2, 5, B0, 3])], in_dims=2)
            test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])],
                 in_dims=2, out_dims=2)

        # correctness tests
        run_test(torch.float)
        run_test(torch.cfloat)

        # check that torch.conj on a non-complex tensor returns the same tensor
        real_tensor = torch.randn(3)
        result = vmap(op)(real_tensor)
        self.assertEqual(result.data_ptr(), real_tensor.data_ptr())

    def test_contiguous(self):
        op = Tensor.contiguous

        self._test_unary(op, TensorFactory.randn, 'cpu')

        # check that contiguous returns the original tensor if the per-examples
        # are already contiguous
        B0 = 3
        x = torch.randn(B0, 2, 5, 7)
        x = x.movedim(0, 2)
        result = vmap(Tensor.contiguous, in_dims=2, out_dims=2)(x)
        self.assertTrue(result is x)

        msg = 'NYI: querying is_contiguous inside of vmap for memory_format'
        tensor = torch.randn(B0, 3)
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(functools.partial(op, memory_format=torch.channels_last))(tensor)
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(functools.partial(op, memory_format=torch.channels_last_3d))(tensor)

    def test_stride(self):
        B0 = 3

        x = torch.randn(B0, 2, 5, 7)

        def foo(x):
            assert x.stride() == (7 * 5, 7, 1)
            return x

        vmap(foo)(x)

        x = torch.randn(2, B0, 5, 7).movedim(1, 0)

        def bar(x):
            assert x.stride() == (7 * 5 * B0, 7, 1)
            return x

        vmap(bar)(x)

    def test_chunk(self):
        test = self._vmap_view_test
        op = torch.chunk
        B0, B1, B2 = 7, 11, 13

        # tests for torch.split(self, split_size: int, dim)
        test(op, (torch.rand(B0, 2, 1024), 15, -1), in_dims=(0, None, None))
        test(op, (torch.rand(2, B0, 1024), 9, 1), in_dims=(1, None, None))
        test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 4, 0),
             in_dims=(2, None, None))
        test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
             (torch.rand(B1, 2, B0, 64, B2),), in_dims=2)

    def test_clamp(self):
        clamp_cases = (
            (lambda t: t.clamp(min=-0.5), TensorFactory.randn),
            (lambda t: t.clamp(max=0.5), TensorFactory.randn),
            (lambda t: t.clamp(min=-0.5, max=0.5), TensorFactory.randn),
            (lambda t: t.clamp_min(min=-0.5), TensorFactory.randn),
            (lambda t: t.clamp_max(max=0.5), TensorFactory.randn),
        )
        for op, getter in clamp_cases:
            self._test_unary(op, getter, 'cpu')

    def test_comparison_ops(self):
        test = functools.partial(self._vmap_test, check_propagates_grad=False)

        getter = TensorFactory.randn
        B0, B1 = 7, 11

        ops = (
            torch.eq, lambda x, y: x == y,
            torch.gt, lambda x, y: x > y,
            torch.ge, lambda x, y: x >= y,
            torch.le, lambda x, y: x <= y,
            torch.lt, lambda x, y: x < y,
            torch.ne, lambda x, y: x != y,
        )

        for op in ops:
            # Single vmap: op(Tensor, Tensor)
            test(op, (getter([B0, 3]), getter([B0, 3])))
            test(op, (getter([B0]), getter([B0, 2, 3])))
            test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1))
            test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1), out_dims=1)
            test(op, (getter([B0]), getter([2, 3])), in_dims=(0, None))
            test(op, (getter([2, 3]), getter([B0, 3])), in_dims=(0, None))

            # Nested vmap: op(Tensor, Tensor)
            test(vmap(op), (getter([B0, B1, 2, 3]), getter([B0, B1, 3])))
            test(vmap(op, in_dims=(None, 0)),
                 (getter([B0, 2, 3]), getter([B1, 3])), in_dims=(0, None))

            # test number as inputs
            number = getter([]).item()
            self._test_unary(lambda t: op(t, number), getter, 'cpu', check_propagates_grad=False)

    def test_cross_batch_size_three(self):
        # Let's test corner case when batch_size is 3 and cross' dim argument is not specified
        # According to the cross API, dim will be assigned to the first dim with value 3
        # In this test we ensure that found dim is not batch dim.
        op = torch.cross
        test = self._vmap_test
        B0 = B1 = 3
        test(op, (torch.rand(B0, 2, 3), torch.rand(B0, 2, 3)))
        test(vmap(op, in_dims=(0, None)), (torch.rand(B0, B1, 2, 3), torch.rand(B0, B1, 2, 3)),
             in_dims=(None, 1))

    def test_diagonal(self):
        tensor = torch.randn(3, 5, 7, 11, 13)
        test = self._vmap_view_test
        op = torch.diagonal
        test(op, (tensor, 1, 0, 1), in_dims=(0, None, None, None))
        test(op, (tensor, 0, 2, -1), in_dims=(0, None, None, None))
        test(op, (tensor, 2, 1, 2), in_dims=(1, None, None, None))
        test(op, (tensor, 0, -2, -1), in_dims=(1, None, None, None), out_dims=1)
        test(vmap(lambda t: op(t, 0, 0, -1)), (tensor,), in_dims=1, out_dims=1)
        test(vmap(vmap(lambda t: op(t, 0, 0, 1), in_dims=1), in_dims=3),
             (tensor,), in_dims=1, out_dims=1)

    def test_dot(self):
        op = torch.dot
        test = self._vmap_test
        B0, B1 = 7, 11

        # shape mismatch
        msg = ""
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2))
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2))

        # left arg is vmapped
        test(op, (torch.rand(B0, 5), torch.rand(5)), in_dims=(0, None))
        test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 5), torch.rand(5)),
             in_dims=(1, None))

        # right arg is vmapped
        test(op, (torch.rand(5), torch.rand(B0, 5)), in_dims=(None, 0))
        test(vmap(op, in_dims=(None, 0)), (torch.rand(5), torch.rand(B1, B0, 5)),
             in_dims=(None, 1))

        # both args are vmapped
        test(op, (torch.rand(B0, 5), torch.rand(B0, 5)))
        test(vmap(op), (torch.rand(B1, B0, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0))
        test(vmap(op, in_dims=(0, None)),
             (torch.rand(B1, 5), torch.rand(B0, 5)), in_dims=(None, 0))

    def test_expand_as(self):
        op = torch.Tensor.expand_as
        test = self._vmap_view_test
        B0, B1, B2 = 7, 11, 13
        test(op, (torch.rand(B0, 1, 5), torch.rand(B0, 2, 3, 5)))
        test(op, (torch.rand(B0, 1, 5), torch.rand(2, 3, 5)), in_dims=(0, None))
        test(op, (torch.rand(1, 5), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0))
        test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B0, B1, 2, 3, 5)))
        test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B1, B0, 2, 3, 5)), in_dims=(0, 1))
        test(vmap(op), (torch.rand(B0, B1), torch.rand(B1, 2, 3, 5)), in_dims=(0, None))
        test(vmap(vmap(op)), (torch.rand(B0, B1, B2), torch.rand(B0, B1, B2, 2, 3, 5)))

    def test_fill_and_zero_inplace(self):
        test = functools.partial(self._vmap_test, check_propagates_grad=False)
        B0, B1 = 7, 11
        ops = (
            lambda t: t.fill_(0.1),
            lambda t: t.fill_(torch.tensor(0.2)),
            lambda t: t.zero_(),
        )

        for op in ops:
            # Single vmap, various in_dims / out_dims
            test(op, [TensorFactory.randn([B0, 3])])
            test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2)
            test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2, out_dims=2)

            # Doubly nested vmap
            test(vmap(op), [TensorFactory.randn([B0, B1])])
            test(vmap(op), [TensorFactory.randn([B1, 2, 5, B0, 3])], in_dims=2)
            test(vmap(op, in_dims=2), [TensorFactory.randn([2, 5, B0, B1, 3])],
                 in_dims=2, out_dims=2)

        # test when value is a batched tensor for fill_ operator
        B0, B1 = 3, 5
        test(Tensor.fill_, [TensorFactory.randn([B0, B1]), TensorFactory.randn(B0)])

        with self.assertRaisesRegex(RuntimeError,
                                    ""):
            # Runtime Error is thrown when the tensor being written to isn't being vmapped over
            vmap(Tensor.fill_, (None, 0))(TensorFactory.randn([B0, B1]),
                                          TensorFactory.randn([B0]))

    def _test_complex_views(self, op, dtypes):
        test = self._vmap_view_test

        def run_test(op, dtype):
            def get(shape):
                return torch.randn(shape, dtype=dtype)

            B0, B1 = 7, 11

            # Single vmap, various in_dims / out_dims
            test(op, [get([B0, 3])])
            test(op, [get([3, B0])], in_dims=1)
            test(op, [get([2, 5, B0, 3])], in_dims=2)
            test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2)

            # Doubly nested vmap
            test(vmap(op), [get([B0, B1])])
            test(vmap(op), [get([B1, 2, 5, 3, B0])], in_dims=4)
            test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])],
                 in_dims=2, out_dims=2)

        for dtype in dtypes:
            run_test(op, dtype)

    def test_real(self):
        self._test_complex_views(torch.real, dtypes=[torch.cfloat, torch.cdouble])

    def test_imag(self):
        self._test_complex_views(torch.imag, dtypes=[torch.cfloat, torch.cdouble])

    def test_view_as_real(self):
        self._test_complex_views(torch.view_as_real, dtypes=[torch.cfloat, torch.cdouble])

    def test_view_as_complex(self):
        def run_test(dtype):
            def get(shape):
                return torch.randn(shape, dtype=dtype)

            op = torch.view_as_complex
            test = self._vmap_view_test
            B0, B1 = 7, 11

            # Single vmap, various in_dims / out_dims
            test(op, [get([B0, 3, 2])])
            test(op, [get([2, 5, B0, 3, 2])], in_dims=2)
            test(op, [get([2, 5, B0, 3, 2])], in_dims=2, out_dims=2)

            # Doubly nested vmap
            test(vmap(op), [get([B0, B1, 2])])
            test(vmap(op), [get([B1, 2, 5, B0, 3, 2])], in_dims=2)
            test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3, 2])],
                 in_dims=2, out_dims=2)

            # Interesting case #1: Batch dim directly before dim of size 2
            test(op, [get([3, B0, 2])], in_dims=1)
            test(vmap(op, in_dims=1), [get([3, B1, B0, 2])], in_dims=2)

            # Interesting case #2: Batch dim at end of tensor, success cases
            # view_as_complex requires that the dim with size 2 have stride 1
            # in order for the view to function propertly
            test(op, [get([B0, 2]).transpose(0, 1)], in_dims=1)
            test(vmap(op, in_dims=1), [get([B0, B1, 2]).movedim(1, 2)])
            test(vmap(op, in_dims=2), [get([B0, 3, B1, 2]).movedim(2, 3)])

            # Interesting case #3: Batch dim at end of tensor, failure cases
            msg = "Tensor must have a last dimension with stride 1"
            with self.assertRaisesRegex(RuntimeError, msg):
                vmap(op, in_dims=1)(get([2, B0]))
            with self.assertRaisesRegex(RuntimeError, msg):
                vmap(vmap(op, in_dims=1), in_dims=1)(get([2, B0, B1]))

            # Invalid input: no dimension of size 2
            msg = 'Input tensor must have one or more dimensions'
            with self.assertRaisesRegex(RuntimeError, msg):
                vmap(op)(get([B0]))
            with self.assertRaisesRegex(RuntimeError, msg):
                vmap(vmap(op))(get([B0, B1]))

            # Invalid input: Batch dim has size 2, but the logical last dim does
            # not have size 2
            msg = 'Tensor must have a last dimension of size 2'
            with self.assertRaisesRegex(RuntimeError, msg):
                vmap(op, in_dims=1)(get([3, 2]))

        for dtype in [torch.float, torch.double]:
            run_test(dtype)

    def test_is_complex(self):
        ctensor = torch.randn(3, dtype=torch.cfloat)
        tensor = torch.randn(3)

        def foo(x):
            if x.is_complex():
                return torch.tensor(1)
            else:
                return torch.tensor(0)

        self.assertEqual(vmap(foo)(ctensor), torch.tensor([1, 1, 1]))
        self.assertEqual(vmap(foo)(tensor), torch.tensor([0, 0, 0]))

    def test_is_floating_point(self):
        float_tensor = torch.tensor([1., 2., 3.])
        long_tensor = torch.tensor([1, 2, 3])

        def foo(x):
            if x.is_floating_point():
                return torch.tensor(1)
            else:
                return torch.tensor(0)

        self.assertEqual(vmap(foo)(float_tensor), torch.tensor([1, 1, 1]))
        self.assertEqual(vmap(foo)(long_tensor), torch.tensor([0, 0, 0]))

    def test_is_contiguous(self):
        def foo(x):
            if x.is_contiguous():
                return torch.tensor(1.)
            else:
                return torch.tensor(0.)

        B0, B1 = 3, 5

        # Single batch dim
        contig = torch.randn(B0, 2, 7)
        self.assertEqual(vmap(foo)(contig), torch.ones(B0))

        noncontig = torch.randn(2, B0, 7)
        self.assertEqual(vmap(foo, in_dims=1)(noncontig), torch.zeros(B0))

        noncontig = torch.randn(2, B0, 7).movedim(1, 0)
        self.assertEqual(vmap(foo)(noncontig), torch.zeros(B0))

        noncontig = torch.randn(2, 7, B0)
        self.assertEqual(vmap(foo, in_dims=2)(noncontig), torch.zeros(B0))

        # Multiple batch dims
        contig = torch.randn(B0, B1, 3)
        self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1))

        contig = torch.randn(B1, B0, 3)
        self.assertEqual(vmap(vmap(foo), in_dims=1)(contig), torch.ones(B0, B1))

        contig = torch.randn(B1, B0, 3).movedim(0, 1)
        self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1))

        noncontig = torch.randn(B0, 3, B1)
        self.assertEqual(vmap(vmap(foo, in_dims=1))(noncontig), torch.zeros(B0, B1))

        # is_contiguous on empty tensor is True
        def bar(x):
            assert x.is_contiguous()
            return x

        vmap(bar)(torch.randn(B0, 0, 3))
        vmap(bar, in_dims=1)(torch.randn(0, B0, 3))
        vmap(bar)(torch.randn(B0, 0, 3).transpose(-1, -2))

        # is_contiguous with other memory formats
        def baz(x, memory_format):
            x.is_contiguous(memory_format=memory_format)
            return x

        msg = 'NYI: querying is_contiguous inside of vmap for memory_format'
        tensor = torch.randn(B0, 2, 7, 3)
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(functools.partial(baz, memory_format=torch.channels_last))(tensor)
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor)

    def test_unsqueeze(self):
        op = torch.unsqueeze
        test = self._vmap_view_test
        B0, B1 = 7, 11

        # unsqueeze dim 0
        test(op, (torch.rand(B0, 2, 5), 0), in_dims=(0, None))
        test(op, (torch.rand(2, B0, 5), 0), in_dims=(1, None))

        # unsqueeze last dim (positive)
        test(op, (torch.rand(B0, 2, 5), 2), in_dims=(0, None))
        test(op, (torch.rand(2, B0, 5), 2), in_dims=(1, None))

        # unsqueeze last dim (negative)
        test(op, (torch.rand(B0, 2, 5), -1), in_dims=(0, None))
        test(op, (torch.rand(2, B0, 5), -1), in_dims=(1, None))

        # nested vmaps
        def unsqueeze_0(x):
            return torch.unsqueeze(x, 0)

        def unsqueeze_last(x):
            return torch.unsqueeze(x, -1)

        # bdims in canonical order
        test(vmap(unsqueeze_0), (torch.rand(B0, B1, 2), ))
        test(vmap(unsqueeze_last), (torch.rand(B0, B1, 2),))

        # wild bdims
        test(vmap(unsqueeze_0), (torch.rand(B1, 2, B0),), in_dims=2)
        test(vmap(unsqueeze_0, in_dims=1), (torch.rand(2, B1, B0),), in_dims=2)
        test(vmap(unsqueeze_last), (torch.rand(B1, 2, B0),), in_dims=2)
        test(vmap(unsqueeze_last, in_dims=1), (torch.rand(2, B1, B0),), in_dims=2)

    def test_movedim(self):
        op = torch.movedim
        test = self._vmap_view_test
        B0, B1, B2 = 7, 11, 13

        # movedim(tensor, int, int) variant
        test(op, (torch.rand(B0, 2, 5), 0, 1), in_dims=(0, None, None))
        test(op, (torch.rand(2, B0, 5), 0, 1), in_dims=(1, None, None))
        test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 2, B0, 5), 0, 1), in_dims=(2, None, None))
        test(vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)),
             (torch.rand(B1, 2, B0, 5, B2), 0, 1), in_dims=(2, None, None))

        # movedim(tensor, intlist, intlist) variant
        test(op, (torch.rand(B0, 2, 3, 5), [1, 0], [0, 2]), in_dims=(0, None, None))
        test(op, (torch.rand(2, 3, B0, 5), [1, 0], [0, 2]), in_dims=(1, None, None))
        test(vmap(op, in_dims=(0, None, None)),
             (torch.rand(B1, 2, B0, 5), [0, 1], [1, 0]), in_dims=(2, None, None))
        test(vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)),
             (torch.rand(B1, 2, B0, 5, B2), [0, 1], [1, 0]), in_dims=(2, None, None))

    def test_mm(self):
        op = torch.mm
        test = self._vmap_test
        B0, B1 = 7, 11

        # shape mismatch
        msg = "Shape mismatch"
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2))
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2))

        # left arg is vmapped
        test(op, (torch.rand(B0, 2, 5), torch.rand(5, 2)), in_dims=(0, None))
        test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 2, 5), torch.rand(5, 2)),
             in_dims=(1, None))

        # right arg is vmapped
        test(op, (torch.rand(2, 5), torch.rand(B0, 5, 2)), in_dims=(None, 0))
        test(vmap(op, in_dims=(None, 0)), (torch.rand(2, 5), torch.rand(B1, B0, 5, 2)),
             in_dims=(None, 1))

        # both args are vmapped
        test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5, 2)))
        test(vmap(op), (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5, 2)), in_dims=(1, 0))
        test(vmap(op, in_dims=(0, None)),
             (torch.rand(B1, 2, 5), torch.rand(B0, 5, 2)), in_dims=(None, 0))

    def test_mv(self):
        op = torch.mv
        test = self._vmap_test
        B0, B1 = 7, 11

        # shape mismatch
        msg = ""
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(op, in_dims=(0, None))(torch.randn(B0, 2, 2), torch.randn(2, 2))
        with self.assertRaisesRegex(RuntimeError, msg):
            vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2))

        # left arg is vmapped
        test(op, (torch.rand(B0, 2, 5), torch.rand(5)), in_dims=(0, None))
        test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 2, 5), torch.rand(5)),
             in_dims=(1, None))

        # right arg is vmapped
        test(op, (torch.rand(2, 5), torch.rand(B0, 5)), in_dims=(None, 0))
        test(vmap(op, in_dims=(None, 0)), (torch.rand(2, 5), torch.rand(B1, B0, 5)),
             in_dims=(None, 1))

        # both args are vmapped
        test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5)))
        test(vmap(op), (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0))
        test(vmap(op, in_dims=(0, None)),
             (torch.rand(B1, 2, 5), torch.rand(B0, 5)), in_dims=(None, 0))

    def test_narrow(self):
        op = torch.narrow
        test = self._vmap_view_test
        B0, B1, B2 = 7, 11, 13

        test(op, (torch.rand(B0, 2, 5), -1, 1, 3), in_dims=(0, None, None, None))
        test(op, (torch.rand(2, B0, 5), 1, 1, 3), in_dims=(1, None, None, None))
        test(vmap(op, in_dims=(0, None, None, None)),
             (torch.rand(B1, 2, B0, 5), 1, 0, 0), in_dims=(2, None, None, None))
        test(vmap(vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)),
             (torch.rand(B1, 2, B0, 5, B2), -1, 2, 3), in_dims=(2, None, None, None))

    def test_new_empty(self):
        # Empty is non-deterministic so we just check that the shape of the
        # output tensor is what we expect and that the vmap fallback isn't used.
        op = Tensor.new_empty

        B0, B1 = 7, 11

        result = vmap(lambda x: op(x, [2, 3]))(torch.randn(B0))
        self.assertEqual(result.shape, [B0, 2, 3])

        result = vmap(lambda x: op(x, []))(torch.randn(B0))
        self.assertEqual(result.shape, [B0])

        result = vmap(vmap(lambda x: op(x, [2, 3])))(torch.randn(B0, B1))
        self.assertEqual(result.shape, [B0, B1, 2, 3])

    def test_new_empty_strided(self):
        # Empty is non-deterministic so we just check that the size and shape
        # of the output are what we expect and that the vmap fallback isn't used
        B0, B1 = 7, 11

        def _test_single_vmap(size, stride, B0):
            x = torch.randn(B0)
            result = vmap(lambda x: x.new_empty_strided(size, stride))(x)
            S = torch.empty_strided(size, stride).storage().size()
            self.assertEqual(result.shape, [B0] + size)
            self.assertEqual(result.stride(), [S] + stride)

        def _test_double_vmap(size, stride, B0, B1):
            x = torch.randn(B0, B1)
            result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)))(x)
            S = torch.empty_strided(size, stride).storage().size()
            self.assertEqual(result.shape, [B0, B1] + size)
            self.assertEqual(result.stride(), [B1 * S, S] + stride)

            x = torch.randn(B1, B0)
            result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)), in_dims=1)(x)
            S = x.new_empty_strided(size, stride).storage().size()
            self.assertEqual(result.shape, [B0, B1] + size)
            self.assertEqual(result.stride(), [B1 * S, S] + stride)

        # contiguous case
        _test_single_vmap([2, 3, 5], [3 * 5, 5, 1], B0)
        _test_double_vmap([2, 3, 5], [3 * 5, 5, 1], B0, B1)

        # expanded
        _test_single_vmap([2, 3, 5], [0, 5, 1], B0)
        _test_double_vmap([2, 3, 5], [0, 5, 1], B0, B1)

        # some of these cases are pretty strange, just verifying that if
        # empty_strided allows them then BatchedTensor.new_empty_strided
        # can as well
        for shape in [[2, 3, 4], [0, 2, 0]]:
            for strides in [[12, 4, 1], [2, 4, 6], [0, 0, 0]]:
                _test_single_vmap(shape, strides, B0)
                _test_double_vmap(shape, strides, B0, B1)

    def test_new_zeros(self):
        op = Tensor.new_zeros
        test = functools.partial(self._vmap_test, check_propagates_grad=False)
        B0, B1 = 7, 11

        test(lambda x: op(x, 2, 3), (torch.rand(B0),))
        test(lambda x: op(x, []), (torch.rand(B0),))
        test(vmap(lambda x: op(x, 3, 5)), (torch.rand(B0, B1),))

    def test_select(self):
        op = torch.select
        test = self._vmap_view_test
        B0, B1, B2 = 7, 11, 13
        test(op, (torch.rand(B0, 2, 5), 0, 0), in_dims=(0, None, None))
        test(op, (torch.rand(2, B0, 5), 1, 1), in_dims=(1, None, None))
        test(vmap(lambda t: op(t, 1, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
        test(vmap(vmap(lambda t: op(t, 1, 1), in_dims=1)), (torch.rand(B1, 2, B0, B2, 5),), in_dims=2)

    def test_roll_no_dims(self):
        op = torch.roll
        test = self._vmap_test
        B0, B1, B2 = 7, 11, 13
        test(op, (torch.rand(B0, 2, 5), 2), in_dims=(0, None))
        test(op, (torch.rand(2, B0, 5), 3), in_dims=(1, None))
        test(vmap(lambda t: op(t, 3)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
        test(vmap(vmap(lambda t: op(t, 3), in_dims=1)), (torch.rand(B1, 2, B0, B2, 5),), in_dims=2)

    def test_stack(self):
        test = self._vmap_test
        B0, B1 = 5, 7

        # Quick hack b/c vmap can't accept a list of tensors as an argument
        def get_op(dim):
            def op(*tensors):
                return torch.stack(tensors, dim=dim)
            return op

        test(get_op(0), (torch.rand(B0, 3), torch.rand(B0, 3)))
        test(get_op(0), (torch.rand(3), torch.rand(B0, 3)), in_dims=(None, 0))
        test(get_op(0), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2))
        test(get_op(-1), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2))
        test(vmap(get_op(0), in_dims=(0, None)),
             (torch.rand(B1, 2), torch.rand(B0, 2)), in_dims=(None, 0))
        test(vmap(get_op(0), in_dims=(0, 0)),
             (torch.rand(B1, 2), torch.rand(B0, B1, 2)), in_dims=(None, 0))

    def test_slice(self):
        test = self._vmap_view_test
        B0, B1, B2 = 7, 11, 13
        test(lambda t: t[0:1], (torch.rand(B0, 3, 5),))
        test(lambda t: t[:, 1:3], (torch.rand(3, 5, B0),), in_dims=2)
        test(vmap(lambda t: t[:, 0:1], in_dims=2), (torch.rand(3, 5, B0, B1),), in_dims=2)
        test(vmap(vmap(lambda t: t[0:1], in_dims=2), in_dims=2),
             (torch.rand(3, 5, B0, B1, B2),), in_dims=2)

    def test_squeeze(self):
        def verify_behavior(op, min_ndim=1):
            test = self._vmap_view_test
            B0, B1 = 1, 11
            # These tests cannot be used with an operator that requires more
            # than 1 dimension after batching.
            if min_ndim <= 1:
                test(op, (torch.rand(B0),))
                test(op, (torch.rand(B1),))
                test(vmap(op), (torch.rand(B0, B1, 1),))
                test(vmap(op), (torch.rand(B1, 1, B0),), in_dims=2)
            test(op, (torch.rand(B0, 3, 5),))
            test(op, (torch.rand(1, B0, 5),), in_dims=1)
            test(op, (torch.rand(B0, 0, 1, 5, 1),))
            test(op, (torch.rand(B0, 1, 1, 1, 1),))
            test(vmap(op), (torch.rand(B0, B1, 1, 3, 4),))
            test(vmap(op), (torch.rand(B1, 1, B0, 4, 5),), in_dims=2)

        verify_behavior(torch.squeeze)
        verify_behavior(lambda x: torch.squeeze(x, dim=0), min_ndim=1)
        verify_behavior(lambda x: torch.squeeze(x, dim=1), min_ndim=2)
        verify_behavior(lambda x: torch.squeeze(x, dim=-1), min_ndim=2)
        verify_behavior(lambda x: torch.squeeze(x, dim=-2), min_ndim=3)

        msg = ""
        try:
            torch.squeeze(torch.rand(10), dim=1)
        except IndexError as err:
            msg = str(err)
        with self.assertRaises(RuntimeError, msg=msg):
            vmap(lambda x: torch.squeeze(x, dim=1))(torch.rand(10))

    def _test_mean_sum_dim(self, op):
        test = self._vmap_test
        B0, B1 = 5, 7

        # Single vmap, various in_dims / out_dims
        test(lambda x: op(x, 0), [torch.randn([B0])])
        test(lambda x: op(x, -1), [torch.randn([B0])])
        test(lambda x: op(x, 0), [torch.randn([B0, 3])])
        test(lambda x: op(x, -1), [torch.randn([2, 5, B0, 3])], in_dims=2)
        test(lambda x: op(x, 2), [torch.randn([2, 5, B0, 3])], in_dims=2, out_dims=2)

        # Doubly nested vmap
        test(vmap(lambda x: op(x, 0)), [torch.randn([B0, B1])])
        test(vmap(lambda x: op(x, -1)), [torch.randn([B0, B1])])
        test(vmap(lambda x: op(x, -2)), [torch.randn([B1, 2, 5, B0, 3])], in_dims=2)
        test(vmap(lambda x: op(x, 2), in_dims=2), [torch.randn([2, 5, B0, B1, 3])],
             in_dims=2, out_dims=2)

    def test_sum_dim(self):
        self._test_mean_sum_dim(torch.sum)

    def test_mean_dim(self):
        self._test_mean_sum_dim(torch.mean)

    def test_argmax_dim(self):
        def test(f, args):
            for loop_out, batched_out in get_fallback_and_vmap_exhaustive(f, args, {}):
                self.assertEqual(loop_out, batched_out)
        B0 = 5
        test(lambda x: torch.argmax(x), [torch.randn(B0)])
        test(lambda x: torch.argmax(x), [torch.randn(B0, 2, 3)])
        test(lambda x: torch.argmax(x, 0), [torch.randn(B0, 2, 3)])
        test(lambda x: torch.argmax(x, -1), [torch.randn(B0, 2, 3)])
        test(lambda x: torch.argmax(x, 2), [torch.randn(B0, 2, 3)])

    def _test_sum_mean(self, op):
        test = self._vmap_test
        B0, B1 = 5, 7

        # Single vmap, various in_dims / out_dims
        test(op, [torch.randn([B0])])
        test(op, [torch.randn([B0, 3])])
        test(op, [torch.randn([2, 5, B0, 3])], in_dims=2)
        test(op, [torch.randn([2, 5, B0, 3])], in_dims=2)

        # Doubly nested vmap
        test(vmap(op), [torch.randn([B0, B1])])
        test(vmap(op), [torch.randn([B1, 2, 5, B0, 3])])
        test(vmap(op), [torch.randn([2, 5, B0, B1, 3])], in_dims=2)

    def test_sum(self):
        self._test_sum_mean(torch.sum)

    def test_mean(self):
        self._test_sum_mean(torch.mean)

    def test_repeat(self):
        test = self._vmap_test
        B0 = 7
        op = Tensor.repeat
        test(lambda x: op(x, (2, 3)), (torch.rand(B0, 1, 1),))
        test(lambda x: op(x, (2, 3)), (torch.rand(1, B0, 1),), in_dims=1)

    def test_slogdet(self):
        test = functools.partial(self._vmap_test, check_propagates_grad=False)
        B0 = 7
        op = torch.linalg.slogdet
        test(op, (torch.rand(B0, 1, 1),))
        test(op, (torch.rand(B0, 2, 2),))
        test(op, (torch.rand(B0, 3, 2, 2),))
        test(op, (torch.rand(3, 2, 2, B0),), in_dims=3)

    def test_reshape(self):
        test = self._vmap_test
        B0, B1, B2 = 7, 11, 13
        op = torch.reshape
        test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None), check_view=True)
        test(op, (torch.rand(2, B0, 5), [1, 1, 10]), in_dims=(1, None), check_view=False)
        test(vmap(lambda t: t.reshape([-1])), (torch.rand(B0, B1, 2, 5),), check_view=True)
        test(vmap(vmap(lambda t: t.reshape([-1]), in_dims=2), in_dims=1),
             (torch.rand(3, B1, 2, B2, 5, B0),), in_dims=5, check_view=False)

    def test_reshape_as(self):
        test = self._vmap_test
        B0, B1, B2 = 7, 11, 13
        op = torch.Tensor.reshape_as
        test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)), check_view=True)
        test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0), check_view=True)
        test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None), check_view=True)

        test(op, (torch.rand(2, B0, 5), torch.rand(1, 1, 10)), in_dims=(1, None), check_view=False)

        test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)), check_view=True)
        test(vmap(vmap(op, in_dims=(2, None)), in_dims=(1, None)),
             (torch.rand(3, B1, 2, B2, 5, B0), torch.rand(B0, 3 * 2 * 5)),
             in_dims=(5, 0), check_view=False)

    def test_result_type(self):
        def scalar_tensor_with_dtype(op):
            def wrapped(*args, **kwargs):
                dtype = op(*args, **kwargs)
                return torch.ones([], dtype=dtype)
            return wrapped

        test = self._vmap_test
        op = scalar_tensor_with_dtype(torch.result_type)

        B0 = 2

        test(op, (torch.randn(B0), torch.randn(B0, dtype=torch.float64)),
             check_propagates_grad=False)
        test(op, (torch.randn(B0), torch.randint(10, [B0], dtype=torch.int64)),
             check_propagates_grad=False)

        test(lambda x: op(x, 1), (torch.randn(B0),), check_propagates_grad=False)
        test(lambda x: op(x, 1.6), (torch.randn(B0),), check_propagates_grad=False)

        test(lambda x: op(x, torch.tensor(1)), (torch.randn(B0),),
             check_propagates_grad=False)
        test(lambda x: op(x, torch.tensor(1.6, dtype=torch.double)),
             (torch.randn(B0),), check_propagates_grad=False)

        test(op, (torch.randn(B0, 2), torch.randn(B0, 2, dtype=torch.float64)),
             check_propagates_grad=False)
        test(op, (torch.randn(B0, 2), torch.randint(10, [B0, 2], dtype=torch.int64)),
             check_propagates_grad=False)

        test(lambda x: op(x, 1), (torch.randn(B0, 2),), check_propagates_grad=False)
        test(lambda x: op(x, 1.6), (torch.randn(B0, 2),), check_propagates_grad=False)

        test(lambda x: op(x, torch.tensor(1)), (torch.randn(B0, 2),),
             check_propagates_grad=False)
        test(lambda x: op(x, torch.tensor(1.6, dtype=torch.double)),
             (torch.randn(B0, 2),), check_propagates_grad=False)

        test(op, (torch.randn(B0, 2), torch.randn(B0, dtype=torch.float64)),
             check_propagates_grad=False)
        test(op, (torch.randn(B0, 2), torch.randint(10, [B0], dtype=torch.int64)),
             check_propagates_grad=False)

    def test_tensor_split(self):
        test = self._vmap_view_test
        op = torch.tensor_split
        B0, B1, B2 = 7, 11, 13

        # tests for torch.tensor_split(self, indices_or_sections: int, dim)
        test(op, (torch.rand(B0, 2, 1024), 5, -1), in_dims=(0, None, None))
        test(op, (torch.rand(2, B0, 1024), 150, 1), in_dims=(1, None, None))
        test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 256, 0),
             in_dims=(2, None, None))
        test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
             (torch.rand(B1, 2, B0, 64, B2),), in_dims=2)

        # tests for torch.tensor_split(self, indices_or_sections: List[int], dim)
        test(op, (torch.rand(B0, 2, 1024), [50, 100, 378, 890], -1), in_dims=(0, None, None))
        test(op, (torch.rand(2, B0, 1024), [50, 100, 212, 345, 0, 378, 890], 1), in_dims=(1, None, None))
        test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), [50, 100, 212, 345, 0, 378, 890], 0),
             in_dims=(2, None, None))
        test(vmap(vmap(lambda t: op(t, [4, 8, 9, 34, 29], 1), in_dims=2)),
             (torch.rand(B1, 2, B0, 64, B2),), in_dims=2)

    def test_split(self):
        test = self._vmap_view_test
        op = torch.split
        B0, B1, B2 = 7, 11, 13

        # tests for torch.split(self, split_size: int, dim)
        test(op, (torch.rand(B0, 2, 1024), 101, -1), in_dims=(0, None, None))
        test(op, (torch.rand(2, B0, 1024), 130, 1), in_dims=(1, None, None))
        test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 256, 0),
             in_dims=(2, None, None))
        test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
             (torch.rand(B1, 2, B0, 64, B2),), in_dims=2)

        # tests for torch.split(self, split_size: List[int], dim)
        test(op, (torch.rand(B0, 2, 1024), [1, 1020, 3], -1), in_dims=(0, None, None))
        test(op, (torch.rand(2, B0, 1024), [100] * 10 + [24], 1), in_dims=(1, None, None))
        test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), [256] * 3 + [255], 0),
             in_dims=(2, None, None))
        test(vmap(vmap(lambda t: op(t, [4] * 8 + [8] * 4, 1), in_dims=2)),
             (torch.rand(B1, 2, B0, 64, B2),), in_dims=2)

    def test_trace(self):
        op = torch.trace
        test = self._vmap_test
        B0, B1, B2 = 7, 11, 13
        test(op, (torch.rand(B0, 2, 5),))
        test(op, (torch.rand(2, B0, 5),), in_dims=1)
        test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
        test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2)

    def test_transpose(self):
        op = torch.transpose
        test = self._vmap_view_test

        B0, B1, B2 = 7, 11, 13
        test(lambda x: op(x, 0, 1), (torch.rand(B0, 2, 5),))
        test(lambda x: op(x, -1, -2), (torch.rand(B0, 2, 5),))
        test(lambda x: op(x, 3, 1), (torch.rand(B0, 2, 5, 4, 6),))
        test(lambda x: op(x, 1, 0), (torch.rand(2, B0, 5),), in_dims=1)
        test(vmap(lambda x: op(x, 0, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
        test(vmap(vmap(lambda x: op(x, 0, 1), in_dims=2)),
             (torch.rand(B1, 2, B0, 5, B2),), in_dims=2)

        # Special case: scalar tensor
        for dim1, dim2 in itertools.product([0, -1], [0, -1]):
            x = torch.rand(B0)
            result = vmap(lambda x: op(x, dim1, dim2))(x)
            self.assertTrue(result is x)

    def test_t(self):
        op = torch.t
        test = self._vmap_view_test
        B0, B1, B2 = 7, 11, 13
        test(op, (torch.rand(B0, 2, 5),))
        test(op, (torch.rand(2, B0, 5),), in_dims=1)
        test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
        test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2)

    def test_T_numpy(self):
        def op(t):
            return t.T

        test = self._vmap_view_test
        B0, B1, B2 = 7, 11, 13
        test(op, (torch.rand(B0, 2, 3, 5),))
        test(op, (torch.rand(B0),))
        test(op, (torch.rand(2, B0, 3, 5),), in_dims=1)
        test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
        test(vmap(op), (torch.rand(B1, 2, B0, 3, 5),), in_dims=2)
        test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 3, B2, 5),), in_dims=2)

    def test_to(self):
        test = self._vmap_test
        B0, B1 = 7, 11

        test(lambda t: t.to('cpu'), (torch.rand(B0),))
        test(lambda t: t.to(torch.double), (torch.rand(B0),))
        test(lambda t, o: t.to(o), (torch.rand(B0), torch.randn(B0, dtype=torch.float64)))
        test(lambda t, o: t.to(o),
             (torch.rand(B0), torch.randn(B0, dtype=torch.float64)),
             in_dims=(0, None))
        test(vmap(lambda t: t.to(torch.double)), (torch.rand(B0, B1, 3),))

        # also test some casting methods
        test(lambda t: t.double(), (torch.rand(B0),))
        test(lambda t: t.float(), (torch.rand(B0),))
        test(lambda t: t.int(), (torch.rand(B0),), check_propagates_grad=False)
        test(lambda t: t.long(), (torch.rand(B0),), check_propagates_grad=False)

    def test_unfold(self):
        op = torch.Tensor.unfold
        test = self._vmap_view_test
        B0, B1, B2 = 3, 2, 5

        test(op, (torch.rand(B0, 7, 11), 0, 2, 1), in_dims=(0, None, None, None))
        test(op, (torch.rand(7, B0, 11), 1, 4, 2), in_dims=(1, None, None, None))
        test(vmap(op, in_dims=(0, None, None, None)),
             (torch.rand(B1, 7, B0, 11), 1, 5, 1), in_dims=(2, None, None, None))
        test(vmap(vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)),
             (torch.rand(B1, 7, B0, 11, B2), -1, 2, 4), in_dims=(2, None, None, None))

    def test_unbind(self):
        test = self._vmap_view_test
        op = torch.unbind
        B0, B1, B2 = 7, 11, 13

        test(op, (torch.rand(B0, 2, 1024), -1), in_dims=(0, None))
        test(op, (torch.rand(B0, 2, 0),))
        test(op, (torch.rand(2, B0, 7), 0), in_dims=(1, None))
        test(vmap(op, in_dims=(0, None)), (torch.rand(B1, 1023, B0, 5), 1),
             in_dims=(2, None))
        test(vmap(vmap(lambda t: op(t, dim=1), in_dims=2)),
             (torch.rand(B1, 2, B0, 32, B2),), in_dims=2)

    def test_view(self):
        test = self._vmap_view_test
        B0, B1, B2 = 7, 11, 13
        op = torch.Tensor.view

        # We should error out if the view would produce an incorrect result
        with self.assertRaises(RuntimeError):
            vmap(op, in_dims=(1, None))(torch.rand(2, B0, 5), [10])

        test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None))
        test(op, (torch.rand(B0, 4, 5), [1, 2, 1, 10]), in_dims=(0, None))
        test(vmap(lambda t: t.view([-1])), (torch.rand(B0, B1, 2, 5, 3),))
        test(vmap(vmap(lambda t: t.reshape([-1])), in_dims=1),
             (torch.rand(B2, B0, B1, 3, 2, 5),), in_dims=1)

    def test_view_as(self):
        test = self._vmap_view_test
        B0, B1, B2 = 7, 11, 13
        op = torch.Tensor.view_as

        # We should error out if the view would produce an incorrect result
        with self.assertRaises(RuntimeError):
            vmap(op, in_dims=(1, 0))(torch.rand(2, B0, 5), torch.rand(B0, 10))

        test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)))
        test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0))
        test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None))

        test(op, (torch.rand(B0, 4, 5), torch.rand(2, 1, 1, 10)), in_dims=(0, None))

        test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)))
        test(vmap(vmap(op, in_dims=(0, None)), in_dims=(0, None)),
             (torch.rand(B1, B2, B0, 3, 2, 5), torch.rand(B0, 3 * 2 * 5)),
             in_dims=(2, 0))

    def test_conv2d(self):
        conv_setups = [
            (torch.nn.Conv1d, torch.conv1d, [2, 4, 15]),
            (torch.nn.Conv2d, torch.conv2d, [2, 4, 15, 20]),
            (torch.nn.Conv3d, torch.conv3d, [2, 4, 15, 20, 25]),
            # (torch.nn.ConvTranspose2d, torch.conv_transpose2d, [2, 4, 15, 20])
        ]
        for conv_mod, conv_fn, inp_shape in conv_setups:
            mod = conv_mod(4, 8, kernel_size=3)
            arg_values = [torch.randn(inp_shape), mod.weight, mod.bias]
            kwarg_values = {}
            for loop_out, batched_out in get_fallback_and_vmap_exhaustive(conv_fn, arg_values, kwarg_values):
                self.assertEqual(loop_out, batched_out)

            arg_values = [torch.randn(inp_shape), mod.weight, None]
            for loop_out, batched_out in get_fallback_and_vmap_exhaustive(conv_fn, arg_values, kwarg_values):
                self.assertEqual(loop_out, batched_out)

            mod2 = conv_mod(4, 8, kernel_size=3, groups=2, stride=3, padding=1, dilation=2)
            arg_values = [torch.randn(inp_shape), mod2.weight, mod2.bias]
            kwarg_values = dict(groups=2, stride=3, padding=1, dilation=2)
            for loop_out, batched_out in get_fallback_and_vmap_exhaustive(conv_fn, arg_values, kwarg_values):
                self.assertEqual(loop_out, batched_out)

            arg_values = [torch.randn(inp_shape), mod2.weight, None]
            for loop_out, batched_out in get_fallback_and_vmap_exhaustive(conv_fn, arg_values, kwarg_values):
                self.assertEqual(loop_out, batched_out)

    def test_one_hot(self):
        sample_inputs = [
            (torch.randint(0, 3, []), 3),
            (torch.randint(0, 3, [2, 3, 4]), 4),
        ]
        for args in sample_inputs:
            for loop_out, batched_out in get_fallback_and_vmap_exhaustive(F.one_hot, args, {}):
                self.assertEqual(loop_out, batched_out)

    def test_conj_bit(self):
        x = torch.tensor([1 + 1j, 2 + 1j])

        def foo(x):
            assert not x.is_conj()
            y = x.conj()
            assert y.is_conj()
            return y
        res = vmap(foo)(x)
        self.assertEqual(res, x.conj())

    def test_mode_key(self):
        def vmap_f(x):
            return x + torch.randn(())

        def naive_f(x, shape):
            return x + torch.randn(shape)

        torch.manual_seed(0)
        out1 = vmap(vmap(vmap_f, randomness='different'), randomness='different')(torch.ones(2, 3))

        torch.manual_seed(0)
        out2 = naive_f(torch.ones(2, 3), (2, 3))
        self.assertEqual(out1, out2)

        torch.manual_seed(0)
        out1 = vmap(vmap(vmap_f, randomness='different'), randomness='different')(torch.ones(2, 3, 4))

        torch.manual_seed(0)
        out2 = naive_f(torch.ones(2, 3, 4), (2, 3, 1))
        self.assertEqual(out1, out2)

        self.assertTrue(torch.randn(()).dim() == 0)

    @parametrize('in_dim', [0, 1, 2])
    @parametrize('out_dim', [0, 1, 2])
    @parametrize('randomness', ['error', 'same'])
    def test_chunk_vmap(self, in_dim, out_dim, randomness):

        x = torch.randn(4, 5, 6)

        def f(x):
            y = x.sin()
            if randomness != "error":
                y = y + torch.rand_like(x)
            return y

        rs = torch.get_rng_state()
        expected = vmap(f, in_dims=in_dim, out_dims=out_dim, randomness=randomness)(x)

        for chunks in [1, 2, 3, 4, 7, 10, 16]:
            torch.set_rng_state(rs)
            output = chunk_vmap(
                f, in_dims=in_dim, out_dims=out_dim, randomness=randomness, chunks=chunks
            )(x)
            self.assertEqual(output, expected)


instantiate_parametrized_tests(TestVmapOperators)


def construct_v(output, batch_size, contig=False):
    if contig:
        return torch.randn(batch_size, *output.shape,
                           dtype=output.dtype, device=output.device)
    result = torch.randn(*output.shape, batch_size,
                         dtype=output.dtype, device=output.device)
    return result.movedim(-1, 0)

def as_tuple(x):
    if isinstance(x, tuple):
        return x
    elif isinstance(x, list):
        return tuple(x)
    else:
        return x,


def differentiable(args):
    return tuple(arg for arg in as_tuple(args)
                 if isinstance(arg, torch.Tensor) and arg.requires_grad)


def _get_rand_no_zeros(*args, **kwargs):
    requires_grad = kwargs.get('requires_grad', False)
    kwargs_without_requires_grad = kwargs.copy()
    kwargs_without_requires_grad['requires_grad'] = False
    result = torch.rand(*args, **kwargs_without_requires_grad)
    return result.clamp_min_(0.1).requires_grad_(requires_grad)


class TestVmapBatchedGradient(Namespace.TestVmapBase):
    def _vmap_test(self, *args, **kwargs):
        return _vmap_test(self, *args, **kwargs)

    # Tests batched gradient computation of outputs = op(*args, **kwargs)
    # by comparing it to a sequential map+stack fallback.
    #
    # output_process_fn: a function that maps the outputs to the part
    #       that should be differentiated.
    # batch_size: the batch dim size for the batched grad
    def _batched_grad_test(self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3):
        if kwargs is None:
            kwargs = {}
        outputs = op(*args, **kwargs)
        outputs = differentiable(output_process_fn(outputs))
        for contig in [True, False]:
            batched_vectors = tuple(construct_v(out, batch_size, contig)
                                    for out in outputs)

            def vector_jacobian_product(*vectors):
                return torch.autograd.grad(outputs, differentiable(args), vectors,
                                           retain_graph=True)
            self._vmap_test(vector_jacobian_product, batched_vectors,
                            check_propagates_grad=False)

    # Tests batched second grad computation of outputs = op(*args, **kwargs).
    # by comparing it to a sequential map+stack fallback.
    #
    # output_process_fn: a function that maps the outputs to the part
    #       that should be differentiated.
    # batch_size: the batch dim size for the batched grad
    #
    # NB: we only test computing batched gradients in the second gradient
    # computation. One specific use case that does this is computing the hessian
    # matrix of a scalar-valued function; this is useful in Bayesian Logistic
    # Regression.
    # It might be useful to have a test that computes batched first gradients and
    # then uses those to compute batched second gradients in the future.
    def _batched_grad_grad_test(self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3):
        if kwargs is None:
            kwargs = {}
        outputs = op(*args, **kwargs)
        outputs = differentiable(output_process_fn(outputs))
        ones = tuple(torch.ones_like(out) for out in outputs)
        # Same thing as summing together all of the outputs and calling .backward()
        first_grads = torch.autograd.grad(outputs, differentiable(args), ones,
                                          create_graph=True)
        first_grads = differentiable(first_grads)
        self.assertNotEqual(
            len(first_grads), 0, "None of the first grads depend on the input!")

        for contig in [True, False]:
            batched_vectors = tuple(construct_v(grad, batch_size, contig)
                                    for grad in first_grads)

            def vector_hessian_product(*vectors):
                outputs = torch.autograd.grad(first_grads, differentiable(args), vectors,
                                              retain_graph=True, allow_unused=True)
                outputs = tuple(out for out in outputs if out is not None)
                assert len(outputs) > 0
                return outputs

            self._vmap_test(vector_hessian_product, batched_vectors,
                            check_propagates_grad=False)

    def _test_arithmetic(self, op, device, test_grad_grad=True):
        x = torch.randn(2, 3, requires_grad=True, device=device)
        y = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
        scalar = 3.14
        self._batched_grad_test(op, (x, y))
        self._batched_grad_test(op, (scalar, y))
        self._batched_grad_test(op, (x, scalar))

        if test_grad_grad:
            self._batched_grad_grad_test(op, (x, y))

    def test_add(self, device):
        self._test_arithmetic(torch.add, device, test_grad_grad=False)
        self._test_arithmetic(lambda x, y: x + y, device, test_grad_grad=False)

    def test_sub(self, device):
        self._test_arithmetic(torch.sub, device, test_grad_grad=False)
        self._test_arithmetic(lambda x, y: x - y, device, test_grad_grad=False)

    def test_mul(self, device):
        self._test_arithmetic(torch.mul, device)
        self._test_arithmetic(lambda x, y: x * y, device)

    def test_div(self, device):
        self._test_arithmetic(torch.div, device)
        self._test_arithmetic(lambda x, y: x / y, device)

    def test_binary_cross_entropy(self, device):
        x = F.sigmoid(torch.randn(3, 2, device=device, requires_grad=True))
        target = torch.rand(3, 2, device=device)

        op = functools.partial(F.binary_cross_entropy, target=target)

        self._batched_grad_test(op, (x,), {})
        self._batched_grad_grad_test(op, (x,), {})

    def test_log_softmax(self, device):
        op = functools.partial(torch.log_softmax, dim=-1)
        x = torch.randn(3, 2, device=device, requires_grad=True)

        self._batched_grad_test(op, (x,), {})
        self._batched_grad_grad_test(op, (x,), {})

    def test_expand(self, device):
        x = torch.randn(2, 3, device=device, requires_grad=True)

        def op(x):
            return x.expand(5, 5, 2, 3)
        self._batched_grad_test(op, (x,))

    @allowVmapFallbackUsage
    def test_index(self, device):
        x = torch.randn(2, 3, requires_grad=True, device=device)
        index = torch.tensor([[0, 0], [1, 1]], device=device)

        def op(x):
            y = x * x
            return y[index]

        self._batched_grad_test(op, (x,))
        self._batched_grad_grad_test(op, (x,))

    def test_lgamma(self, device):
        x = torch.randn(2, 3, requires_grad=True, device=device)
        self._batched_grad_test(Tensor.lgamma, (x,))
        self._batched_grad_grad_test(Tensor.lgamma, (x,))

    def test_log(self, device):
        x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
        self._batched_grad_test(torch.log, (x,))
        self._batched_grad_grad_test(torch.log, (x,))

    def test_logsumexp(self, device):
        x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)

        def op(x):
            return torch.logsumexp(x, -1)

        self._batched_grad_test(op, (x,))
        self._batched_grad_grad_test(op, (x,))

    def test_log1p(self, device):
        x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
        self._batched_grad_test(torch.log1p, (x,))
        self._batched_grad_grad_test(torch.log1p, (x,))

    @allowVmapFallbackUsage
    def test_max(self, device):
        x = torch.randn(2, 3, requires_grad=True, device=device)
        self._batched_grad_test(torch.max, (x,))

    @allowVmapFallbackUsage
    def test_median(self, device):
        x = torch.randn(2, 3, requires_grad=True, device=device)
        self._batched_grad_test(torch.median, (x,))

    @allowVmapFallbackUsage
    def test_min(self, device):
        x = torch.randn(2, 3, requires_grad=True, device=device)
        self._batched_grad_test(torch.min, (x,))

    def test_permute(self, device):
        x = torch.randn(2, 3, 5, requires_grad=True, device=device)

        def op(x):
            return x.permute(2, 0, 1)

        self._batched_grad_test(op, (x,))

    def test_reshape(self, device):
        x = torch.randn(2, 3, 5, requires_grad=True, device=device)

        def op(x):
            return x.reshape([2 * 3, 5])

        self._batched_grad_test(op, (x,))

    def test_sigmoid(self, device):
        x = torch.randn(2, 3, requires_grad=True, device=device)
        self._batched_grad_test(Tensor.sigmoid, (x,))
        self._batched_grad_grad_test(Tensor.sigmoid, (x,))

    def test_stack(self, device):
        x = torch.randn(2, 3, device=device, requires_grad=True)
        y = torch.randn(2, 3, device=device, requires_grad=True)

        def op(x, y):
            return torch.stack([x, y])
        self._batched_grad_test(op, (x, y))

    def test_select(self, device):
        x = torch.randn(2, 3, device=device, requires_grad=True)
        self._batched_grad_test(lambda x: x[1], (x,))
        self._batched_grad_test(lambda x: x.select(1, 2), (x,))
        self._batched_grad_test(lambda x: x.select(-1, 0), (x,))

    def test_slice(self, device):
        x = torch.randn(2, 3, 5, device=device, requires_grad=True)
        self._batched_grad_test(lambda x: x[0:1], (x,))
        self._batched_grad_test(lambda x: x[:, 1:3], (x,))
        self._batched_grad_test(lambda x: x[..., 1:3], (x,))

    def test_trace(self, device):
        x = torch.randn(2, 3, device=device, requires_grad=True)
        self._batched_grad_test(Tensor.trace, (x,))

        x = torch.randn(3, 2, 2, device=device)

        def sum_grad_trace(x):
            return grad(torch.trace)(x).sum()

        output = vmap(grad(sum_grad_trace))(x)
        self.assertEqual(output, torch.zeros_like(output))

    def test_where(self, device):
        x = torch.randn(3, 2, device=device)
        y = torch.ones(3, 2, device=device)

        def f(x, y):
            return torch.where(x > 0, x, y)

        # Check that there is no runtime error, exactness tests are done with opinfo
        vmap(f)(x, y)

        x = torch.randint(0, 2, size=(4, 3), dtype=torch.float)

        def f(t):
            return torch.where(t)

        with self.assertRaisesRegex(RuntimeError, r"Attempted to vmap over aten::where"):
            vmap(f)(x)

    @skipCUDAIfNoMagma
    @allowVmapFallbackUsage
    def test_symeig(self, device):
        def op(x):
            return torch.symeig(x, eigenvectors=True)[0]

        x = torch.randn(3, 3, device=device, requires_grad=True)
        self._batched_grad_test(op, (x,), {})
        self._batched_grad_grad_test(op, (x,), {})

    def test_threshold(self, device):
        x = torch.randn(2, 3, device=device, requires_grad=True)
        self._batched_grad_test(lambda x: F.threshold(x, 0.5, 0.0), (x,))

    @allowVmapFallbackUsage
    def test_inplace_view(self, device):
        leaf = torch.randn(4, 5, requires_grad=True)

        def func(leaf):
            # Make sure the function is non-trivially twice differentiable
            base = leaf * leaf
            view = base[0]
            view.cos_()
            return view

        self._batched_grad_test(func, (leaf,), {})
        self._batched_grad_grad_test(func, (leaf,), {})

    @allowVmapFallbackUsage
    def test_inplace_manyview(self, device):
        leaf = torch.randn(4, 4, 5, requires_grad=True)

        def func(leaf):
            # Make sure the function is non-trivially twice differentiable
            base = leaf * leaf
            view = base.transpose(0, 2)
            view = view[1]
            view = view.diagonal()
            view = view[::2]
            view.cos_()
            return view

        self._batched_grad_test(func, (leaf,), {})
        self._batched_grad_grad_test(func, (leaf,), {})

    def test_diagonal(self, device):
        x = torch.randn(4, 5, device=device, requires_grad=True)
        self._batched_grad_test(lambda x: x.diagonal(1, 0, 1), (x,))

        x = torch.randn(3, 4, 5, device=device, requires_grad=True)
        self._batched_grad_test(lambda x: x.diagonal(0, -1, -2), (x,))

    @allowVmapFallbackUsage
    def test_unrelated_output(self, device):
        B0 = 3
        x = torch.randn([], requires_grad=True)
        y = torch.randn([], requires_grad=True)
        gy = torch.randn(B0, requires_grad=True)

        def vjp(v):
            res, = torch.autograd.grad(y, x, v, allow_unused=True)
            return torch.zeros_like(x) if res is None else res

        result = vmap(vjp)(gy)
        self.assertEqual(result, torch.zeros(B0, *x.shape, device=device))

    @allowVmapFallbackUsage
    def test_unrelated_output_multiple_grad(self, device):
        B0 = 3
        x = torch.randn([], requires_grad=True)
        y = torch.randn([], requires_grad=True)
        gy = torch.randn(B0, requires_grad=True)

        def vjp(v):
            res, = torch.autograd.grad(y, x, v, allow_unused=True)
            return torch.zeros_like(x) if res is None else res

        _ = vjp(gy[0])
        result = vmap(vjp)(gy)
        self.assertEqual(result, torch.zeros(B0, *x.shape, device=device))


def discover_variants(opinfo):
    aliases = []
    inplace_variants = []

    if opinfo.inplace_variant:
        inplace_variants.append(opinfo.inplace_variant)

    aliases.append(opinfo.op)
    for alias in opinfo.aliases:
        aliases.append(alias.op)
        if alias.inplace_variant:
            inplace_variants.append(alias.inplace_variant)
    return aliases, inplace_variants


class TestVmapOperatorsOpInfo(TestCase):

    def vmap_outplace_test(self, func, args, kwargs, in_dims, check_shape_only=False,
                           postprocess_fn=None):
        for loop_out, vmap_out in compute_quantities_for_vmap_test(func, args, kwargs, in_dims):
            if postprocess_fn is not None:
                loop_out = postprocess_fn(loop_out)
                vmap_out = postprocess_fn(vmap_out)
            if check_shape_only:
                self.assertEqual(vmap_out.shape, loop_out.shape)
                continue
            self.assertEqual(vmap_out, loop_out)

    def vmap_inplace_test(self, func, args, kwargs, in_dims, postprocess_fn=None):
        # NB: This test assumes that the first argument is being modified.
        # This is OK because it's what every other OpInfo-based test assumes,
        # but it is going to need a more robust solution eventually.
        if in_dims[0] is None:
            # Check that we correctly raise an error when vmap is impossible
            # on the in-place operation
            with self.assertRaises(RuntimeError):
                for _ in compute_quantities_for_vmap_test(
                        func, args, kwargs, in_dims, compute_loop_out=False, clone_inputs=True):
                    pass
            return
        for loop_out, vmap_out in compute_quantities_for_vmap_test(
                func, args, kwargs, in_dims, clone_inputs=True):
            if postprocess_fn is not None:
                loop_out = postprocess_fn(loop_out)
                vmap_out = postprocess_fn(vmap_out)
            self.assertEqual(vmap_out, loop_out)

    def opinfo_vmap_test(self, device, dtype, op, check_has_batch_rule,
                         skip_inplace=(), postprocess_fn=None):
        def test():
            # Error inputs check
            if op.error_inputs_func is not None:
                error_inputs = op.error_inputs(device)
                for error_input in error_inputs:
                    sample_input = error_input.sample_input
                    args = (sample_input.input,) + tuple(sample_input.args)
                    kwargs = sample_input.kwargs
                    for args, in_dims, _ in generate_vmap_inputs(args, {}):
                        with self.assertRaises(Exception):
                            vmap(op, in_dims)(*args, **kwargs)

            # Sample inputs check
            sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
            aliases, inplace_aliases = discover_variants(op)
            check_shape_only = op.name in ('empty_like', 'new_empty')
            for sample_input in sample_inputs_itr:
                args = (sample_input.input,) + sample_input.args
                kwargs = sample_input.kwargs
                is_batch_norm_and_training = is_batch_norm_training(op.name, kwargs)
                for args, in_dims, _ in generate_vmap_inputs(
                        args, {}, is_batch_norm_and_training=is_batch_norm_and_training):
                    for func in aliases:
                        self.vmap_outplace_test(func, args, kwargs, in_dims, check_shape_only, postprocess_fn)
                    if op.name in skip_inplace:
                        continue
                    if not is_valid_inplace_sample_input(sample_input, op, op.inplace_variant):
                        continue
                    for func in inplace_aliases:
                        self.vmap_inplace_test(func, args, kwargs, in_dims, postprocess_fn)

        if check_has_batch_rule:
            check_vmap_fallback(self, test, op)
        else:
            test()

    vmap_fail = {
        # -------------------- ALLOWED FAILURES --------------------------------
        # These are things that we either cannot fix or are not actually problems
        xfail('resize_'),
        xfail('resize_as_'),
        xfail('to_sparse'),
        xfail('__getitem__'),  # dynamic mask
        xfail('index_put'),  # dynamic mask
        xfail('nn.functional.dropout'),  # works, can't check against for loop because of randomness inconsistency
        xfail('nn.functional._scaled_dot_product_attention'),  # randomness
        xfail('masked_select'),  # dynamic op
        xfail('nonzero'),  # dynamic op
        xfail('unique', ''),  # dynamic op
        xfail('unique_consecutive', ''),  # dynamic op
        xfail('allclose'),  # returns a boolean
        xfail('uniform'),  # randomness is tested separately
        xfail('rand_like'),  # randomness is tested separately
        xfail('randint_like'),  # randomness is tested separately
        xfail('randn_like'),  # randomness is tested separately
        xfail('bernoulli', ''),  # randomness is tested separately
        xfail('normal', ''),  # randomness is tested separately
        xfail('normal', 'number_mean'),  # randomness is tested separately
        xfail('multinomial', ''),  # randomness
        xfail('nn.functional.embedding', ''),  # we only support some cases
        xfail('nn.functional.rrelu'),  # randomness
        xfail('nn.functional.dropout2d', ''),  # randomness
        xfail('nn.functional.dropout3d', ''),  # randomness
        xfail('nn.functional.feature_alpha_dropout', 'with_train'),  # randomness
        xfail('as_strided'),  # Our test runner can't handle this; manual test exists
        xfail('new_empty_strided'),  # empty tensor data is garbage so it's hard to make comparisons with it
        xfail('nn.functional.fractional_max_pool3d'),  # randomness
        xfail('nn.functional.fractional_max_pool2d'),  # randomness
        xfail('pca_lowrank', ''),  # random operation
        xfail('svd_lowrank', ''),  # random operation
        xfail('linspace', ''),  # test runner can't handle factory functions
        xfail('arange', ''),  # test runner can't handle factory functions
        xfail('logspace', ''),  # test runner can't handle factory functions
        xfail('empty', ''),  # test runner can't handle factory functions
        xfail('ones', ''),  # test runner can't handle factory functions
        xfail('zeros', ''),  # test runner can't handle factory functions
        xfail('eye', ''),  # non-tensor input
        xfail('broadcast_shapes', ''),  # test runner can't handle non-Tensor ops
        xfail('sparse.sampled_addmm'),  # sparse
        xfail('cross'),  # The default value of dim in op is *very* weird. No wonder it doesn't work
        xfail('svd', device_type='cuda'),  # not unique, see test_linalg_svd for manual test
        xfail('linalg.svd', device_type='cuda'),  # not unique, see test_linalg_svd for manual test
        skip('linalg.eigh', ''),  # not unique, see test_linalg_eigh for manual test
        skip('to'),  # RuntimeError: required rank 4 tensor to use channels_last format
        # ----------------------------------------------------------------------

        # ---------------------------- BUGS ------------------------------------
        # entries in here don't work and need to be fixed.
        # Each one of these is a bug
        xfail('clamp_min', ''),  # Exception not raised on error input
        xfail('clamp_max', ''),  # Exception not raised on error input

        xfail('view_as_complex'),  # RuntimeError: Tensor must have a last dimension with stride 1
        xfail('tensor_split'),  # data_ptr
        xfail('histogramdd'),  # expected Tensor as element 0 in argument 0, but got tuple
        xfail('nn.functional.gaussian_nll_loss'),  # data-dependent control flow error
        xfail('nn.functional.embedding_bag'),  # embedding renorm vmap inplace incompatible
        xfail('__rpow__'),  # https://github.com/pytorch/functorch/issues/617
        xfail('column_stack', ''),  # Batching rule not implemented for aten::column_stack
        xfail('narrow'),  # Batching rule not implemented for aten::narrow.Tensor

        # required rank 4 tensor to use channels_last format
        xfail('bfloat16'),
        xfail('bool'),
        xfail('byte'),
        xfail('char'),
        xfail('double'),
        xfail('float'),
        xfail('half'),
        xfail('int'),
        xfail('long'),
        xfail('short'),

        xfail('jiterator_binary', device_type='cuda'),  # NYI: querying is_contiguous inside of vmap
        xfail('jiterator_binary_return_by_ref', device_type='cuda'),  # NYI: querying is_contiguous inside of vmap
        xfail('jiterator_4inputs_with_extra_args', device_type='cuda'),  # NYI: querying is_contiguous inside of vmap
        xfail('equal', ''),  # TypeError: object of type 'bool' has no len(); likely testrunner problem
        xfail('jiterator_unary', device_type='cuda'),  # NYI: querying is_contiguous inside of vmap
        xfail('jiterator_2inputs_2outputs', device_type='cuda'),  # NYI: querying is_contiguous inside of vmap
        # ---------------------------------------------------------------------
    }

    @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
    @opsToleranceOverride('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', (
        tol1('linalg.det',
             {torch.float32: tol(atol=1e-04, rtol=1e-04)}, device_type='cuda'),
        # The following is often flaky, but just on windows.
        # We should investigate if it's actually a problem or not.
        tol1('nn.functional.conv_transpose3d',
             {torch.float32: tol(atol=1e-04, rtol=1e-02)}, device_type='cuda'),
    ))
    @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
    @skipOps('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', vmap_fail.union({
        xfail('cat'),
        xfail('native_batch_norm'),
    }))
    def test_vmap_exhaustive(self, device, dtype, op):
        # needs to be fixed
        inplace_failure_list = (
        )
        self.opinfo_vmap_test(device, dtype, op, check_has_batch_rule=False,
                              skip_inplace=inplace_failure_list)

    @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
    @opsToleranceOverride('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', (
        tol1('linalg.det',
             {torch.float32: tol(atol=1e-04, rtol=1e-04)}, device_type='cuda'),
    ))
    @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
    @skipOps('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', vmap_fail.union({
        skip('to'),  # RuntimeError: required rank 4 tensor to use channels_last format
        xfail('cat'),
        xfail('complex'),
        xfail('copysign'),
        xfail('native_batch_norm'),
        xfail('histogram'),
        xfail('index_fill'),
        xfail('nansum'),
        xfail('nanmean'),
        xfail('scatter_reduce', 'sum'),
        xfail('scatter_reduce', 'mean'),
        xfail('scatter_reduce', 'amax'),
        xfail('scatter_reduce', 'amin'),
        # `index_put` OpInfo in pytorch/pytorch has
        # masked index as input which is not supported
        xfail('index_put', ''),
        xfail('isin'),
        xfail('lu_unpack'),
        xfail('masked_fill'),
        xfail('masked_scatter'),
        xfail('masked_select'),
        xfail('nanquantile'),
        xfail('narrow_copy'),  # hit the vmap fallback which is currently disabled
        xfail('ormqr'),
        xfail('put'),
        xfail('quantile'),
        xfail('renorm'),
        xfail('resize_as_'),
        xfail('take'),
        xfail('tensor_split'),
        xfail('to_sparse'),
        xfail('vdot'),
        xfail('__getitem__', ''),
        xfail('all'),
        xfail('any'),
        xfail('count_nonzero'),
        xfail('nanmean'),
        xfail('nn.functional.dropout'),  # works, can't check against for loop because of randomness inconsistency
        xfail('nn.functional._scaled_dot_product_attention'),  # randomness
        xfail('resize_'),
        xfail('view_as_complex'),
        xfail('matrix_exp'),
        xfail('bucketize'),
        xfail('fft.ihfft2'),
        xfail('fft.ihfftn'),
        xfail('allclose'),
        xfail('argwhere'),
        xfail('unique_consecutive'),
        xfail('unique'),
        xfail('nn.functional.ctc_loss'),
        xfail('nn.functional.gaussian_nll_loss'),
        xfail('nn.functional.huber_loss'),
        # We can get this to work on CUDA through decomposition,
        # but fails on CPU due to max_pool1d_cpu not having a derivative
        xfail('nn.functional.max_pool1d'),
        xfail('nn.functional.max_pool3d'),
        xfail('histc'),
        xfail('as_strided'),
        xfail('istft'),
        xfail('nonzero'),
        xfail('nn.functional.fractional_max_pool2d'),
        xfail('stft'),
        xfail('isclose'),
        xfail('nn.functional.fractional_max_pool3d'),
        xfail('nn.functional.bilinear'),
        xfail('nn.functional.embedding_bag'),
        xfail('linalg.tensorsolve'),
        xfail('bernoulli', ''),
        xfail('linalg.lu_factor', ''),
        xfail('nn.functional.feature_alpha_dropout', 'with_train'),
        xfail('nn.functional.kl_div', ''),
        xfail('multinomial', ''),
        xfail('column_stack', ''),
        xfail('pca_lowrank', ''),
        xfail('normal', ''),
        xfail('nn.functional.dropout2d', ''),
        xfail('normal', 'number_mean'),
        xfail('svd_lowrank', ''),
        xfail('diagflat', ''),
        xfail('special.log_ndtr'),
        xfail('narrow'),  # Batching rule not implemented for aten::narrow.Tensor
        xfail('nn.functional.triplet_margin_loss', ''),
        xfail('nn.functional.pdist', ''),
        xfail('scatter_reduce', 'sum'),
        xfail('nn.functional.smooth_l1_loss', ''),
        xfail('scatter_reduce', 'amax'),
        xfail('nn.functional.max_unpool1d', 'grad'),
        xfail('nn.functional.multi_margin_loss', ''),
        xfail('scatter_reduce', 'prod'),
        xfail('nn.functional.multilabel_margin_loss', ''),
        xfail('scatter_reduce', 'amin'),
        xfail('nn.functional.max_unpool3d', 'grad'),
        xfail('nn.functional.max_unpool2d', ''),
        xfail('nn.functional.max_unpool2d', 'grad'),
        xfail('nn.functional.margin_ranking_loss', ''),
        xfail('nn.functional.max_unpool1d', ''),
        xfail('nn.functional.soft_margin_loss', ''),
        xfail('scatter_reduce', 'mean'),
        xfail('nn.functional.max_unpool3d', ''),
        xfail('linalg.ldl_solve', '', device_type='cpu'),
        xfail('chalf', ''),
        xfail('arange', ''),
        xfail('clamp_max', ''),
        xfail('jiterator_binary_return_by_ref', device_type='cuda'),
        xfail('special.spherical_bessel_j0'),
        xfail('jiterator_unary', device_type='cuda'),
        xfail('jiterator_2inputs_2outputs', device_type='cuda'),
        xfail('special.airy_ai'),
        xfail('clamp_min', ''),
        xfail('special.bessel_j0'),
        xfail('sparse.sampled_addmm'),
        xfail('special.bessel_y0'),
        xfail('special.chebyshev_polynomial_u'),
        xfail('special.modified_bessel_k1'),
        xfail('segment_reduce', 'offsets'),
        xfail('special.bessel_j1'),
        xfail('logspace', ''),
        xfail('empty', ''),
        xfail('index_reduce', ''),
        xfail('linspace', ''),
        xfail('special.laguerre_polynomial_l'),
        xfail('special.hermite_polynomial_h'),
        xfail('jiterator_binary', device_type='cuda'),
        xfail('special.modified_bessel_i0'),
        xfail('jiterator_4inputs_with_extra_args', device_type='cuda'),
        xfail('linalg.vander', ''),
        xfail('segment_reduce', 'lengths'),
        xfail('lu_solve', ''),
        xfail('special.bessel_y1'),
        xfail('special.hermite_polynomial_he'),
        xfail('special.scaled_modified_bessel_k0'),
        xfail('nn.functional.dropout3d', ''),
        xfail('special.scaled_modified_bessel_k1'),
        xfail('broadcast_shapes', ''),
        xfail('special.modified_bessel_k0'),
        xfail('linalg.vecdot', ''),
        xfail('linalg.ldl_factor', ''),
        xfail('special.modified_bessel_i1'),
        xfail('special.chebyshev_polynomial_t'),
        xfail('as_strided_scatter', ''),
        xfail('equal', ''),
        xfail('linalg.lu', ''),
        skip('linalg.ldl_solve', ''),
    }))
    def test_op_has_batch_rule(self, device, dtype, op):
        # needs to be fixed
        inplace_failures = (
            'abs',
            'acos',
            'acosh',
            'addbmm',
            'addcdiv',
            'addcmul',
            'addmm',
            'addmv',
            'addr',
            'asin',
            'asinh',
            'atan2',
            'atan',
            'atanh',
            'baddbmm',
            'clamp',
            'conj_physical',
            'cumprod',
            'cumsum',
            'div',
            'div',
            'floor_divide',
            'fmod',
            'heaviside',
            'hypot',
            'igamma',
            'igammac',
            'index_add',
            'index_copy',
            'ldexp',
            'lerp',
            'neg',
            'nextafter',
            'polygamma',
            'pow',
            'remainder',
            'scatter_add',
            'scatter',
            'square',
            'sub',
            'tril',
            'triu',
            'trunc',
            'xlogy',
        )
        self.opinfo_vmap_test(device, dtype, op, check_has_batch_rule=True,
                              skip_inplace=inplace_failures)

    def test_linalg_svd(self, device):
        # linalg_svd returns a tuple of three tensors, (U, S, Vh).
        # Given the same input, it may return different tensors,
        # because svd isn't unique. To test that the svd is correct, we multiply
        # U @ diag(S) @ Vh and check that that the output from vmap matches the
        # output from a for-loop.
        def compute_A(out):
            U, S, Vh = out
            m = U.shape[-1]
            n = Vh.shape[-2]
            diag_S = S.new_zeros(*S.shape[:-1], m, n)
            diag_S.diagonal(offset=0, dim1=-2, dim2=-1).copy_(S)
            return U @ diag_S @ Vh

        opinfos = [op for op in op_db if op.name == 'linalg.svd']
        assert len(opinfos) > 0

        for op in opinfos:
            self.opinfo_vmap_test(device, torch.float, op, check_has_batch_rule=True,
                                  postprocess_fn=compute_A)

    def test_linalg_eigh(self, device):
        # linalg_svd returns two tensors, (Q, L).
        # Given the same input, it may return different tensors,
        # because the eig decomposition isn't unique.
        # To test that eigh is correct, we multiply
        # Q @ diag(L) @ Qh and check that that the output from vmap matches the
        # output from a for-loop.
        def compute_A(out):
            L, Q = out
            n = Q.shape[-1]
            diag_L = L.new_zeros(*L.shape[:-1], n, n)
            diag_L.diagonal(offset=0, dim1=-2, dim2=-1).copy_(L)
            Qh = Q.transpose(-2, -1).conj()
            return Q @ diag_L @ Qh

        opinfos = [op for op in op_db if op.name == 'linalg.eigh']
        assert len(opinfos) > 0

        for op in opinfos:
            self.opinfo_vmap_test(device, torch.float, op, check_has_batch_rule=False,
                                  postprocess_fn=compute_A)

    def test_slogdet(self, device):
        # There's no OpInfo for this
        def test():
            B = 2
            x = torch.randn(2, 5, 5, device=device)
            self.vmap_outplace_test(torch.slogdet, (x,), {}, (0,))

        check_vmap_fallback(self, test, torch.slogdet)

    def test_fill__Tensor(self, device):
        # There's no OpInfo for fill_.Tensor, so here's an extra test for it.
        def test():
            B = 2
            args = (torch.randn(B, 3, device=device), torch.randn(B))
            self.vmap_inplace_test(Tensor.fill_, args, {}, (0, 0))

            args = (torch.randn(3, B, device=device), torch.randn(B))
            self.vmap_inplace_test(Tensor.fill_, args, {}, (-1, 0))

            args = (torch.randn(3, device=device), torch.randn(B))
            self.vmap_inplace_test(Tensor.fill_, args, {}, (None, 0))

            args = (torch.randn(3, B, device=device), torch.randn([]))
            self.vmap_inplace_test(Tensor.fill_, args, {}, (1, None))

        check_vmap_fallback(self, test, Tensor.fill_)

    def test_conv_double_backward(self, device):
        images = torch.randn(2, 1, 5, 5, device=device)
        weight = torch.randn(2, 1, 2, 2, device=device)
        bias = torch.randn(2, device=device)
        ggI = torch.randn_like(images)
        ggW = torch.randn_like(weight)
        ggb = torch.randn_like(bias)
        stride = (1, 1)
        padding = (0, 0)
        dilation = (1, 1)
        transposed = False
        output_padding = (0, 0)
        groups = 1
        output_mask = (True, True, True)
        gO = torch.randn_like(F.conv2d(images, weight, bias, stride, padding, dilation, groups))

        args = (
            ggI, ggW, ggb, gO, weight, images, stride, padding, dilation,
            transposed, output_padding, groups, output_mask,
        )
        op = torch.ops.aten._convolution_double_backward

        generator = get_fallback_and_vmap_exhaustive(op, args, {})

        def test():
            for loop_out, batched_out in generator:
                self.assertEqual(loop_out, batched_out, atol=1e-4, rtol=1e-4)

        check_vmap_fallback(self, test, op)

    def test_isnan(self, device):
        test = functools.partial(_vmap_test, check_propagates_grad=False)

        B, N, C, H, W = 2, 3, 24, 5, 7
        op = torch.isnan

        x = torch.randn(B, N, C, H, W)
        x[x > 0] = float('nan')
        test(self, op, (x,), in_dims=(0))

    def test_isinf(self, device):
        test = functools.partial(_vmap_test, check_propagates_grad=False)

        B, N, C, H, W = 2, 3, 24, 5, 7
        op = torch.isinf

        x = torch.randn(B, N, C, H, W)
        x[x > 0] = float('inf')
        test(self, op, (x,), in_dims=(0))

    def test_foo_like(self, device):
        # vfdev-5: Probably, we can remove this line. Flake8 reported as unused
        # test = functools.partial(_vmap_test, check_propagates_grad=False)

        B, N, C, H, W = 2, 3, 24, 5, 7
        for op in [torch.ones_like, torch.zeros_like]:
            x = torch.randn(B, N, C, H, W)
            # todo(chilli): test these better
            # Not testing correctness, just that they run
            vmap(op, in_dims=(0,))(x,)

    def test_flatten(self, device):
        test = functools.partial(_vmap_test, check_propagates_grad=False)

        op = torch.flatten

        x = torch.randn(2, 3, 4, 5)
        test(self, op, (x, 1, 2), in_dims=(0, None, None))

    def test_group_norm(self, device):
        test = functools.partial(_vmap_test, check_propagates_grad=False)

        B, N, C, H, W = 2, 3, 24, 5, 7
        op = F.group_norm

        x = torch.randn(B, N, C, H, W)
        weight = torch.randn(C)
        bias = torch.randn(C)
        test(self, op, (x, 3, weight, bias), in_dims=(0, None, None, None))

        x = torch.randn(B, N, C, H, W)
        weight = torch.randn(B, C)
        bias = torch.randn(B, C)
        test(self, op, (x, 4, weight, bias), in_dims=(0, None, 0, 0))

    def test_index_put(self, device):
        def test(f, t, idx, values):
            base = f(t[0], idx[0], values[0])
            self.assertEqual(vmap(f, in_dims=(0, 0, 0))(t, idx, values)[0], base)
            self.assertEqual(vmap(f, in_dims=(0, None, None))(t, idx[0], values[0])[0], base)
            self.assertEqual(vmap(f, in_dims=(0, None, 0))(t, idx[0], values)[0], base)
            self.assertEqual(vmap(f, in_dims=(0, 0, None))(t, idx, values[0])[0], base)

        def f(x, y, z):
            x[y] = z
            return x

        x = torch.randn(3, 4, 5, device=device)
        y = torch.zeros((3, 2), device=device).long()
        z = torch.randn(3, 2, 5, device=device)
        test(f, x, y, z)

        # indexing innermost dim
        def f(t, idx, values):
            t[:, idx] = values
            return t

        t = torch.zeros((3, 2, 3))
        values = torch.ones((3, 1, 2))
        idx = torch.tensor([[1, 2]]).expand((3, 2))
        test(f, t, idx, values)

        # indexing middle dim
        def f(t, idx, values):
            t[:, idx, :] = values
            return t

        t = torch.zeros((3, 2, 3, 3))
        values = torch.ones((3, 1, 2, 3))
        idx = torch.tensor([[0, 2]]).expand((3, 2))
        test(f, t, idx, values)

        # indexing with slices
        def f(t, values):
            t[:, :2, :] = values
            return t

        base = f(t[0], values[0])
        self.assertEqual(vmap(f, in_dims=(0, 0))(t, values)[0], base)
        self.assertEqual(vmap(f, in_dims=(0, None))(t, values[0])[0], base)

        # index_put_
        tensor = torch.zeros(3, 3, 4)
        value = torch.ones(3, 2)
        idxs = (torch.tensor([[0], [1], [2]]), torch.tensor([[0]]), torch.tensor([1, 2]))
        expected = torch.index_put_(tensor.clone(), idxs, value)

        def f(t, idx, v):
            torch.index_put_(t, idx, v)
            return t

        self.assertEqual(vmap(f, in_dims=(0, (None, None), 0))(tensor, idxs[1:], value), expected)
        self.assertEqual(vmap(f, in_dims=(0, (None, None), None))(tensor, idxs[1:], value[0]), expected)

    @parametrize('training', [True, False])
    @parametrize('track_running_stats', [True, False])
    @parametrize('affine', [True, False])
    def test_batch_norm(self, device, affine, track_running_stats, training):
        if not track_running_stats and not training:
            return

        test = functools.partial(_vmap_test, check_propagates_grad=False)
        BN = torch.nn.BatchNorm2d
        ensemble_size = 10
        hidden_dim = 3

        weights, buffers, _, _, _ = \
            functional_init_with_buffers(BN, [ensemble_size])(
                hidden_dim, affine=affine, track_running_stats=track_running_stats)

        inputs = [torch.randn(ensemble_size, 32, hidden_dim, 16, 16, device=device)]
        in_dims = [0]

        def append(inp, in_dim):
            inputs.append(inp)
            in_dims.append(in_dim)

        if track_running_stats:
            running_mean, running_var, _ = buffers
            append(running_mean.to(device), 0)
            append(running_var.to(device), 0)
        else:
            append(None, None)
            append(None, None)

        if affine:
            weight, bias = weights
            append(weight.to(device), 0)
            append(bias.to(device), 0)
        else:
            append(None, None)
            append(None, None)

        append(training, None)

        def op(inp, running_mean, running_var, weight, bias, training):
            res = F.batch_norm(inp, running_mean, running_var, weight, bias, training)
            if track_running_stats:
                return res, running_mean, running_var
            return res

        test(self, op, tuple(inputs), in_dims=tuple(in_dims))

    def test_torch_return_types_returns(self, device):
        t = torch.randn(3, 2, 2, device=device)
        self.assertTrue(isinstance(vmap(torch.min, (0, None))(t, 0), torch.return_types.min))
        self.assertTrue(isinstance(vmap(torch.max, (0, None))(t, 0), torch.return_types.max))
        self.assertTrue(isinstance(vmap(torch.topk, (0, None, None))(t, 1, 0), torch.return_types.topk))
        self.assertTrue(isinstance(vmap(torch.linalg.eig, (0))(t), torch.return_types.linalg_eig))

    def test_namedtuple_returns(self, device):
        Point = namedtuple('Point', ['x', 'y'])

        def f(x, y):
            return Point(x=x, y=y)

        x = torch.randn(2, 5, device=device)
        y = torch.randn(2, 3, device=device)
        self.assertTrue(isinstance(vmap(f)(x, y), Point))

    def test_inplace_on_view(self, device):
        def func(leaf):
            base = leaf * leaf
            view = base.transpose(0, 1)
            view[2:4, 2:4] *= 2
            view[0:2, 0:2].diagonal().sin_()
            view = view[1:3, 1:3]
            view.cos_()
            return view

        def push_vjp(leaf, gout):
            _, vjp_fn = vjp(func, leaf)
            result, = vjp_fn(gout)
            return result

        leaf = torch.randn(4, 4, device=device)
        gout = torch.randn(2, 2, device=device)
        args = (leaf, gout)

        for args, in_dims, _, in generate_vmap_inputs(args, {}):
            if in_dims[1] is None:
                # triggers some composite compliance problem
                continue
            self.vmap_outplace_test(push_vjp, args, {}, in_dims)

    def test_advanced_indexing(self, device):
        def test(f, args):
            for loop_out, batched_out in get_fallback_and_vmap_exhaustive(f, args, {}):
                self.assertEqual(loop_out, batched_out)

        def f(x, idx):
            return x[:, idx]

        def f2(x, idx):
            return x[idx, :]

        def f3(x, idx):
            return x[:, :, idx]

        inps = (torch.randn(5, 5, 5, device=device),
                torch.randn(5, 5, 5, 5, device=device),
                torch.randn(5, 5, 5, 5, 5, device=device))
        idxes = (torch.tensor([0, 1, 2], device=device),
                 torch.tensor([0, 1, 2], device=device).reshape(3, 1),
                 torch.tensor([0, 1, 2], device=device).reshape(3, 1, 1))
        for (inp, idx) in itertools.product(inps, idxes):
            test(f, (inp, idx))
            test(f2, (inp, idx))
            test(f3, (inp, idx))

    def test_nested_advanced_indexing(self, device):
        e = torch.rand(7, 4, device=device)
        idx = torch.tensor([0, 1], device=device).view(2, 1)

        # simple reference implementation for comparison
        def _fake_vmap(f, in_dims=0, out_dims=0):
            def w(input):
                r = [f(input.select(in_dims, i)) for i in range(input.size(in_dims))]
                return torch.stack(r, out_dims)

            return w

        def with_vmap(_vmap):
            def g(idx_):
                def f(e_):
                    return e_[idx_]

                return _vmap(f, in_dims=1)(e)

            r = _vmap(g)(idx)
            return r

        a = with_vmap(vmap)
        b = with_vmap(_fake_vmap)
        self.assertEqual(a, b)

    @ops(filter(lambda op: "linalg" in op.name, op_db + additional_op_db), allowed_dtypes=(torch.float,))
    @skipOps('TestVmapOperatorsOpInfo', 'test_vmap_linalg_failure_1D_input', {
        xfail('linalg.vector_norm'),  # can accept vector inputs
        xfail('linalg.norm'),  # can accept vector inputs
        xfail('linalg.norm', 'subgradients_at_zero'),  # can accept vector inputs
        skip('linalg.multi_dot'),  # accepts list of tensor inputs, has its own special test
        xfail('linalg.vander'),
        xfail('linalg.vecdot'),
        skip('linalg.ldl_solve', ''),
    })
    def test_vmap_linalg_failure_1D_input(self, device, dtype, op):
        for sample in op.sample_inputs(device, dtype, requires_grad=False):
            if sample.input.dim() != 2 or sample.input.shape[0] == 0:
                continue
            test_input = sample.input[0]  # using the sample input avoids numerical inconsistency issues
            with self.assertRaisesRegex(RuntimeError, "dimension"):
                op(test_input, *sample.args, **sample.kwargs)

            def op_wrapper(inp):
                return op(inp, *sample.args, **sample.kwargs)

            # square inputs are more likely to pass linalg checks
            test_input = test_input.expand(test_input.shape[0], test_input.shape[0])
            with self.assertRaisesRegex(RuntimeError, "dimension"):
                return vmap(op_wrapper)(test_input)

    def test_vmap_multi_dot_failure_1D_input(self):
        # special exception for first and last tensors so making giving 3 items avoids special cases
        inputs = (torch.randn(3, 3), torch.randn(3), torch.randn(3, 3))
        with self.assertRaisesRegex(RuntimeError, "tensor 1 must be 2D but got 1D"):
            torch.linalg.multi_dot(inputs)

        # square inputs are more likely to pass linalg checks
        inputs = tuple(i.expand(i.shape[0], i.shape[0]) for i in inputs)
        with self.assertRaisesRegex(RuntimeError, "tensor 1 must be 2D but got 1D"):
            return vmap(torch.linalg.multi_dot)(inputs)


class TestRandomness(TestCase):
    def _reset_random(self, generator, orig_state, use_generator, seed):
        return generator.set_state(orig_state) if use_generator else torch.manual_seed(seed)

    def _get_image(self, batched_input, batch_size, device):
        if batched_input == "first":
            return torch.ones([batch_size, 3, 3, 14, 14], device=device)
        if batched_input == "last":
            return torch.ones([3, 3, 14, 14, batch_size], device=device)
        assert batched_input == "none"
        return torch.ones([3, 3, 14, 14], device=device)

    def _assert_all_slices_equal(self, tensor):
        expected = tensor[0]
        self.assertTrue((tensor == expected).all())

    def _assert_all_slices_unique(self, tensor):
        B0 = tensor.shape[0]
        slices_equal = vmap(vmap(lambda x, y: (x == y).all(), (0, None)), (None, 0))(tensor, tensor)
        assert slices_equal.shape == (B0, B0)
        slices_equal.diagonal().zero_()
        self.assertEqual(slices_equal, torch.zeros_like(slices_equal))

    def _assert_throws_in_error_mode(self, fn, args, in_dims):
        with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"):
            vmap(fn, in_dims=in_dims, randomness="error")(*args)

    def _assert_throws_in_different_mode_inplace(self, fn, args, in_dims):
        with self.assertRaisesRegex(RuntimeError, r"different inplace randomness on an unbatched tensor"):
            vmap(fn, in_dims=in_dims, randomness="different")(*args)

    def _assert_throws_in_same_mode_batched(self, fn, args, in_dims):
        with self.assertRaisesRegex(RuntimeError,
                                    r"Vmap does not currently support same randomness with a batched tensor input"):
            vmap(fn, in_dims=in_dims, randomness="same")(*args)

    def _in_dims(self, *batched_strings):

        def get_in_dim(batched_string):
            if batched_string == "first":
                return 0
            if batched_string == "last":
                return -1
            assert batched_string == "none"
            return None

        batched_strings = batched_strings + ("first",)  # for the always batched as first dim dummy argument
        return tuple(get_in_dim(batched_string) for batched_string in batched_strings)

    @parametrize('randomness', ['same', 'different', 'error'])
    @parametrize('use_generator', [True, False])
    def test_factory_ops(self, device, randomness, use_generator):
        generator = torch.Generator(device=device)
        orig_state = generator.get_state()
        kwargs = {'device': device, 'generator': generator} if use_generator else {'device': device}
        ops = [
            lambda _, shape: torch.randn(shape, **kwargs),
            lambda _, shape: torch.rand(shape, **kwargs),
            lambda _, shape: torch.randint(100, shape, **kwargs),
            lambda _, shape: torch.randint(5, 100, shape, **kwargs),
            lambda _, shape: torch.normal(0., 1., shape, **kwargs),
        ]
        B0 = 4
        shape = (3, 3)
        seed = 1234567

        for op in ops:
            passed = torch.randn(B0, device=device)
            if randomness == 'error':
                self._assert_throws_in_error_mode(op, (passed, shape), in_dims=(0, None))
                return

            generator = self._reset_random(generator, orig_state, use_generator, seed)
            vmap_result = vmap(op, in_dims=(0, None), randomness=randomness)(passed, shape)

            generator = self._reset_random(generator, orig_state, use_generator, seed)
            if randomness == "different":
                expected = op(passed, [B0, *shape])
                self._assert_all_slices_unique(vmap_result)
                self.assertEqual(vmap_result, expected)
            else:
                expected = op(passed, shape)
                self._assert_all_slices_equal(vmap_result)
                for i in range(B0):
                    self.assertEqual(vmap_result[i], expected)

    @parametrize('randomness', ['same', 'different', 'error'])
    @parametrize('use_generator', [True, False])
    def test_randperm(self, device, randomness, use_generator):
        # needs a special case because randperm doesn't take a batch size
        B0 = 4
        seed = 1234567
        passed = torch.randn(B0, device=device)

        torch.manual_seed(seed)
        generator = torch.Generator(device=device)
        orig_state = generator.get_state()

        kwargs = {'device': device, 'generator': generator} if use_generator else {'device': device}

        if randomness == 'error':
            with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"):
                vmap(lambda _: torch.randperm(10, **kwargs), randomness=randomness)(passed)
            return

        vmap_result = vmap(lambda _: torch.randperm(10, **kwargs), randomness=randomness)(passed)
        generator = generator.set_state(orig_state)
        torch.manual_seed(seed)
        if randomness == 'different':
            for i in range(B0):
                expected = torch.randperm(10, **kwargs)
                self.assertEqual(vmap_result[i], expected)
        else:
            expected = torch.randperm(10, **kwargs)
            for i in range(B0):
                self.assertEqual(vmap_result[i], expected)

    @parametrize('randomness', ['error', 'same', 'different'])
    @parametrize('batched_input', ["first", "last", "none"])
    def test_dropout(self, device, randomness, batched_input):
        def op(t, ignored):
            return torch.nn.functional.dropout(torch.ones_like(t), training=True)

        B0 = 4
        always_batched = torch.randn((B0,))
        passed = self._get_image(batched_input, B0, device)
        in_dims = self._in_dims(batched_input)

        if randomness == 'error':
            with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"):
                vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)
            return

        vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)

        # Check that the randomness is within bounds...
        # ideally this is close to 0.5
        p_estimate = vmap_result.mean() / 2
        self.assertTrue(p_estimate < 0.75)
        self.assertTrue(p_estimate > 0.25)

        if randomness == 'different':
            self._assert_all_slices_unique(vmap_result)
            return

        assert randomness == 'same'
        self._assert_all_slices_equal(vmap_result)

    @parametrize('randomness', ['error', 'same', 'different'])
    @parametrize('batched_input', ["first", "last", "none"])
    def test_alpha_dropout(self, device, randomness, batched_input):
        def op(t, ignored):
            return torch.nn.functional.alpha_dropout(torch.ones_like(t), training=True)

        B0 = 4
        always_batched = torch.randn((B0,))
        passed = self._get_image(batched_input, B0, device)
        in_dims = self._in_dims(batched_input)

        if randomness == 'error':
            with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"):
                vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)
            return

        # I have no clue how to actually test corectness of alpha dropout because the docs
        # seem wrong: https://github.com/pytorch/pytorch/issues/74004
        vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)
        if randomness == 'different':
            self._assert_all_slices_unique(vmap_result)
            return

        assert randomness == 'same'
        self._assert_all_slices_equal(vmap_result)

    @parametrize('randomness', ['error', 'same', 'different'])
    @parametrize('batched_input', ["first", "last", "none"])
    @parametrize('dim', [2, 3])
    def test_feature_dropout(self, device, randomness, batched_input, dim):
        def op(t, ignored):
            f = torch.nn.functional.dropout2d if dim == 2 else torch.nn.functional.dropout3d
            return f(torch.ones_like(t), training=True)

        B0 = 4
        always_batched = torch.randn((B0,))
        passed = self._get_image(batched_input, B0, device)
        if dim == 3:
            unsqueeze_dim = -2 if batched_input == "last" else -1
            passed = passed.unsqueeze(unsqueeze_dim)
        in_dims = self._in_dims(batched_input)

        if randomness == 'error':
            with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"):
                vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)
            return

        vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)

        # Check that the randomness is within bounds...
        # ideally this is close to 0.5
        p_estimate = vmap_result.mean() / 2
        self.assertTrue(p_estimate < 0.75)
        self.assertTrue(p_estimate > 0.25)

        # Check the "feature" pattern
        dims = [-1, -2] if dim == 2 else [-1, -2, -3]
        planes_numel = 2 * vmap_result.numel() / (vmap_result.shape[0] * vmap_result.shape[1] * vmap_result.shape[2])
        planes = vmap_result.sum(dims)
        result = (planes == 0) ^ (planes == planes_numel)
        self.assertEqual(result, torch.ones_like(result, dtype=torch.bool))

        if randomness == 'different':
            self._assert_all_slices_unique(vmap_result)
            return

        assert randomness == 'same'
        self._assert_all_slices_equal(vmap_result)

    @parametrize('randomness', ['error', 'same', 'different'])
    @parametrize('batched_input', ["first", "last", "none"])
    def test_feature_alpha_dropout(self, device, randomness, batched_input):
        def op(t, ignored):
            return torch.nn.functional.feature_alpha_dropout(torch.ones_like(t), training=True)

        B0 = 4
        always_batched = torch.randn((B0,))
        passed = self._get_image(batched_input, B0, device)
        unsqueeze_dim = -2 if batched_input == "last" else -1
        passed = passed.unsqueeze(unsqueeze_dim)
        in_dims = self._in_dims(batched_input)

        if randomness == 'error':
            with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"):
                vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)
            return

        vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)

        # I have no clue how to actually test corectness of alpha dropout because the docs
        # seem wrong: https://github.com/pytorch/pytorch/issues/74004

        # Check the "feature" pattern
        dims = [-1, -2, -3]
        planes = vmap_result.sum(dims)
        max_elt = planes.max()
        min_elt = planes.min()
        result = (planes == min_elt) ^ (planes == max_elt)
        self.assertEqual(result, torch.ones_like(result, dtype=torch.bool))

        if randomness == 'different':
            self._assert_all_slices_unique(vmap_result)
            return

        assert randomness == 'same'
        self._assert_all_slices_equal(vmap_result)

    @parametrize('randomness', ['error', 'same', 'different'])
    @parametrize('batched_input', ["first", "last", "none"])
    def test_like_functions(self, device, randomness, batched_input):
        seed = 1234567
        supported_ops = [
            lambda t, _: torch.randint_like(t, 20),
            lambda t, _: torch.randint_like(t, 0, 20),
            lambda t, _: torch.rand_like(t),
            lambda t, _: torch.randn_like(t),
        ]
        B0 = 4

        for op in supported_ops:
            always_batched = torch.randn(B0)
            passed = self._get_image(batched_input, B0, device)
            in_dims = self._in_dims(batched_input)

            if randomness == 'error':
                with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"):
                    vmap(op, in_dims=in_dims, randomness=randomness)(passed, always_batched)
                return

            torch.manual_seed(seed)
            vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)

            torch.manual_seed(seed)

            if batched_input == "last":
                passed = passed.movedim(-1, 0)
            if randomness == 'different':
                if batched_input == "none":
                    passed = passed.expand(B0, *passed.shape)
                expected = op(passed, 0)

                self._assert_all_slices_unique(vmap_result)
                self.assertEqual(expected, vmap_result)
                return

            assert randomness == 'same'
            if batched_input != "none":
                passed = passed[0]
            expected = op(passed, 0)
            self._assert_all_slices_equal(vmap_result)
            for i in range(B0):
                self.assertEqual(expected, vmap_result[i])

    @parametrize('use_generator', [True, False])
    @parametrize('randomness', ['error', 'same', 'different'])
    @parametrize('batched_input', ["first", "last", "none"])
    def test_random_unary_inplace(self, device, use_generator, randomness, batched_input):
        generator = torch.Generator(device=device)
        orig_state = generator.get_state()
        kwargs = {'generator': generator} if use_generator else {}
        ops = [
            lambda t, _: t.random_(**kwargs),
            lambda t, _: t.random_(100, **kwargs),
            lambda t, _: t.random_(-5, 100, **kwargs),
            lambda t, _: t.normal_(**kwargs),
            lambda t, _: t.bernoulli_(**kwargs),
            lambda t, _: t.cauchy_(**kwargs),
            lambda t, _: t.exponential_(**kwargs),
            lambda t, _: t.geometric_(0.5, **kwargs),
            lambda t, _: t.log_normal_(**kwargs),
            lambda t, _: t.uniform_(**kwargs),
        ]
        B0 = 4
        seed = 1234567
        in_dims = self._in_dims(batched_input)

        for op in ops:
            # because of in place updates, clone inputs
            always_batched = torch.randn(B0, device=device)
            passed = self._get_image(batched_input, B0, device)
            passed_expected = passed.clone()

            if randomness == 'error':
                self._assert_throws_in_error_mode(op, (passed, always_batched), in_dims=in_dims)
                return
            if randomness == 'different' and batched_input == "none":
                self._assert_throws_in_different_mode_inplace(op, (passed, always_batched), in_dims=in_dims)
                return

            generator = self._reset_random(generator, orig_state, use_generator, seed)
            vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(passed, always_batched)

            if batched_input == "last":
                passed_expected = passed_expected.movedim(-1, 0)
            generator = self._reset_random(generator, orig_state, use_generator, seed)
            if randomness == "different":
                expected = op(passed_expected, always_batched)
                self._assert_all_slices_unique(vmap_result)
                self.assertEqual(vmap_result, expected)
            else:
                if batched_input != "none":
                    passed_expected = passed_expected[0].clone()  # bug in pytorch, normal_ on views doesn't work
                expected = op(passed_expected, always_batched)
                self._assert_all_slices_equal(vmap_result)
                for i in range(B0):
                    self.assertEqual(vmap_result[i], expected)

    @parametrize('use_generator', [True, False])
    @parametrize('randomness', ['error', 'same', 'different'])
    @parametrize('batched_input', ["first", "last", "none"])
    @parametrize('batched_probability', ["first", "last", "none"])
    def test_bernoulli_in_place(self, device, use_generator, randomness, batched_input, batched_probability):
        B0 = 4
        seed = 1234567
        generator = torch.Generator(device=device)
        orig_state = generator.get_state()
        kwargs = {'generator': generator} if use_generator else {}
        in_dims = self._in_dims(batched_input, batched_probability)

        def op(t, p, ignored):
            return t.bernoulli_(p, **kwargs)

        # because of in place updates, clone inputs
        always_batched = torch.randn(B0, device=device)
        input = self._get_image(batched_input, B0, device)
        input_expected = input.clone()
        probability = self._get_image(batched_probability, B0, device) - 0.5

        if randomness == 'error':
            self._assert_throws_in_error_mode(op, (input, probability, always_batched), in_dims=in_dims)
            return
        if randomness == 'same' and batched_probability != "none":
            self._assert_throws_in_same_mode_batched(op, (input, probability, always_batched), in_dims=in_dims)
            return
        if batched_input == "none" and batched_probability != "none":
            regex = r"there exists a Tensor `other` in extra_args that has more elements than `self`"
            with self.assertRaisesRegex(RuntimeError, regex):
                vmap(op, in_dims=in_dims, randomness=randomness)(input, probability, always_batched)
            return
        if randomness == 'different' and batched_input == "none":
            self._assert_throws_in_different_mode_inplace(op, (input, probability, always_batched), in_dims=in_dims)
            return

        self._reset_random(generator, orig_state, use_generator, seed)
        vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(input, probability, always_batched)

        self._reset_random(generator, orig_state, use_generator, seed)
        if batched_input == "last":
            input_expected = input_expected.movedim(-1, 0)
        if batched_probability == "last":
            probability = probability.movedim(-1, 0)
        if randomness == "different":
            expected = op(input_expected, probability, always_batched)
            self._assert_all_slices_unique(vmap_result)
            self.assertEqual(vmap_result, expected)
        else:
            if batched_input != "none":
                input_expected = input_expected[0]
            expected = op(input_expected, probability, always_batched)
            self._assert_all_slices_equal(vmap_result)
            for i in range(B0):
                self.assertEqual(vmap_result[i], expected)

    @parametrize('use_generator', [True, False])
    @parametrize('randomness', ['error', 'same', 'different'])
    @parametrize('batched_input', ["first", "last", "none"])
    @parametrize('batched_other', ["first", "last", "none"])
    def test_random_binary_out_of_place(self, device, use_generator, randomness, batched_input, batched_other):
        generator = torch.Generator(device=device)
        orig_state = generator.get_state()
        kwargs = {'generator': generator} if use_generator else {}
        ops = [
            lambda t, o, _: torch.normal(t, o, **kwargs),
            lambda t, o, _: torch.binomial(t, (o - 0.5), **kwargs),
        ]

        B0 = 4
        seed = 1234567
        in_dims = self._in_dims(batched_input, batched_other)

        for op in ops:
            always_batched = torch.randn(B0, device=device)
            input = self._get_image(batched_input, B0, device)
            other = self._get_image(batched_other, B0, device)

            if randomness == 'error':
                self._assert_throws_in_error_mode(op, (input, other, always_batched), in_dims=in_dims)
                return
            if randomness == 'same' and (batched_input != "none" or batched_other != "none"):
                self._assert_throws_in_same_mode_batched(op, (input, other, always_batched), in_dims=in_dims)
                return

            generator = self._reset_random(generator, orig_state, use_generator, seed)
            vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(input, other, always_batched)

            if batched_input == "last":
                input = input.movedim(-1, 0)
            if batched_other == "last":
                other = other.movedim(-1, 0)

            generator = self._reset_random(generator, orig_state, use_generator, seed)
            if randomness == "different":
                if batched_input == "none":
                    input = input.expand(B0, *input.shape)
                expected = op(input, other, always_batched)
                self._assert_all_slices_unique(vmap_result)
                self.assertEqual(vmap_result, expected)
            else:
                assert batched_input == "none" and batched_other == "none"
                expected = op(input, other, always_batched)
                self._assert_all_slices_equal(vmap_result)
                for i in range(B0):
                    self.assertEqual(vmap_result[i], expected)

    @parametrize('use_generator', [True, False])
    @parametrize('randomness', ['error', 'same', 'different'])
    @parametrize('batched_input', ["first", "last", "none"])
    def test_random_unary_out_of_place(self, device, use_generator, randomness, batched_input):
        generator = torch.Generator(device=device)
        orig_state = generator.get_state()
        kwargs = {'generator': generator} if use_generator else {}
        ops = [
            lambda t, _: torch.normal(0., torch.abs(t), **kwargs),
            lambda t, _: torch.normal(t, 1., **kwargs),
            lambda t, _: torch.bernoulli(t - 0.5, **kwargs),
            lambda t, _: torch.bernoulli(t, 0.5, **kwargs),
            lambda t, _: torch._standard_gamma(t, **kwargs),
            lambda t, _: torch._sample_dirichlet(t, **kwargs),
            lambda t, _: torch.poisson(t, **kwargs),
        ]

        B0 = 4
        seed = 1234567
        in_dims = self._in_dims(batched_input)

        for op in ops:
            always_batched = torch.randn(B0, device=device)
            passed = self._get_image(batched_input, B0, device)
            if randomness == 'error':
                self._assert_throws_in_error_mode(op, (passed, always_batched), in_dims=in_dims)
                return
            if randomness == 'same' and batched_input != "none":
                self._assert_throws_in_same_mode_batched(op, (passed, always_batched), in_dims=in_dims)
                return

            generator = self._reset_random(generator, orig_state, use_generator, seed)
            vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(passed, always_batched)

            generator = self._reset_random(generator, orig_state, use_generator, seed)
            if randomness == "different":
                if batched_input == "none":
                    passed = passed.expand(B0, *passed.shape)
                if batched_input == "last":
                    passed = passed.movedim(-1, 0)
                expected = op(passed, always_batched)
                self._assert_all_slices_unique(vmap_result)
                self.assertEqual(vmap_result, expected)
            else:
                expected = op(passed, always_batched)
                self._assert_all_slices_equal(vmap_result)
                for i in range(B0):
                    self.assertEqual(vmap_result[i], expected)

    @parametrize('use_generator', [True, False])
    @parametrize('randomness', ['error', 'same', 'different'])
    @parametrize('batched_call', [True, False])
    @parametrize('batched_input', ["first", "last", "none"])
    def test_multinomial(self, device, use_generator, randomness, batched_call, batched_input):
        def flatten_input(input, batch_call, batch_location):
            if batch_call and batch_location != "none":
                final_size = 3  # [B0, B, N]
            elif not batch_call and batch_location == "none":
                final_size = 1  # [N]
            else:
                final_size = 2  # [B0, N] or [B, N]

            start_idx = final_size - 1
            end_idx = -1
            if batch_location == "last":
                start_idx -= 1
                end_idx -= 1   # gets to correct final size because using negative indices

            ret = input.flatten(start_idx, end_idx)
            assert ret.dim() == final_size
            return ret

        def op(input, _):
            return torch.multinomial(input, 10, **kwargs)

        generator = torch.Generator(device=device)
        orig_state = generator.get_state()
        kwargs = {'generator': generator} if use_generator else {}

        B0 = 4
        seed = 1234567
        in_dims = self._in_dims(batched_input)

        always_batched = torch.randn(B0, device=device)
        passed = self._get_image(batched_input, B0, device)
        passed = flatten_input(passed, batched_call, batched_input)
        if randomness == 'error':
            self._assert_throws_in_error_mode(op, (passed, always_batched), in_dims=in_dims)
            return
        if randomness == 'same' and batched_input != "none":
            self._assert_throws_in_same_mode_batched(op, (passed, always_batched), in_dims=in_dims)
            return

        generator = self._reset_random(generator, orig_state, use_generator, seed)
        vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(passed, always_batched)

        generator = self._reset_random(generator, orig_state, use_generator, seed)

        if randomness == "different":
            if batched_input == "none":
                passed = passed.expand(B0, *passed.shape)
            if batched_input == "last":
                passed = passed.movedim(-1, 0)
            orig_passed_size = passed.shape[:2] if batched_call else passed.shape[:1]
            passed = passed.flatten(0, 1) if batched_call else passed
            expected = op(passed, always_batched)
            expected = expected.reshape(*orig_passed_size, 10)
            self._assert_all_slices_unique(vmap_result)
            self.assertEqual(vmap_result, expected)
        else:
            expected = op(passed, always_batched)
            self._assert_all_slices_equal(vmap_result)
            for i in range(B0):
                self.assertEqual(vmap_result[i], expected)

    def test_unsupported_random(self, device):
        x = torch.randn(3, device=device)
        y = x.abs()
        z = x.abs()
        with self.assertRaisesRegex(RuntimeError, "calling out variants"):
            def f(x):
                return torch.randn(3, device=device, out=y)
            vmap(f, randomness='same')(x)
        with self.assertRaisesRegex(RuntimeError, "calling out variants"):
            def f(x0, x1):
                return torch.normal(x, y, out=x)
            vmap(f, randomness='same')(z, z)
        with self.assertRaisesRegex(RuntimeError, "do not yet support"):
            def f(z):
                return torch.rrelu(x)
            vmap(f, randomness='same')(z)

    @parametrize('in_dim', [0, 1, 2])
    @parametrize('out_dim', [0, 1, 2])
    def test_chunk_vmap(self, in_dim, out_dim):

        randomness = "different"

        x = torch.randn(4, 5, 6)

        def f(x):
            y = x.sin() + torch.rand_like(x)
            return y

        for chunks in [1, 2, 3, 4, 7, 10, 16]:
            output = chunk_vmap(
                f, in_dims=in_dim, out_dims=out_dim, randomness=randomness, chunks=chunks
            )(x)
            self._assert_all_slices_unique(output)


    def test_jacfwd_with_random(self):
        # checks on behavior are above, this just checks that jacfwd respects
        # the randomness param

        x = torch.rand(3, 4)
        with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"):
            jacfwd(torch.bernoulli)(x)

        # x isn't batched so use bernoulli since it doesn't do inplace randomness
        jacfwd(torch.bernoulli, randomness="same")(x)
        jacfwd(torch.bernoulli, randomness="different")(x)


class TestTransformFailure(TestCase):
    @parametrize('transform', ['vmap', 'grad', 'grad_and_value', 'vjp', 'jvp', 'jacrev', 'jacfwd'])
    def test_fails_with_autograd_function(self, device, transform):
        class Test(torch.autograd.Function):
            @staticmethod
            def forward(_, input):
                return input

            @staticmethod
            def backward(_, grad_input):
                return grad_input

        transform = getattr(functorch, transform)

        def f(x):
            return Test.apply(x)

        if transform == grad or transform == grad_and_value:
            input = torch.tensor(4.)
        else:
            input = torch.randn(5)

        if transform == vjp:
            transform = functools.partial(transform, f)
        elif transform == jvp:
            input = (input,)
            transform = functools.partial(transform, f, input)
        else:
            transform = transform(f)

        with self.assertRaisesRegex(RuntimeError, "autograd.Function"):
            transform(input)

only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestVmapOperatorsOpInfo, globals(), only_for=only_for)

instantiate_device_type_tests(
    TestVmapBatchedGradient,
    globals(),
    only_for=only_for,
)
instantiate_device_type_tests(TestTransformFailure, globals(), only_for=only_for)
instantiate_device_type_tests(TestRandomness, globals(), only_for=only_for)

if __name__ == '__main__':
    run_tests()
