1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
|
# Owner(s): ["module: nn"]
import math
import unittest
from typing import List, Tuple, Union
import torch
from torch._inductor import config
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, run_tests
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
default_atol = {
torch.float16: 1e-3,
torch.bfloat16: float("infinity"),
torch.float32: 1e-5,
}
default_rtol = {
torch.float16: 1e-3,
torch.bfloat16: float("infinity"),
torch.float32: 1.3e-6,
}
def rand_math_tensor(
shape: Tuple[Union[int, List[int]]],
device: str,
dtype: torch.dtype,
requires_grad: bool = False,
packed: bool = False,
) -> torch.Tensor:
"""Creates rand dense or nested tensor with given shape and type.
Args:
shape (Tuple[int]): Shape of Tensor to construct
device (str): which device to create tensor on
dtype (torch.dtype): Tensors' dtype
requires_grad (bool, optional): Tensors grad status. Defaults to False.
packed (bool, optional): Whether to create a single QKV packed or not. Defaults to False.
Returns:
torch.Tensor: A new tensor
"""
return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad)
def init_tensor(tensor_list, **kwargs) -> torch.Tensor:
return torch.Tensor(tensor_list).to(**kwargs)
def run_comp_nocomp(function, *inputs, **kwargs):
c_function = torch.compile(function)
f_res = function(*inputs)
cf_res = c_function(*inputs)
if not (math.isinf(kwargs.get("atol", 0.0)) or math.isinf(kwargs.get("rtol", 0.0))):
torch.testing.assert_close(f_res, cf_res, **kwargs)
# The test functions are used by several tests
def torch_mm(a, b):
return torch.mm(a, b)
def torch_addmm(add, b, c):
return torch.addmm(add, b, c)
def torch_bmm(a, b):
return torch.bmm(a, b)
def torch_baddbmm(add, b, c, alpha, beta):
return torch.baddbmm(add, b, c, alpha=alpha, beta=beta)
# The shapes we test on
ts_list = [
(1, 32, 32, 1),
(1, 10, 10, 1),
(1, 3, 3, 1),
(32, 1, 1, 32),
(3, 1, 1, 3),
(4, 1, 1, 9),
(9, 1, 1, 4),
]
class TestDecomp(NNTestCase):
_do_cuda_memory_leak_check = GPU_TYPE == "cuda"
_do_cuda_non_default_stream = GPU_TYPE == "cuda"
@unittest.skipIf(not HAS_GPU, "GPU tests require triton")
@parametrize("dtype", [torch.float, torch.bfloat16])
def test_simple_mm(self, device, dtype):
fudge = 10
rtol = default_rtol[dtype] * fudge
atol = default_atol[dtype] * fudge
for t_size in ts_list:
((a1_0, a1_1, a2_0, a2_1)) = t_size
t1 = rand_math_tensor((a1_0, a1_1), dtype=dtype, device=device)
t2 = rand_math_tensor((a2_0, a2_1), dtype=dtype, device=device)
tadd = rand_math_tensor((a1_0, a2_1), dtype=dtype, device=device)
run_comp_nocomp(torch_mm, t1, t2, rtol=rtol, atol=atol)
run_comp_nocomp(torch_addmm, tadd, t1, t2, rtol=rtol, atol=atol)
@unittest.skipIf(not HAS_GPU, "GPU tests require triton")
@parametrize(
"dtype", [torch.float, torch.bfloat16] if SM80OrLater else [torch.float]
)
@parametrize("bs", [1, 2, 4, 10])
def test_batched_mm(self, device, dtype, bs):
fudge = 3
rtol = default_rtol[dtype] * fudge
atol = default_atol[dtype] * fudge
for t_size in ts_list:
((a1_0, a1_1, a2_0, a2_1)) = t_size
t1 = rand_math_tensor((bs, a1_0, a1_1), dtype=dtype, device=device)
t2 = rand_math_tensor((bs, a2_0, a2_1), dtype=dtype, device=device)
tadd = rand_math_tensor((bs, a1_0, a2_1), dtype=dtype, device=device)
run_comp_nocomp(torch_bmm, t1, t2, rtol=rtol, atol=atol)
for alpha in (0, 1, -1, 0.5, -0.5):
for beta in (0, 1, -1, 0.5, -0.5):
run_comp_nocomp(
torch_baddbmm, tadd, t1, t2, alpha, beta, rtol=rtol, atol=atol
)
@unittest.skipIf(not HAS_GPU, "GPU tests require triton")
@config.patch(coordinate_descent_tuning=True)
def test_bmm_batch2_last_dim_size_is_one(self, device):
fudge = 3
rtol = default_rtol[torch.float32] * fudge
atol = default_atol[torch.float32] * fudge
t1 = torch.randn(1, 32, 2, device=device)
t2 = torch.randn(1, 2, 1, device=device)
run_comp_nocomp(torch_bmm, t1, t2, rtol=rtol, atol=atol)
@unittest.skipIf(not HAS_GPU, "GPU tests require triton")
@parametrize("dtype", [torch.float, torch.bfloat16, torch.int])
def test_some(self, device, dtype):
# this Pytorch data type is not fully supported on cuda today
# - unfortunately we can't skipIf because we don't see the actual parms in skipIf
if device.startswith(GPU_TYPE) and dtype == torch.int:
return
run_comp_nocomp(
torch_mm,
init_tensor([[1], [2], [3], [4]], dtype=dtype, device=device),
init_tensor([[1, 2, 3, 4]], dtype=dtype, device=device),
)
run_comp_nocomp(
torch_mm,
init_tensor([[1, 2, 3, 4]], dtype=dtype, device=device),
init_tensor([[1], [2], [3], [4]], dtype=dtype, device=device),
)
@unittest.skipIf(not HAS_GPU, "GPU tests require triton")
@parametrize("dtype", [torch.float, torch.bfloat16, torch.int])
@parametrize("bs", [1, 2, 4, 10])
def test_some_batched(self, device, dtype, bs):
# this Pytorch data type is not fully supported on cuda today
# - unfortunately we can't skipIf because we don't see the actual parms in skipIf
if device.startswith(GPU_TYPE) and dtype == torch.int:
return
run_comp_nocomp(
torch_bmm,
init_tensor([[[1], [2], [3], [4]]] * bs, dtype=dtype, device=device),
init_tensor([[[1, 2, 3, 4]]] * bs, dtype=dtype, device=device),
)
run_comp_nocomp(
torch_bmm,
init_tensor([[[1, 2, 3, 4]]] * bs, dtype=dtype, device=device),
init_tensor([[[1], [2], [3], [4]]] * bs, dtype=dtype, device=device),
)
device_types = ("cpu", GPU_TYPE)
instantiate_device_type_tests(TestDecomp, globals(), only_for=device_types)
if __name__ == "__main__":
# We don't support torch.compile() on Windows
if not IS_WINDOWS:
run_tests()
|