# Owner(s): ["module: nn"]
from functools import reduce
from itertools import repeat
import unittest
import subprocess
import sys
import os
import random
import itertools
import math

from torch._six import inf, nan
import torch
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_UBSAN, set_default_dtype, \
    instantiate_parametrized_tests, slowTest, parametrize as parametrize_test, subtest, skipIfMps
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_nn import NNTestCase, _test_bfloat16_ops, _test_module_empty_input
from torch.testing._internal.common_device_type import largeTensorTest, onlyNativeDeviceTypes, dtypes, \
    instantiate_device_type_tests, skipCUDAIfRocm, expectedFailureMeta, dtypesIfCUDA, onlyCPU, onlyCUDA, \
    TEST_WITH_ROCM
from torch.testing._internal.common_dtype import floating_types_and
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import gradcheck, gradgradcheck


class TestAvgPool(TestCase):
    def _sum_pool2d(self, x, kernel_size):
        windows = torch.nn.functional.unfold(x, kernel_size=kernel_size, stride=kernel_size)
        return torch.sum(windows, dim=1)

    def _sum_pool3d(self, x, kernel_size):
        # Because unfold does not support 3D sliding window we will split tensor to multiple tensors and calculate sum
        h = kernel_size[0]
        splited_x = [t.sum(0) for t in x.split(h) if t.size(0) == h]
        # sum_pool2d assumes tensor in (1, 1, n, m) view, so unsqueeze two times
        splited_x = [self._sum_pool2d(t.unsqueeze(0).unsqueeze(0), kernel_size[1:]) for t in splited_x]
        joined_x = torch.cat(splited_x)
        return joined_x.view(1, joined_x.numel())

    def _avg_pool2d(self, x, kernel_size):
        size = reduce((lambda x, y: x * y), kernel_size)
        return self._sum_pool2d(x, kernel_size) / size

    def _avg_pool3d(self, x, kernel_size):
        size = reduce((lambda x, y: x * y), kernel_size)
        return self._sum_pool3d(x, kernel_size) / size

    def test_doubletensor_avg_pool2d(self):
        n, m = 5, 8
        input = torch.rand(1, 1, n, m, dtype=torch.double)
        for i in range(1, n + 1):
            for j in range(1, m + 1):
                actual = torch.nn.functional.avg_pool2d(input[0], (i, j))
                actual = actual.view(1, actual.numel())
                expected = self._avg_pool2d(input, (i, j))
                self.assertEqual(actual, expected, rtol=0, atol=1e-5)

    def test_doubletensor_avg_pool2d_with_divisor(self):
        n, m = 3, 3
        input = torch.rand(1, 1, n, m, dtype=torch.double)
        for i in range(1, n + 1):
            for j in range(1, m + 1):
                for divisor in [1, 7, i * j]:
                    actual = F.avg_pool2d(input[0], (i, j), divisor_override=divisor)
                    actual = actual.view(1, actual.numel())
                    expected = self._sum_pool2d(input, (i, j)) / divisor
                    self.assertEqual(actual, expected, rtol=0, atol=1e-5)

    def test_doubletensor_avg_pool3d(self):
        h, w, d = 5, 6, 7
        input = torch.rand(h, w, d, dtype=torch.double)
        for i in range(1, h + 1):
            for j in range(1, w + 1):
                for k in range(1, d + 1):
                    actual = torch.nn.functional.avg_pool3d(input.unsqueeze(0), (i, j, k))
                    actual = actual.view(1, actual.numel())
                    expected = self._avg_pool3d(input, (i, j, k))
                    self.assertEqual(actual, expected, rtol=0, atol=1e-5)

    def test_doubletensor_avg_pool3d_with_divisor(self):
        h, w, d = 6, 5, 7
        input = torch.rand(h, w, d, dtype=torch.double)
        for i in range(1, h + 1):
            for j in range(1, w + 1):
                for k in range(1, d + 1):
                    for divisor in [1, 7, i * j]:
                        actual = torch.nn.functional.avg_pool3d(input.unsqueeze(0), (i, j, k), divisor_override=divisor)
                        actual = actual.view(1, actual.numel())
                        expected = self._sum_pool3d(input, (i, j, k)) / divisor
                        self.assertEqual(actual, expected, rtol=0, atol=1e-5)

    def test_avg_pool1d_ceil_mode(self):
        # Regression test for gh-36977
        x = 10 * torch.randn((1, 16, 4))
        y = torch.nn.functional.avg_pool1d(
            x, ceil_mode=True, count_include_pad=True, kernel_size=1, stride=2)
        self.assertTrue(not torch.isnan(y).any())

        if TEST_CUDA:
            y = torch.nn.functional.avg_pool1d(
                x.to('cuda'), ceil_mode=True, count_include_pad=True, kernel_size=1, stride=2)
            self.assertTrue(not torch.isnan(y).any())

    def test_avg_pool2d_ceil_mode(self):
        # Regression test for gh-36977
        x = 10 * torch.randn((1, 16, 4, 4))
        y = torch.nn.functional.avg_pool2d(
            x, ceil_mode=True, count_include_pad=True, kernel_size=(1, 2),
            padding=(0, 1), stride=2)
        self.assertTrue(not torch.isnan(y).any())

        if TEST_CUDA:
            y = torch.nn.functional.avg_pool2d(
                x.to('cuda'), ceil_mode=True, count_include_pad=True, kernel_size=(1, 2),
                padding=(0, 1), stride=2)
            self.assertTrue(not torch.isnan(y).any())

    def test_avg_pool3d_ceil_mode(self):
        # Regression test for gh-36977
        x = 10 * torch.randn((1, 16, 4, 4, 4))
        y = torch.nn.functional.avg_pool3d(
            x, ceil_mode=True, count_include_pad=True, kernel_size=(1, 2, 3), stride=2)
        self.assertTrue(not torch.isnan(y).any())

        if TEST_CUDA:
            y = torch.nn.functional.avg_pool3d(
                x.to('cuda'), ceil_mode=True, count_include_pad=True, kernel_size=(1, 2, 3), stride=2)
            self.assertTrue(not torch.isnan(y).any())


class TestPoolingNN(NNTestCase):
    def test_adaptive_pooling_input_size(self):
        for numel in (2, 3):
            for pool_type in ('Max', 'Avg'):
                cls_name = 'Adaptive{}Pool{}d'.format(pool_type, numel)
                module_cls = getattr(nn, cls_name)
                output_size = (2,) * numel
                module = module_cls(output_size)

                input = torch.randn(output_size)
                self.assertRaises(ValueError, lambda: module(input))

    def test_adaptive_pooling_size_none(self):
        for numel in (2, 3):
            for pool_type in ('Max', 'Avg'):
                cls_name = 'Adaptive{}Pool{}d'.format(pool_type, numel)
                module_cls = getattr(nn, cls_name)
                output_size = (2,) * (numel - 1) + (None,)
                module = module_cls(output_size)

                input = torch.randn((4,) * (numel + 1))
                output = module(input)
                self.assertEqual(output.size(), (4,) + (2,) * (numel - 1) + (4,))

    @unittest.skipIf(TEST_WITH_UBSAN, "signed integer overflow error with UBSAN")
    def test_adaptive_pooling_size_overflow(self):
        # 0x0x3fffffffffffffff * 2 * 2 = 0xfffffffffffffffc = -4 as int64_t
        # Tensor::numel() return int64_t, so following check that negative allocs are correctly handled
        self.assertRaises(
            RuntimeError,
            lambda: torch.nn.AdaptiveMaxPool1d(0x3fffffffffffffff)(torch.empty([2, 2, 2])))

    def test_adaptive_pooling_avg_nhwc(self):
        device_list = ['cpu']
        if TEST_CUDA:
            device_list.append('cuda')

        for device in device_list:
            input = torch.randint(1, 10, (4, 8, 8, 8), dtype=torch.float32).to(device)
            input = input.contiguous(memory_format=torch.channels_last).requires_grad_()
            grad = torch.randint(1, 10, (4, 8, 7, 7), dtype=torch.float32).to(device)
            pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device)

            ref_input = input.detach().clone().contiguous().requires_grad_(True)
            ref_grad = grad.detach().clone().contiguous()
            ref_pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device)

            out = pool(input)
            out.backward(grad)
            ref_out = ref_pool(ref_input)
            ref_out.backward(ref_grad)

            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
            self.assertTrue(ref_out.is_contiguous())
            self.assertEqual(out, ref_out)
            self.assertEqual(input.grad, ref_input.grad)

    def test_adaptive_pooling_avg_nhwc_non_contiguous(self):
        device_list = ['cpu']
        if TEST_CUDA:
            device_list.append('cuda')

        for device in device_list:
            input = torch.randint(1, 10, (4, 8, 8, 8), dtype=torch.float32).to(device)
            input = input.contiguous(memory_format=torch.channels_last)
            input = input[:, ::2, :, :].requires_grad_()
            grad = torch.randint(1, 10, (4, 8, 7, 7), dtype=torch.float32).to(device)
            grad = grad[:, ::2, :, :]
            pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device)

            ref_input = input.detach().clone().contiguous().requires_grad_(True)
            ref_grad = grad.detach().clone().contiguous()
            ref_pool = torch.nn.AdaptiveAvgPool2d((7, 7)).to(device)

            out = pool(input)
            out.backward(grad)
            ref_out = ref_pool(ref_input)
            ref_out.backward(ref_grad)

            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
            self.assertTrue(ref_out.is_contiguous())
            self.assertEqual(out, ref_out)
            self.assertEqual(input.grad, ref_input.grad)

    def test_adaptive_pooling_bfloat16(self):
        def _test_adaptive_pooling_bfloat16(self, device, mod, memory_format):
            input = torch.randint(1, 10, (3, 19, 8, 8), dtype=torch.float32)
            input = input.to(device).to(memory_format=memory_format).requires_grad_()
            pool = mod((7, 7)).to(device)

            input2 = input.detach().clone().bfloat16().requires_grad_(True)

            out = pool(input)
            out.sum().backward()
            out2 = pool(input2)
            out2.sum().backward()

            self.assertTrue(out2.is_contiguous(memory_format=memory_format))
            self.assertEqual(out2.dtype, torch.bfloat16)
            self.assertEqual(input2.grad.dtype, torch.bfloat16)
            self.assertEqual(out, out2.float(), atol=0.1, rtol=0)
            self.assertEqual(input.grad, input2.grad.float(), atol=0.1, rtol=0)

        device_list = ['cpu']
        for device in device_list:
            _test_adaptive_pooling_bfloat16(self, device, torch.nn.AdaptiveAvgPool2d, torch.contiguous_format)
            _test_adaptive_pooling_bfloat16(self, device, torch.nn.AdaptiveAvgPool2d, torch.channels_last)
            _test_adaptive_pooling_bfloat16(self, device, torch.nn.AdaptiveMaxPool2d, torch.contiguous_format)
            _test_adaptive_pooling_bfloat16(self, device, torch.nn.AdaptiveMaxPool2d, torch.channels_last)

    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
    @largeTensorTest('12GB', device='cuda')
    def test_adaptive_pooling_avg_nhwc_launch_config_backward(self):
        input = torch.randint(1, 10, (1, 32, 2 ** 17 + 1, 32), dtype=torch.float32, device="cuda")
        input = input.contiguous(memory_format=torch.channels_last).requires_grad_()
        grad = torch.randint(1, 10, (1, 32, 10, 32), dtype=torch.float32, device="cuda")

        pool = torch.nn.AdaptiveAvgPool2d((10, 32)).cuda()

        ref_input = input.detach().clone().contiguous().requires_grad_(True)
        ref_grad = grad.detach().clone().contiguous()
        ref_pool = torch.nn.AdaptiveAvgPool2d((10, 32)).cuda()

        out = pool(input)
        out.backward(grad)
        ref_out = ref_pool(ref_input)
        ref_out.backward(ref_grad)

        self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
        self.assertTrue(ref_out.is_contiguous())
        self.assertEqual(out, ref_out)
        self.assertEqual(input.grad, ref_input.grad)

    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
    @largeTensorTest('12GB', device='cuda')
    def test_adaptive_pooling_avg_nhwc_launch_config_forward(self):
        input = torch.randint(1, 10, (1, 32, 16, 16), dtype=torch.float32, device="cuda")
        input = input.contiguous(memory_format=torch.channels_last).requires_grad_()
        pool = torch.nn.AdaptiveAvgPool2d((2 ** 17 + 1, 32)).cuda()

        ref_input = input.detach().clone().contiguous().requires_grad_(True)
        ref_pool = torch.nn.AdaptiveAvgPool2d((2 ** 17 + 1, 32)).cuda()

        out = pool(input)
        ref_out = ref_pool(ref_input)

        self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
        self.assertTrue(ref_out.is_contiguous())
        self.assertEqual(out, ref_out)

    def test_MaxUnpool2d_output_size(self):
        m = nn.MaxPool2d(3, stride=2, return_indices=True)
        mu = nn.MaxUnpool2d(3, stride=2)
        big_t = torch.rand(1, 1, 6, 6)
        big_t[0][0][4][4] = 100
        output_big, indices_big = m(big_t)
        self.assertRaises(RuntimeError, lambda: mu(output_big, indices_big))

        small_t = torch.rand(1, 1, 5, 5)
        for i in range(0, 4, 2):
            for j in range(0, 4, 2):
                small_t[:, :, i, j] = 100
        output_small, indices_small = m(small_t)
        for h in range(3, 10):
            for w in range(3, 10):
                if 4 <= h <= 6 and 4 <= w <= 6:
                    size = (h, w)
                    if h == 6:
                        size = (1, 1) + size

                    mu(output_small, indices_small, output_size=size)
                else:
                    self.assertRaises(ValueError, lambda: mu(output_small, indices_small, (h, w)))

    def test_max_unpool2d_nhwc_cpu(self):
        input = torch.randn(2, 10, 9, 9).float().cpu()
        input = input.contiguous(memory_format=torch.channels_last)
        ref_input = input.clone().contiguous()

        pool = nn.MaxPool2d(3, stride=2, return_indices=True).cpu()
        ref_pool = nn.MaxPool2d(3, stride=2, return_indices=True).cpu()

        out, ind = pool(input)
        ref_out, ref_ind = ref_pool(ref_input)
        out.requires_grad_()
        ref_out.requires_grad_()

        unpool = nn.MaxUnpool2d(3, stride=2).cpu()
        ref_unpool = nn.MaxUnpool2d(3, stride=2).cpu()

        upout = unpool(out, ind)
        ref_upout = ref_unpool(ref_out, ref_ind)

        grad = torch.randn(upout.size()).float().cpu()
        grad = grad.contiguous(memory_format=torch.channels_last)
        ref_grad = grad.clone().contiguous()

        upout.backward(grad)
        ref_upout.backward(ref_grad)

        self.assertTrue(upout.is_contiguous(memory_format=torch.channels_last))
        self.assertTrue(ref_upout.is_contiguous())
        self.assertTrue(torch.allclose(upout, ref_upout))
        self.assertTrue(torch.allclose(out.grad, ref_out.grad))

    def test_max_unpool(self):
        with set_default_dtype(torch.double):
            # Test 1D
            output, indices = F.max_pool1d(torch.randn([1, 1, 4]), 2, stride=2, return_indices=True)
            self.assertEqual(F.max_unpool1d(output, indices, 2), F.max_unpool1d(output, indices, 2, stride=2))

            # Test list / tuple passed as argument to max_unpool1d
            input = torch.randn([1, 1, 5], requires_grad=True)
            output, indices = F.max_pool1d(input, 2, stride=2, return_indices=True)
            self.assertEqual(F.max_unpool1d(output, indices, 2, stride=2, output_size=input.shape),
                             F.max_unpool1d(output, indices, 2, stride=2, output_size=input.size()))
            gradcheck(F.max_unpool1d, (output, indices, 2), check_forward_ad=True)

            # Test 2D
            output, indices = F.max_pool2d(torch.randn(
                [1, 1, 4, 4], requires_grad=True), 2, stride=2, return_indices=True)
            self.assertEqual(F.max_unpool2d(output, indices, 2), F.max_unpool2d(output, indices, 2, stride=2))
            gradcheck(F.max_unpool2d, (output, indices, 2), check_forward_ad=True)

            # Test 3D
            output, indices = F.max_pool3d(torch.randn(
                [4, 4, 4, 4, 4], requires_grad=True), 2, stride=2, return_indices=True)
            self.assertEqual(F.max_unpool3d(output, indices, 2), F.max_unpool3d(output, indices, 2, stride=2))
            gradcheck(F.max_unpool3d, (output, indices, 2), check_forward_ad=True)


class TestPoolingNNDeviceType(NNTestCase):
    @onlyNativeDeviceTypes
    @dtypes(torch.float, torch.double)
    def test_adaptive_pooling_zero_batch(self, dtype, device):
        inp = torch.ones(0, 10, dtype=dtype, device=device)
        mod = torch.nn.AdaptiveAvgPool1d(5).to(device)
        _test_module_empty_input(self, mod, inp, check_size=False)

        inp = torch.ones(0, 10, 10, dtype=dtype, device=device)
        mod = torch.nn.AdaptiveAvgPool2d((5, 5)).to(device)
        _test_module_empty_input(self, mod, inp, check_size=False)

        inp = torch.ones(0, 10, 10, 10, dtype=dtype, device=device)
        mod = torch.nn.AdaptiveAvgPool3d((5, 5, 5)).to(device)
        _test_module_empty_input(self, mod, inp, check_size=False)

    @onlyNativeDeviceTypes
    def test_FractionalMaxPool2d_zero_batch(self, device):
        mod = nn.FractionalMaxPool2d(3, output_ratio=(0.5, 0.5))
        inp = torch.ones(0, 16, 50, 32, device=device)
        _test_module_empty_input(self, mod, inp, check_size=False)

        with self.assertRaisesRegex(RuntimeError, "Expected input"):
            inp = torch.randn(1, 0, 50, 32, device=device)
            mod(inp)

    @onlyNativeDeviceTypes
    def test_FractionalMaxPool3d_zero_batch(self, device):
        mod = nn.FractionalMaxPool3d(3, output_ratio=(0.5, 0.5, 0.5)).to(device)
        inp = torch.ones(0, 16, 50, 32, 32, device=device)
        _test_module_empty_input(self, mod, inp, check_size=False)

        with self.assertRaisesRegex(RuntimeError, "Expected input"):
            inp = torch.randn(1, 0, 50, 32, 32, device=device)
            mod(inp)

    @onlyNativeDeviceTypes
    def test_FractionalMaxPool2d_zero_out_size(self, device):
        mod = nn.FractionalMaxPool2d([2, 2], output_size=[0, 1])
        inp = torch.rand([16, 50, 32, 32], device=device)
        out = mod(inp)
        self.assertEqual(out, torch.empty((16, 50, 0, 1), device=device))

    @onlyNativeDeviceTypes
    def test_FractionalMaxPool3d_zero_out_size(self, device):
        mod = nn.FractionalMaxPool3d([3, 2, 2], output_size=[0, 1, 1])
        inp = torch.rand([16, 50, 32, 32], device=device)
        out = mod(inp)
        self.assertEqual(out, torch.empty((16, 0, 1, 1), device=device))

    @onlyNativeDeviceTypes
    def test_MaxPool_zero_batch_dim(self, device):
        inp = torch.randn(0, 16, 50, device=device)
        mod = torch.nn.MaxPool1d(3, stride=2).to(device)
        _test_module_empty_input(self, mod, inp, check_size=False)

        # 1D is supposed to be okay with 0 numel() inputs so dont test
        # error raising for that case.

        inp = torch.randn(0, 16, 50, 32, device=device)
        mod = torch.nn.MaxPool2d(3, stride=2).to(device)
        _test_module_empty_input(self, mod, inp, check_size=False)

        with self.assertRaisesRegex(RuntimeError, "Expected"):
            inp = torch.randn(1, 0, 50, 32, device=device)
            mod(inp)

        inp = torch.ones(0, 16, 50, 44, 31, device=device)
        mod = torch.nn.MaxPool3d(3, stride=2).to(device)
        _test_module_empty_input(self, mod, inp, check_size=False)

        with self.assertRaisesRegex(RuntimeError, "Expected"):
            inp = torch.ones(1, 0, 50, 44, 31, device=device)
            mod(inp)

    @onlyNativeDeviceTypes
    def test_MaxUnpool_zero_batch_dim(self, device):
        pool = torch.nn.MaxPool1d(2, stride=2, return_indices=True).to(device)
        unpool = torch.nn.MaxUnpool1d(2, stride=2).to(device)
        inp = torch.randn(0, 10, 10, requires_grad=True, device=device)
        output, indices = pool(inp)
        output.requires_grad_(True)
        unpool_out = unpool(output, indices)
        unpool_out.sum().backward()

        self.assertEqual(inp.grad, torch.zeros_like(inp))
        self.assertEqual(unpool_out, torch.zeros_like(unpool_out))

        pool = torch.nn.MaxPool2d(2, stride=2, return_indices=True).to(device)
        unpool = torch.nn.MaxUnpool2d(2, stride=2).to(device)
        inp = torch.randn(0, 10, 10, 10, requires_grad=True, device=device)
        output, indices = pool(inp)
        unpool_out = unpool(output, indices)
        unpool_out.sum().backward()

        self.assertEqual(inp.grad, torch.zeros_like(inp))
        self.assertEqual(unpool_out, torch.zeros_like(unpool_out))

        pool = torch.nn.MaxPool3d(2, stride=2, return_indices=True).to(device)
        unpool = torch.nn.MaxUnpool3d(2, stride=2).to(device)
        inp = torch.randn(0, 10, 10, 10, 10, requires_grad=True, device=device)
        output, indices = pool(inp)
        output.requires_grad_(True)
        unpool_out = unpool(output, indices)
        unpool_out.sum().backward()

        self.assertEqual(inp.grad, torch.zeros_like(inp))
        self.assertEqual(unpool_out, torch.zeros_like(unpool_out))

    @slowTest
    @onlyNativeDeviceTypes
    @skipCUDAIfRocm
    @parametrize_test("module_name,module_size,output_size,test_index,should_error", [
        subtest(('MaxUnpool2d', (2, 2), (1, 3, 4, 5), -1, True), name='case1'),
        subtest(('MaxUnpool2d', (2, 2), (1, 3, 4, 5), 2 * 2 * 4 * 5, True), name='case2'),
        subtest(('MaxUnpool2d', (2, 2), (1, 3, 4, 5), (2 * 2 * 4 * 5) - 1, False), name='case3'),
        subtest(('MaxUnpool2d', (2, 3), (2, 1, 4, 2), 2 * 3 * 4 * 2, True), name='case4'),
        subtest(('MaxUnpool2d', (2, 3), (2, 1, 4, 2), (2 * 3 * 4 * 2) - 1, False), name='case5'),

        subtest(('MaxUnpool3d', (2, 2, 2), (1, 3, 4, 5), -1, True), name='case6'),
        subtest(('MaxUnpool3d', (2, 2, 2), (1, 3, 4, 5), 2 * 2 * 2 * 3 * 4 * 5, True), name='case7'),
        subtest(('MaxUnpool3d', (2, 2, 2), (1, 3, 4, 5), (2 * 2 * 2 * 3 * 4 * 5) - 1, False), name='case8'),
        subtest(('MaxUnpool3d', (2, 2, 2), (2, 3, 4, 1), 2 * 2 * 2 * 3 * 4 * 1, True), name='case9'),
        subtest(('MaxUnpool3d', (2, 2, 2), (2, 3, 4, 1), (2 * 2 * 2 * 3 * 4 * 1) - 1, False), name='case10'),
    ])
    def test_MaxUnpool_index_errors(self, device, module_name, module_size, output_size, test_index, should_error):
        # NOTE: CUDA tests need to be run in a subprocess because they cause device asserts
        if torch.device(device).type == 'cuda':
            error_msgs = {
                'MaxUnpool2d': r'Assertion `maxind >= 0 && maxind < outputImageSize` failed',
                'MaxUnpool3d': r'Assertion `index >= 0 && index < outputImageSize` failed'}

            script = f'''
import torch
unpool = torch.nn.{module_name}({module_size}).to('{device}')
output = torch.rand({output_size}, dtype=torch.float32, device='{device}')
indices = torch.zeros({output_size}, dtype=torch.int64, device='{device}')
indices.flatten()[0] = {test_index}
unpool(output, indices)
torch.cuda.synchronize()
'''
            p = subprocess.run(
                [sys.executable, '-c', script],
                cwd=os.path.dirname(os.path.realpath(__file__)),
                capture_output=True,
                text=True,
            )

            output = p.stdout + '\n' + p.stderr

            error_msg = error_msgs[module_name]

            if should_error:
                self.assertIn(
                    error_msg,
                    output,
                    'The expected error was not found')
            else:
                self.assertNotIn(
                    'Error',
                    output,
                    'Should not have produced an error')
        else:
            module_class = getattr(torch.nn, module_name)
            unpool = module_class(module_size).to(device)
            output = torch.rand(output_size, dtype=torch.float32, device=device)
            indices = torch.zeros(output_size, dtype=torch.int64, device=device)
            indices.flatten()[0] = test_index

            if should_error:
                with self.assertRaisesRegex(RuntimeError, r'Found an invalid max index:'):
                    unpool(output, indices)
            else:
                unpool(output, indices)

    @onlyNativeDeviceTypes
    def test_AdaptiveMaxPool_zero_batch_dim(self, device):
        inp = torch.randn(0, 16, 50, device=device)
        mod = torch.nn.AdaptiveMaxPool1d(3).to(device)
        _test_module_empty_input(self, mod, inp, check_size=False)

        with self.assertRaisesRegex(RuntimeError, "Expected"):
            inp = torch.randn(1, 0, 50, device=device)
            mod(inp)

        inp = torch.randn(0, 16, 50, 32, device=device)
        mod = torch.nn.AdaptiveMaxPool2d(3).to(device)
        _test_module_empty_input(self, mod, inp, check_size=False)

        with self.assertRaisesRegex(RuntimeError, "Expected"):
            inp = torch.randn(1, 0, 50, 32, device=device)
            mod(inp)

        inp = torch.ones(0, 16, 50, 44, 31, device=device)
        mod = torch.nn.AdaptiveMaxPool3d(3).to(device)
        _test_module_empty_input(self, mod, inp, check_size=False)

        with self.assertRaisesRegex(RuntimeError, "Expected"):
            inp = torch.ones(1, 0, 50, 44, 31, device=device)
            mod(inp)

    @onlyNativeDeviceTypes
    def test_AvgPool2d_empty(self, device):
        avgpool = torch.nn.AvgPool2d(3, stride=2).to(device)
        inp = torch.randn(0, 16, 20, 32, device=device)
        _test_module_empty_input(self, avgpool, inp, check_size=False)

        clast_inp = torch.randn(0, 16, 20, 32, device=device).contiguous(memory_format=torch.channels_last)
        _test_module_empty_input(self, avgpool, clast_inp, check_size=False)

        # test with empty non-batch input
        with self.assertRaisesRegex(RuntimeError, '3D or 4D'):
            inp = torch.randn(16, 0, 20, 32, device=device)
            avgpool(inp)

    def test_pooling_shape(self, device):
        ''' Test the output shape calculation for pooling functions '''

        # Checks output shape against expected for 1D, 2D and 3D
        def check(expected_out_shape, sizes, *args, **kwargs):
            for kernel in ['max', 'avg']:
                for i in [1, 2, 3]:
                    if hasattr(torch.nn.functional, f'{kernel}_pool{i}d'):
                        op = getattr(torch.nn.functional, f'{kernel}_pool{i}d')
                        t = torch.randn(sizes[:i + 2], device=device)
                        self.assertEqual(op(t, *args, **kwargs).shape, expected_out_shape[:i + 2])

        check((1, 1, 3, 3, 4), (1, 1, 5, 6, 7), kernel_size=1, stride=2, padding=0, ceil_mode=True)
        check((1, 1, 2, 3, 3), (1, 1, 3, 4, 5), kernel_size=2, stride=2, padding=1, ceil_mode=False)
        check((1, 1, 2, 3, 3), (1, 1, 3, 4, 5), kernel_size=2, stride=2, padding=1, ceil_mode=True)

        # Test case from issue https://github.com/pytorch/pytorch/issues/45357
        x = torch.randn(1, 1, 6, 7, device=device)
        y = torch.nn.functional.max_pool2d(x, 1, stride=(2, 2), padding=0, ceil_mode=True)
        self.assertEqual(y.size(), (1, 1, 3, 4))

    @onlyNativeDeviceTypes   # TODO: fix on XLA
    def test_adaptive_avg_pool2d_output_size_one(self, device):
        def helper(size, memory_format):
            x = torch.randint(1, 10, size, dtype=torch.float, device=device, requires_grad=True)
            if memory_format == 'non_contiguous':
                x = x[::2, ::2, ::2, ::2]
            else:
                x = x.to(memory_format=memory_format)

            net = torch.nn.AdaptiveAvgPool2d((1, 1))
            out = net(x)
            ref_out = x.contiguous().mean((-1, -2)).view((x.size(0), x.size(1), 1, 1))

            out.sum().backward()    # make sure it doesn't crash

            self.assertEqual(out, ref_out)
            if memory_format == torch.channels_last:
                self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
                c = out.size(1)
                self.assertEqual(out.stride(), [c, 1, c, c])
            else:
                self.assertTrue(out.is_contiguous())
                c = out.size(1)
                self.assertEqual(out.stride(), [c, 1, 1, 1])

        for mf in (torch.contiguous_format, torch.channels_last, 'non_contiguous'):
            helper((2, 3, 6, 6), mf)

    @onlyNativeDeviceTypes
    def test_adaptive_avg_pool3d_output_size_one(self, device):
        x = torch.randn((2, 3, 6, 6, 6), dtype=torch.float, device=device, requires_grad=True)

        net = torch.nn.AdaptiveAvgPool3d(1)
        out = net(x)
        ref_out = x.contiguous().mean((-1, -2, -3)).view(out.shape)

        out.sum().backward()    # make sure it doesn't crash

        self.assertEqual(out, ref_out)
        self.assertTrue(out.is_contiguous())
        c = out.size(1)
        self.assertEqual(out.stride(), [c, 1, 1, 1, 1])

    @expectedFailureMeta  # Runtime Error not raised for meta
    @onlyNativeDeviceTypes
    @dtypes(torch.uint8, torch.int8, torch.short, torch.int, torch.long)
    def test_adaptive_pooling_no_suppot_input(self, device, dtype):
        for numel in (2, 3):
            for pool_type in ('Max', 'Avg'):
                cls_name = 'Adaptive{}Pool{}d'.format(pool_type, numel)
                module_cls = getattr(nn, cls_name)
                output_size = (2,) * numel
                module = module_cls(output_size)
                input = torch.randn((4,) * (numel + 1), device=device).to(dtype)
                with self.assertRaisesRegex(RuntimeError, "not implemented"):
                    output = module(input)

    @onlyNativeDeviceTypes
    @dtypes(torch.float, torch.double)
    @dtypesIfCUDA(torch.half, torch.float, torch.double)
    def test_avg_pool2d_nhwc(self, device, dtype):
        def helper(n, c, h, w, kernel_size, stride=None,
                   count_include_pad=True, divisor_override=None, padding=0):
            if stride is None:
                stride = kernel_size
            input = torch.randn(n, c, h, w, dtype=dtype, device=device)
            input = input.contiguous(memory_format=torch.channels_last).requires_grad_()
            grad = torch.randn(n, c, (h - kernel_size) // stride + 1, (w - kernel_size) // stride + 1,
                               dtype=dtype, device=device)
            pool = torch.nn.AvgPool2d(kernel_size, stride=stride, count_include_pad=count_include_pad,
                                      divisor_override=divisor_override).to(device)

            ref_input = input.detach().clone().contiguous().requires_grad_(True)
            ref_grad = grad.detach().clone().contiguous()
            ref_pool = torch.nn.AvgPool2d(kernel_size, stride=stride, count_include_pad=count_include_pad,
                                          divisor_override=divisor_override).to(device)

            out = pool(input)
            out.backward(grad)
            ref_out = ref_pool(ref_input)
            ref_out.backward(ref_grad)

            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
            self.assertTrue(ref_out.is_contiguous())
            self.assertEqual(out, ref_out)
            self.assertEqual(input.grad, ref_input.grad)

        helper(4, 8, 8, 8, 3)
        helper(4, 8, 8, 8, 3, count_include_pad=False, padding=1)
        helper(4, 8, 8, 8, 3, count_include_pad=False, padding=2, stride=2)
        helper(4, 8, 8, 8, 3, divisor_override=42)
        helper(4, 8, 8, 8, 7)
        # ROCm 16GB MI25 hits OOM error. Clear caching allocator prior to running large subtest.
        if TEST_WITH_ROCM and 'cuda' in device:
            torch.cuda.empty_cache()
        helper(200, 512, 28, 28, 2)
        helper(4, 8, 7, 7, 3, stride=1)
        helper(4, 8, 7, 7, 3, padding=2, stride=1)
        helper(10, 512, 31, 31, 3, stride=2)
        helper(1, 129, 8, 8, 3, stride=2)

    @onlyCPU
    @dtypes(torch.float, torch.double)
    def test_max_pool1d_corner_cases(self, device, dtype):
        def check(x, args, expected):
            model = torch.nn.MaxPool1d(*args)
            if isinstance(x, list):
                x = torch.tensor(x, device=device, dtype=dtype)
                expected = torch.tensor(expected, device=device, dtype=dtype)
            self.assertEqual(model(x), expected)

        # Pooling args: (kernel_size, stride, padding, dilation, return_indices, ceil_mode)
        check([[1]], (1, None, 0, 1, False, False), [[1]])
        check([[1]], (2, None, 1, 2, False, False), [[float('-inf')]])
        check([[1], [1]], (2, None, 1, 2, False, False), [[float('-inf')], [float('-inf')]])
        check([[1, 2]], (2, 1, 1, 2, False, False), [[2, 1]])
        check([[1, 2]], (2, 2, 1, 2, False, True), [[2, 2]])

    @onlyCPU
    @dtypes(torch.float, torch.double)
    def test_max_pool1d(self, device, dtype):
        # FIXME For now compare against max_pool1d with indices
        def check(x, *args, **kwargs):
            model = torch.nn.MaxPool1d(*args, **kwargs)
            ref_model = torch.nn.MaxPool1d(*args, **kwargs, return_indices=True)
            self.assertEqual(model(x), ref_model(x)[0])

        sizes = [random.sample(range(8, 128), 3) for _ in range(3)]
        kernel_sizes = random.sample(range(1, 5), 3)
        strides = random.sample(range(1, 5), 3)
        dilations = random.sample(range(1, 5), 3)
        ceil_modes = [True, False]

        for size, kernel_size, stride, dilation, ceil_mode in \
                itertools.product(sizes, kernel_sizes, strides, dilations, ceil_modes):
            padding = random.sample(range(0, math.floor(kernel_size / 2) + 1), 1)
            check(torch.randn(size, device=device, dtype=dtype),
                  kernel_size, stride, padding, dilation, ceil_mode=ceil_mode)

        # Non-contiguous test
        tensor = torch.randn(5, 151, 33, device=device, dtype=dtype)[::2, ::3, ::2]
        check(tensor, 3, 2, 1, 2, ceil_mode=True)
        check(tensor.transpose(1, 2), 3, 2, 1, 2, ceil_mode=True)

    @onlyCUDA
    def test_max_pool2d(self, device):
        def helper(n, c, h, w, ks):
            x = torch.randn(n, c, h, w, device='cuda', dtype=torch.float, requires_grad=True)
            ref_x = x.detach().clone().cpu().requires_grad_()

            pool = torch.nn.MaxPool2d(kernel_size=ks)

            y = pool(x)
            ref_y = pool(ref_x)

            y.sum().backward()
            ref_y.sum().backward()

            self.assertEqual(y, ref_y)
            self.assertEqual(x.grad, ref_x.grad)

        helper(2, 8, 4, 4, ks=2)
        helper(1, 100000, 32, 32, ks=4)
        helper(1, 100000, 1, 4, ks=(1, 4))  # test for max_pool1d

    @onlyNativeDeviceTypes
    @dtypes(torch.float, torch.double)
    @dtypesIfCUDA(torch.half, torch.float, torch.double)
    def test_max_pool2d_nhwc(self, device, dtype):
        def helper(n, c, h, w, kernel_size, stride=None):
            if stride is None:
                stride = kernel_size
            input = torch.randn(n, c, h, w, dtype=dtype, device=device)
            input = input.contiguous(memory_format=torch.channels_last).requires_grad_()
            grad = torch.randn(n, c, (h - kernel_size) // stride + 1, (w - kernel_size) // stride + 1,
                               dtype=dtype, device=device)
            pool = torch.nn.MaxPool2d(kernel_size, stride, return_indices=True).to(device)

            ref_input = input.detach().clone().contiguous().requires_grad_(True)
            ref_grad = grad.detach().clone().contiguous()
            ref_pool = torch.nn.MaxPool2d(kernel_size, stride, return_indices=True).to(device)

            out, ind = pool(input)
            out.backward(grad)
            ref_out, ref_ind = ref_pool(ref_input)
            ref_out.backward(ref_grad)

            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
            self.assertTrue(ref_out.is_contiguous())
            self.assertTrue(ind.is_contiguous(memory_format=torch.channels_last))
            self.assertTrue(ref_ind.is_contiguous())
            self.assertEqual(out, ref_out)
            self.assertEqual(ind, ref_ind)
            self.assertEqual(input.grad, ref_input.grad)

        helper(4, 8, 8, 8, 7)
        helper(200, 512, 28, 28, 2)
        helper(4, 8, 7, 7, 3, stride=1)
        helper(10, 512, 31, 31, 3, stride=2)
        helper(1, 129, 8, 8, 3, stride=2)

    @onlyNativeDeviceTypes
    @dtypes(torch.half, torch.float, torch.double)
    @onlyCUDA
    def test_max_pool3d_ndhwc(self, device, dtype):
        def helper(n, c, h, w, d, kernel_size, stride=None):
            batch = n
            if not batch:
                batch = 1
            input = torch.randn(batch, c, d, h, w, dtype=dtype, device=device)
            input = input.contiguous(memory_format=torch.channels_last_3d).requires_grad_()
            if not n:
                input = input.squeeze(0).detach().clone().requires_grad_()
            if isinstance(kernel_size, int):
                kernel_size = [kernel_size] * 3
            if stride is None:
                stride = kernel_size
            elif isinstance(stride, int):
                stride = [stride] * 3
            grad = torch.randn(batch, c,
                               (d - kernel_size[0]) // stride[0] + 1,
                               (h - kernel_size[1]) // stride[1] + 1,
                               (w - kernel_size[2]) // stride[2] + 1,
                               dtype=dtype, device=device)
            grad = grad.contiguous(memory_format=torch.channels_last_3d)
            if not n:
                grad = grad.squeeze(0)
            pool = torch.nn.MaxPool3d(kernel_size, stride, return_indices=True).to(device)

            ref_input = input.detach().clone().contiguous().requires_grad_(True)
            ref_grad = grad.detach().clone().contiguous()
            ref_pool = torch.nn.MaxPool3d(kernel_size, stride, return_indices=True).to(device)
            out, ind = pool(input)
            out.backward(grad)
            ref_out, ref_ind = ref_pool(ref_input)
            ref_out.backward(ref_grad)

            if len(out.shape) == 4:
                self.assertTrue(out.unsqueeze(0).is_contiguous(memory_format=torch.channels_last_3d))
            else:
                self.assertTrue(out.is_contiguous(memory_format=torch.channels_last_3d))
            self.assertTrue(ref_out.is_contiguous())
            if len(ind.shape) == 4:
                self.assertTrue(ind.unsqueeze(0).is_contiguous(memory_format=torch.channels_last_3d))
            else:
                self.assertTrue(ind.is_contiguous(memory_format=torch.channels_last_3d))
            self.assertTrue(ref_ind.is_contiguous())
            self.assertEqual(out, ref_out)
            self.assertEqual(ind, ref_ind)
            if dtype == torch.half:
                self.assertEqual(input.grad, ref_input.grad, atol=0.05, rtol=0.01)
            else:
                self.assertEqual(input.grad, ref_input.grad)

        helper(4, 8, 8, 8, 8, 7)
        helper(4, 8, 8, 8, 8, (5, 6, 7))
        helper(1, 8, 8, 8, 8, (5, 6, 7))
        helper(0, 6, 12, 13, 14, (5, 6, 7))
        helper(4, 8, 7, 7, 7, 3, stride=1)
        helper(10, 128, 19, 19, 19, 3, stride=2)
        helper(10, 128, 19, 19, 19, (1, 2, 3), stride=2)
        helper(1, 128, 19, 19, 19, (1, 2, 3), stride=2)
        helper(0, 128, 19, 19, 19, (1, 2, 3), stride=2)
        helper(1, 79, 4, 4, 4, 3, stride=2)
        helper(0, 79, 4, 4, 4, 3, stride=2)

    @onlyCPU
    def test_max_pool2d_bfloat16(self, device):
        def helper(n, c, h, w, kernel_size, stride, memory_format):
            input = torch.randn(n, c, h, w, dtype=torch.float32, device=device).bfloat16()
            input = input.to(memory_format=memory_format).requires_grad_()
            pool = torch.nn.MaxPool2d(kernel_size, stride, return_indices=True).to(device)

            input2 = input.detach().clone().float().requires_grad_(True)

            out, ind = pool(input)
            out.sum().backward()
            out2, ind2 = pool(input2)
            out2.sum().backward()

            self.assertTrue(out.is_contiguous(memory_format=memory_format))
            self.assertEqual(out.dtype, torch.bfloat16)
            self.assertEqual(input.grad.dtype, torch.bfloat16)
            self.assertEqual(out, out2.bfloat16())
            self.assertEqual(ind, ind2)
            self.assertEqual(input.grad, input2.grad.bfloat16())

        helper(4, 30, 8, 8, 7, 1, torch.contiguous_format)
        helper(4, 65, 8, 8, 7, 1, torch.channels_last)
        helper(1, 19, 20, 10, 8, 2, torch.contiguous_format)
        helper(1, 19, 20, 10, 8, 2, torch.channels_last)

    @onlyCUDA
    def test_max_pool2d_indices(self, device):
        def helper(n, c, h, w, ks):
            if n is None:
                x = torch.randn(c, h, w, device='cuda', dtype=torch.float, requires_grad=True)
            else:
                x = torch.randn(n, c, h, w, device='cuda', dtype=torch.float, requires_grad=True)

            ref_x = x.detach().clone().cpu().requires_grad_()

            pool = torch.nn.MaxPool2d(kernel_size=ks, return_indices=True)

            y, idx = pool(x)
            ref_y, ref_idx = pool(ref_x)

            y.sum().backward()
            ref_y.sum().backward()

            self.assertEqual(y, ref_y)
            self.assertEqual(idx, ref_idx)  # assertEqual implicitly compares shape for tensors
            self.assertEqual(x.grad, ref_x.grad)

        helper(2, 8, 4, 4, ks=2)
        helper(None, 3, 50, 50, ks=5)

    @onlyCPU
    def test_avg_pool2d_bfloat16(self, device):
        def helper(n, c, h, w, kernel_size, stride, memory_format):
            input = torch.randn(n, c, h, w, dtype=torch.float32, device=device).bfloat16()
            input = input.to(memory_format=memory_format).requires_grad_()
            pool = torch.nn.AvgPool2d(kernel_size, stride).to(device)

            input2 = input.detach().clone().float().requires_grad_(True)

            out = pool(input)
            out.sum().backward()
            out2 = pool(input2)
            out2.sum().backward()

            self.assertTrue(out.is_contiguous(memory_format=memory_format))
            self.assertEqual(out.dtype, torch.bfloat16)
            self.assertEqual(input.grad.dtype, torch.bfloat16)
            self.assertEqual(out, out2.bfloat16())
            self.assertEqual(input.grad, input2.grad.bfloat16())

        helper(4, 30, 8, 8, 7, 1, torch.contiguous_format)
        helper(4, 65, 8, 8, 7, 1, torch.channels_last)
        helper(1, 19, 20, 10, 8, 2, torch.contiguous_format)
        helper(1, 19, 20, 10, 8, 2, torch.channels_last)

    @dtypes(torch.float, torch.double)
    def test_adaptive_pooling_max_nhwc(self, device, dtype):
        def helper(n, c, h, w, output_height, output_width, contig):
            input = torch.randint(1, 10, (n, c, h, w), device=device, dtype=dtype)
            input = input.contiguous(memory_format=torch.channels_last)
            grad = torch.randint(1, 10, (4, 8, output_height, output_width), device=device, dtype=dtype)
            grad = grad.contiguous(memory_format=torch.channels_last)
            if not contig:
                input = input[:, ::2, :, :]
                grad = grad[:, ::2, :, :]
            input.requires_grad_(True)
            pool = torch.nn.AdaptiveMaxPool2d((output_height, output_width), return_indices=True).to(device)

            ref_input = input.detach().clone().contiguous().requires_grad_(True)
            ref_grad = grad.detach().clone().contiguous()
            ref_pool = torch.nn.AdaptiveMaxPool2d((output_height, output_width), return_indices=True).to(device)

            out, ind = pool(input)
            out.backward(grad)
            ref_out, ref_ind = ref_pool(ref_input)
            ref_out.backward(ref_grad)

            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
            self.assertTrue(ref_out.is_contiguous())
            self.assertTrue(ind.is_contiguous(memory_format=torch.channels_last))
            self.assertTrue(ref_ind.is_contiguous())
            self.assertEqual(out, ref_out)
            self.assertEqual(ind, ref_ind)
            self.assertEqual(input.grad, ref_input.grad)

        for contig in [True, False]:
            helper(4, 8, 10, 10, 7, 7, contig)
            helper(4, 8, 9, 14, 5, 8, contig)
            helper(4, 8, 11, 11, 1, 1, contig)

    @dtypes(torch.float, torch.double)
    def test_pooling_max_nhwc(self, device, dtype):
        def helper(n, c, h, w, kernel_size, stride, padding, dilation, contig, device):
            output_height = math.floor((h + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
            output_width = math.floor((w + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)

            input = torch.randint(1, 10, (n, c, h, w), device=device, dtype=dtype)
            input = input.contiguous(memory_format=torch.channels_last)
            grad = torch.randint(1, 10, (n, c, output_height, output_width), device=device, dtype=dtype)
            grad = grad.contiguous(memory_format=torch.channels_last)
            if not contig:
                input = input[:, ::2, :, :]
                grad = grad[:, ::2, :, :]
            input.requires_grad_(True)
            pool = torch.nn.MaxPool2d(
                kernel_size, stride, padding, dilation, return_indices=True, ceil_mode=False
            )

            ref_input = input.detach().clone().contiguous().requires_grad_(True)
            ref_grad = grad.detach().clone().contiguous()
            ref_pool = torch.nn.MaxPool2d(
                kernel_size, stride, padding, dilation, return_indices=True, ceil_mode=False
            ).to(device)

            out, ind = pool(input)
            out.backward(grad)
            ref_out, ref_ind = ref_pool(ref_input)
            ref_out.backward(ref_grad)

            self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
            self.assertTrue(ref_out.is_contiguous())
            self.assertTrue(ind.is_contiguous(memory_format=torch.channels_last))
            self.assertTrue(ref_ind.is_contiguous())
            self.assertEqual(out, ref_out)
            self.assertEqual(ind, ref_ind)
            self.assertEqual(input.grad, ref_input.grad)

        for contig in [True, False]:
            helper(4, 8, 10, 10, (2, 2), (1, 1), (1, 1), (2, 2), contig, device)
            helper(4, 8, 9, 14, (2, 2), (1, 1), (1, 1), (2, 2), contig, device)
            helper(4, 8, 11, 11, (4, 4), (2, 2), (2, 2), (2, 2), contig, device)

    @onlyCUDA
    def test_pool3d_size_one_feature_dim(self, device):
        # Tests crazy strides for feature dim of size 1
        x = torch.randn(7, 1, 5, 3, 2, device=device)
        strange_strides = [30, 1234, 6, 2, 1]
        y = x.as_strided(x.size(), strange_strides)
        x = x.cpu().as_strided(x.size(), strange_strides)

        to_test = {
            'max_pool3d': lambda t: F.max_pool3d(t, (5, 1, 1), stride=(5, 1, 1)),
            'avg_pool3d': lambda t: F.avg_pool3d(t, (5, 1, 1), stride=(5, 1, 1)),
        }

        for test, fn in to_test.items():
            # Should not crash
            out_y = fn(y)
            out_x = fn(x)
            self.assertEqual(out_y, out_x.to(device), msg=test)

    @onlyCUDA
    @largeTensorTest('18GB')
    @largeTensorTest('180GB', 'cpu')
    def test_pool3d_large_size_int64(self, device):
        # See https://github.com/pytorch/pytorch/issues/52822
        x = torch.randn(70, 32, 100, 100, 100, dtype=torch.half, device=device, requires_grad=True)
        y = torch.nn.functional.max_pool3d(x, 5)
        g = torch.randn_like(y, dtype=torch.half)
        torch.cuda.synchronize()
        y.backward(g)
        torch.cuda.synchronize()

        ref_x = x.detach().cpu().float()  # max_pool3d_cpu is not implemented for half
        ref_x.requires_grad = True
        ref_g = g.cpu().float()
        ref_y = torch.nn.functional.max_pool3d(ref_x, 5)
        ref_y.backward(ref_g)

        self.assertEqual(y, ref_y, exact_dtype=False)
        self.assertEqual(x.grad, ref_x.grad, exact_dtype=False)

    @onlyCUDA
    def test_AvgPool3d_backward_after_cat_dim1_device(self, device):
        # x has to have batch_size 1 to test contiguous checks
        x = torch.randn(1, 3, 4, 4, 4, device=device, requires_grad=True)
        y = F.avg_pool3d(x, kernel_size=3, padding=1, stride=2)

        grad = torch.randn(y.size(), device=device)
        # increase the stride in dimension 0. the tensor is still contiguous because size[0] is 1
        stride = list(grad.stride())
        stride[0] = stride[0] * 2
        grad.set_(grad.storage(), 0, grad.size(), stride)
        assert grad.is_contiguous()

        y.backward(grad)

    def test_pooling_size_empty(self, device):
        t = torch.rand([1, 2, 3, 4], device=device)
        self.assertRaises(RuntimeError, lambda: F.adaptive_avg_pool1d(t, []))
        self.assertRaises(RuntimeError, lambda: F.adaptive_avg_pool2d(t, []))
        self.assertRaises(RuntimeError, lambda: F.adaptive_avg_pool3d(t, []))
        self.assertRaises(RuntimeError, lambda: F.adaptive_max_pool1d(t, []))
        self.assertRaises(RuntimeError, lambda: F.adaptive_max_pool2d(t, []))
        self.assertRaises(RuntimeError, lambda: F.adaptive_max_pool3d(t, []))

    def _test_maxpool_indices(self, num_dim, adaptive=False, device="cpu", dtype=torch.float):
        def expected_indices(dim, dtype):
            if dim == 1:
                return torch.tensor([1, 3], dtype=dtype).repeat(2, 2, 1)
            if dim == 2:
                return torch.tensor([[5, 7], [13, 15]], dtype=dtype).repeat(2, 2, 1, 1)

        def expected_grad(dim, dtype):
            if dim == 1:
                return torch.tensor([0, 1, 0, 1], dtype=dtype).repeat(2, 2, 1)
            grad = expected_grad(dim - 1, dtype=dtype)
            zero = torch.zeros(grad.size(), dtype=dtype)
            return torch.stack((zero, grad, zero, grad), 2)

        def expected_output(dim, dtype):
            if dim == 1:
                return torch.arange(2, 17, 2, dtype=dtype).view(2, 2, 2)
            if dim == 2:
                col = torch.arange(6, 63, 8, dtype=dtype)
                return torch.stack([col, col + 2], 1).view(2, 2, 2, 2)

        if adaptive:
            cls_name = 'AdaptiveMaxPool{}d'.format(num_dim)
        else:
            cls_name = 'MaxPool{}d'.format(num_dim)
        module_cls = getattr(nn, cls_name)
        module = module_cls(2, return_indices=True).to(device, dtype=dtype)
        numel = 4 ** (num_dim + 1)
        input = torch.arange(1, numel + 1).view(2, 2, *repeat(4, num_dim)).to(device, dtype=dtype)
        input_var = input.clone().detach().requires_grad_()

        # Check forward
        output, indices = module(input_var)
        if num_dim != 3:
            expected_indices = expected_indices(num_dim, dtype=indices.data.dtype)
            expected_output = expected_output(num_dim, dtype=output.data.dtype)
            self.assertEqual(indices.dim(), input.dim())
            self.assertEqual(indices.data.squeeze(), expected_indices)
            self.assertEqual(output.data.squeeze(), expected_output)
        self.assertTrue(output.requires_grad)
        self.assertFalse(indices.requires_grad)

        # Make sure backward works
        grad_output = torch.ones(output.size(), device=device, dtype=dtype)
        output.backward(grad_output, retain_graph=True)
        expected_grad = expected_grad(num_dim, dtype=input_var.grad.data.dtype)
        self.assertEqual(input_var.grad.data, expected_grad.view_as(input))

        # Make sure backward after changing indices will result in an error
        indices.add_(1)
        self.assertRaises(RuntimeError, lambda: output.backward(grad_output))

        # Make sure -Infinity is handled correctly
        t = torch.tensor([[[float("-inf")]]])
        m = nn.MaxPool1d(kernel_size=1, return_indices=True)
        output, indices = m(t)
        self.assertEqual(output[0, 0, 0], float("-inf"))
        self.assertEqual(indices[0, 0, 0], 0)

        t = torch.tensor([[[float("-inf")]]])
        m = nn.MaxPool2d(kernel_size=1, return_indices=True)
        output, indices = m(t)
        self.assertEqual(output[0, 0, 0], float("-inf"))
        self.assertEqual(indices[0, 0, 0], 0)

        t = torch.tensor([[[[float("-inf")]]]])
        m = nn.MaxPool3d(kernel_size=1, return_indices=True)
        output, indices = m(t)
        self.assertEqual(output[0, 0, 0, 0], float("-inf"))
        self.assertEqual(indices[0, 0, 0, 0], 0)

    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
    @dtypes(torch.float)
    def test_MaxPool1d_indices(self, device, dtype):
        self._test_maxpool_indices(1, device=device, dtype=dtype)

    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
    @dtypes(torch.float)
    def test_MaxPool2d_indices(self, device, dtype):
        self._test_maxpool_indices(2, device=device, dtype=dtype)

    @skipIfMps
    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
    @dtypes(torch.float)
    def test_MaxPool3d_indices(self, device, dtype):
        self._test_maxpool_indices(3, device=device, dtype=dtype)

    @skipIfMps
    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
    @dtypes(torch.float)
    def test_AdaptiveMaxPool1d_indices(self, device, dtype):
        self._test_maxpool_indices(1, adaptive=True, device=device, dtype=dtype)

    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
    @skipIfMps
    @dtypes(torch.float)
    def test_AdaptiveMaxPool2d_indices(self, device, dtype):
        self._test_maxpool_indices(2, adaptive=True, device=device, dtype=dtype)

    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
    @skipIfMps
    @dtypes(torch.float)
    def test_AdaptiveMaxPool3d_indices(self, device, dtype):
        self._test_maxpool_indices(3, adaptive=True, device=device, dtype=dtype)

    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
    @skipIfMps
    @dtypes(torch.float)
    def test_maxpool_indices_no_batch_dim(self, device, dtype):
        """Check that indices with no batch dim is consistent with a single batch."""
        max_pool_cases = [
            (nn.MaxPool1d(3, return_indices=True),
             torch.randn(3, 5, device=device, dtype=dtype)),
            (nn.MaxPool2d(3, return_indices=True),
             torch.randn(3, 5, 6, device=device, dtype=dtype)),
            (nn.MaxPool3d(3, return_indices=True),
             torch.randn(3, 5, 6, 7, device=device, dtype=dtype)),
            (nn.AdaptiveMaxPool1d(3, return_indices=True),
             torch.randn(3, 5, device=device, dtype=dtype)),
            (nn.AdaptiveMaxPool2d(3, return_indices=True),
             torch.randn(3, 5, 6, device=device, dtype=dtype)),
            (nn.AdaptiveMaxPool3d(3, return_indices=True),
             torch.randn(3, 5, 6, 7, device=device, dtype=dtype))]

        for module, input in max_pool_cases:
            _, indices_no_batch = module(input)
            _, indicies_single_batch = module(input.unsqueeze(0))
            self.assertEqual(indices_no_batch, indicies_single_batch.squeeze(0))

    @dtypesIfCUDA(torch.half, torch.float, torch.double)
    @dtypes(torch.float)
    @onlyNativeDeviceTypes  # TODO: Fails on XLA
    def test_max_pool_nan_inf(self, device, dtype):
        for adaptive in ['', 'adaptive_']:
            for num_dim in [1, 2, 3]:
                fn_name = '{}max_pool{}d'.format(adaptive, num_dim)
                fn = getattr(F, fn_name)

                x = torch.full([1, 1] + num_dim * [3], nan, device=device, dtype=dtype, requires_grad=True)
                res = fn(x, 1 if adaptive else 3)
                res.backward(torch.randn_like(res))
                self.assertTrue(math.isnan(res.item()))
                x.requires_grad_(False)
                res = fn(x, 1 if adaptive else 3)
                self.assertTrue(math.isnan(res.item()))

                x2 = torch.full([1, 1] + num_dim * [3], -inf, device=device, dtype=dtype, requires_grad=True)
                res2 = fn(x2, 1 if adaptive else 3)
                res2.backward(torch.randn_like(res2))
                self.assertTrue(math.isinf(res2.item()))
                x2.requires_grad_(False)
                res2 = fn(x2, 1 if adaptive else 3)
                self.assertTrue(math.isinf(res2.item()))

    @expectedFailureMeta  # RuntimeError: Unrecognized tensor type ID: Meta
    @onlyNativeDeviceTypes
    def test_fractional_max_pool2d(self, device):
        with set_default_dtype(torch.double):
            x = torch.randn(1, 2, 7, 7, requires_grad=True, device=device)
            samples = x.new(1, 2, 2).uniform_()

            def func(x):
                return F.fractional_max_pool2d(
                    x, (2, 2), output_size=(3, 3), _random_samples=samples)

            self.assertEqual(func(x).shape, (1, 2, 3, 3))
            gradcheck(func, [x])
            gradgradcheck(func, [x])

            x = torch.randn(2, 7, 7, requires_grad=True, device=device)
            self.assertEqual(func(x).shape, (2, 3, 3))
            if self.device_type != 'cuda':
                # Reference: https://github.com/pytorch/pytorch/issues/52427
                # Raises -> RuntimeError: TensorAccessor expected 4 dims but tensor has 3
                # on CUDA in gradcheck
                gradcheck(func, [x])
                gradgradcheck(func, [x])

            for kernel_size in [(), (1,)]:
                with self.assertRaisesRegex(RuntimeError, "kernel_size must either"):
                    # Incorrect kernel_size
                    F.fractional_max_pool2d(x, kernel_size=kernel_size, output_size=(3, 3), _random_samples=samples)

            err_large_msg = "too large relative to input "
            err_out_size_msg = "output_size must either"
            for output_size, msg in [((9, 3), err_large_msg + "height"),
                                     ((3, 9), err_large_msg + "width"),
                                     ((3,), err_out_size_msg),
                                     ((), err_out_size_msg)]:
                with self.assertRaisesRegex(RuntimeError, msg):
                    # Incorrect output_size
                    F.fractional_max_pool2d(x, (2, 2), output_size=output_size, _random_samples=samples)

    @expectedFailureMeta  # RuntimeError: Unrecognized tensor type ID: Meta
    @onlyNativeDeviceTypes
    def test_fractional_max_pool3d(self, device):
        with set_default_dtype(torch.double):
            x = torch.randn(1, 2, 7, 7, 7, requires_grad=True, device=device)
            samples = x.new(1, 2, 3).uniform_()

            def func(x):
                return F.fractional_max_pool3d(
                    x, (2, 2, 2), output_size=(3, 3, 3), _random_samples=samples)

            self.assertEqual(func(x).shape, (1, 2, 3, 3, 3))
            gradcheck(func, [x])
            gradgradcheck(func, [x])

            x = torch.randn(2, 7, 7, 7, requires_grad=True, device=device)
            self.assertEqual(func(x).shape, (2, 3, 3, 3))
            gradcheck(func, [x])
            gradgradcheck(func, [x])

            for kernel_size in [(), (1,), (1, 1)]:
                with self.assertRaisesRegex(RuntimeError, "kernel_size must either"):
                    # Incorrect kernel_size
                    F.fractional_max_pool3d(x, kernel_size=kernel_size, output_size=(3, 3, 3), _random_samples=samples)

            err_large_msg = "too large relative to input "
            err_out_size_msg = "output_size must either"
            for output_size, msg in [((9, 3, 3), err_large_msg + "time"),
                                     ((3, 9, 3), err_large_msg + "height"),
                                     ((3, 3, 9), err_large_msg + "width"),
                                     ((3, 3), err_out_size_msg),
                                     ((3,), err_out_size_msg),
                                     ((), err_out_size_msg)]:
                with self.assertRaisesRegex(RuntimeError, msg):
                    # Incorrect output_size
                    F.fractional_max_pool3d(x, (2, 2, 2), output_size=output_size, _random_samples=samples)

    @dtypesIfCUDA(torch.half, torch.float, torch.double)
    @dtypes(torch.float)
    @onlyNativeDeviceTypes  # TODO: Fails on XLA
    def test_fractional_max_pool_nan_inf(self, device, dtype):
        for num_dim in [2, 3]:
            fn_name = 'FractionalMaxPool{}d'.format(num_dim)
            fn = getattr(nn, fn_name)(kernel_size=2, output_size=1)
            x = torch.full([1, 1] + num_dim * [3], nan, device=device, dtype=dtype, requires_grad=True)
            res = fn(x)
            res.backward(torch.randn_like(res))
            self.assertTrue(math.isnan(res.item()))

            x2 = torch.full([1, 1] + num_dim * [3], -inf, device=device, dtype=dtype, requires_grad=True)
            res2 = fn(x2)
            res2.backward(torch.randn_like(res2))
            self.assertTrue(math.isinf(res2.item()))

    @onlyNativeDeviceTypes  # TODO: RuntimeError message different on XLA
    def test_pooling_zero_stride(self, device):
        for op in ('max', 'avg'):
            for num_dim in [1, 2, 3]:
                fn_name = '{}_pool{}d'.format(op, num_dim)
                fn = getattr(F, fn_name)
                x = torch.ones([1, 2] + num_dim * [4], device=device, dtype=torch.float)
                self.assertRaisesRegex(RuntimeError, r"stride should not be zero|stride must be greater than zero",
                                       lambda: fn(x, kernel_size=2, stride=0))

                fn_module_name = '{}Pool{}d'.format(op.title(), num_dim)
                fn_module = getattr(nn, fn_module_name)(kernel_size=2, stride=0)
                self.assertRaisesRegex(RuntimeError, r"stride should not be zero|stride must be greater than zero",
                                       lambda: fn_module(x))

    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
    @skipIfMps
    @dtypes(torch.float)
    def test_pool_large_size(self, device, dtype):
        for op in ('max', 'avg'):
            for num_dim in [1, 2, 3]:
                fn_name = '{}_pool{}d'.format(op, num_dim)
                fn = getattr(F, fn_name)
                # 16777217 is the smallest integer not expressible in float32
                x = torch.ones([1, 1, 16777217] + (num_dim - 1) * [1],
                               device=device, dtype=dtype)
                res = fn(x, 1, stride=1, padding=0)
                # check if the output shape was still computed correctly
                self.assertEqual(x.shape[2], res.shape[2])

    @onlyCUDA
    @largeTensorTest('6GB')
    def test_pooling_large(self, device):
        def helper(pool):
            inp = torch.randn(2**7 + 10, 2**8, 2**8, 2**8, dtype=torch.half, device="cuda")
            self.assertTrue(inp.numel() > 2**31 - 1)
            out = pool(inp)
            torch.cuda.synchronize()    # asserts test finishes normally without raising errors

        helper(nn.MaxPool2d(4, 4))
        helper(nn.AvgPool2d(4, 4))
        helper(nn.FractionalMaxPool2d(4, 4))
        helper(nn.AdaptiveMaxPool2d((2**6, 2**6)))
        helper(nn.AdaptiveAvgPool2d((2**6, 2**6)))

    @dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
    @skipIfMps
    @dtypes(torch.float)
    def test_pool_invalid_size(self, device, dtype):
        for op in ('max', 'avg'):
            for num_dim in [1, 2, 3]:
                fn_name = '{}_pool{}d'.format(op, num_dim)
                if op == 'max':
                    # New implementation without indices supports empty tensors
                    # TODO(Heitor) change once with_indices code is updated
                    fn_name += '_with_indices'
                fn = getattr(F, fn_name)
                # use a configuration that gives zero outputs only
                # when doing a correct floor division by the stride
                x = torch.ones([1, 1] + num_dim * [4],
                               device=device, dtype=dtype)
                with self.assertRaisesRegex(RuntimeError, r"too small|smaller than"):
                    try:
                        res = fn(x, 3, stride=2, padding=0, dilation=2)
                    except TypeError:
                        # some implementations do not support dilation
                        res = fn(x, 6, stride=2, padding=0)

    @onlyCUDA
    def test_pooling_bfloat16(self, device):
        _test_bfloat16_ops(self, torch.nn.AvgPool1d(3, stride=2), device, inp_dims=(8, 4, 16), prec=0.05)
        _test_bfloat16_ops(self, torch.nn.AvgPool2d(3, stride=2), device, inp_dims=(8, 4, 16, 16), prec=0.05)
        _test_bfloat16_ops(self, torch.nn.AvgPool3d(3, stride=2), device, inp_dims=(8, 4, 16, 16, 16), prec=0.05)
        _test_bfloat16_ops(self, torch.nn.AdaptiveAvgPool1d(3), device, inp_dims=(8, 4, 16), prec=0.05)
        _test_bfloat16_ops(self, torch.nn.AdaptiveAvgPool2d((3, 5)), device, inp_dims=(8, 4, 16, 16), prec=0.05)
        _test_bfloat16_ops(self, torch.nn.AdaptiveAvgPool3d((3, 5, 7)), device, inp_dims=(8, 4, 16, 16, 16), prec=0.05)

    def test_maxpool3d_non_square_backward(self, device):
        # previous CUDA routine of this backward calculates kernel launch grid size
        # with last two dimensions interchanged, so the tailing along the longer dim
        # get ignored. Here we test whether every position gets gradient.
        for dim in (2, 3, 4):
            shape = tuple(32 if i != dim else 256 for i in range(4))
            x = torch.randn(shape, device=device, requires_grad=True)
            F.max_pool3d(x, kernel_size=(1, 1, 1)).sum().backward()
            self.assertEqual(x.grad, torch.ones_like(x.grad))

    def test_adaptive_pool_invalid(self, device):
        inp_1d = (torch.randn(1, 1, 1, device=device), (-1,))
        inp_2d = (torch.randn(1, 1, 1, 1, device=device), (-1, 0))
        inp_3d = (torch.randn(1, 1, 1, 1, 1, device=device), (-1, 0, 2))
        module_input_dict = {torch.nn.AdaptiveAvgPool1d: inp_1d,
                             torch.nn.AdaptiveAvgPool2d: inp_2d,
                             torch.nn.AdaptiveAvgPool3d: inp_3d}

        for m, inp in module_input_dict.items():
            with self.assertRaisesRegex(RuntimeError,
                                        r"elements of output_size must be greater than or equal to 0"):
                t, output_size = inp
                m(output_size)(t)

    @slowTest
    def test_adaptive_pool_odd_size(self, device):
        # See https://github.com/pytorch/pytorch/issues/81409
        Ih, Iw, Oh, Ow = 5873, 3693, 3527, 2219
        imgs = torch.randint(low=0, high=256, size=(11, Ih, Iw), dtype=torch.float)
        imgs_ = F.adaptive_avg_pool2d(imgs, (Oh, Ow))
        imgs_ = F.adaptive_max_pool2d(imgs, (Oh, Ow))

        Id, Ih, Iw, Od, Oh, Ow = 3, 5873, 3693, 3, 3527, 2219
        imgs = torch.randint(low=0, high=256, size=(3, Id, Ih, Iw), dtype=torch.float)
        imgs_ = F.adaptive_avg_pool3d(imgs, (Od, Oh, Ow))
        imgs_ = F.adaptive_max_pool3d(imgs, (Od, Oh, Ow))

instantiate_device_type_tests(TestPoolingNNDeviceType, globals())
instantiate_parametrized_tests(TestPoolingNN)

if __name__ == '__main__':
    run_tests()
