1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
|
# Owner(s): ["module: cuda"]
import torch
from torch.cuda.jiterator import _create_jit_fn as create_jit_fn
from torch.cuda.jiterator import _create_multi_output_jit_fn as create_multi_output_jit_fn
import sys
from itertools import product
from torch.testing._internal.common_utils import TestCase, parametrize, run_tests, TEST_CUDA
from torch.testing._internal.common_dtype import all_types_and_complex_and
from torch.testing._internal.common_device_type import (
skipCUDAIfRocm, skipCUDAIf, instantiate_device_type_tests, dtypes, toleranceOverride, tol)
from torch.testing._internal.common_cuda import _get_torch_cuda_version
if not TEST_CUDA:
print('CUDA not available, skipping tests', file=sys.stderr)
TestCase = object # noqa: F811
code_string = "template <typename T> T my_fused_kernel(T x, T y, T alpha, T beta) { return alpha * x + beta * y; }"
jitted_fn = create_jit_fn(code_string, alpha=1, beta=1)
def ref_fn(x, y, alpha=1, beta=1):
return alpha * x + beta * y
class TestPythonJiterator(TestCase):
@parametrize("shape_strides", [
(([3, 3], [3, 1]), ([3, 3], [3, 1])), # contiguous
])
@dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16),
all_types_and_complex_and(torch.half, torch.bfloat16)))
def test_all_dtype_contiguous(self, device, dtypes, shape_strides):
a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0])
b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1])
a = a_buffer.as_strided(*shape_strides[0])
b = b_buffer.as_strided(*shape_strides[1])
expected = ref_fn(a, b)
result = jitted_fn(a, b)
self.assertEqual(expected, result)
@skipCUDAIfRocm
# See https://github.com/pytorch/pytorch/pull/76394#issuecomment-1118018287 for details
@skipCUDAIf(_get_torch_cuda_version() < (11, 6), "On cuda 11.3, nvrtcCompileProgram is taking too long to "
"compile jiterator generated kernels for non-contiguous input that requires dynamic-casting.")
@parametrize("shape_strides", [
(([3, 3], [1, 3]), ([3, 1], [1, 3])), # non-contiguous
])
@dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16),
all_types_and_complex_and(torch.half, torch.bfloat16)))
def test_all_dtype_noncontiguous(self, device, dtypes, shape_strides):
a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0])
b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1])
a = a_buffer.as_strided(*shape_strides[0])
b = b_buffer.as_strided(*shape_strides[1])
expected = ref_fn(a, b)
result = jitted_fn(a, b)
self.assertEqual(expected, result)
@dtypes(torch.float, torch.double, torch.float16, torch.bfloat16)
@parametrize("alpha", [-1, 2.0, None])
@parametrize("beta", [3, -4.2, None])
@toleranceOverride({torch.float16 : tol(atol=1e-2, rtol=1e-3)})
def test_extra_args(self, device, dtype, alpha, beta):
a = torch.rand(3, device=device).mul(10).type(dtype)
b = torch.rand(3, device=device).mul(10).type(dtype)
extra_args = {}
if alpha is not None:
extra_args["alpha"] = alpha
if beta is not None:
extra_args["beta"] = beta
expected = ref_fn(a, b, **extra_args)
result = jitted_fn(a, b, **extra_args)
self.assertEqual(expected, result)
@parametrize("is_train", [True, False])
def test_bool_extra_args(self, device, is_train):
code_string = "template <typename T> T conditional(T x, T mask, bool is_train) { return is_train ? x * mask : x; }"
jitted_fn = create_jit_fn(code_string, is_train=False)
def ref_fn(x, mask, is_train):
return x * mask if is_train else x
a = torch.rand(3, device=device)
b = torch.rand(3, device=device)
expected = ref_fn(a, b, is_train=is_train)
result = jitted_fn(a, b, is_train=is_train)
self.assertEqual(expected, result)
def test_multiple_functors(self, device):
code_string = '''
template <typename T> T fn(T x, T mask) { return x * mask; }
template <typename T> T main_fn(T x, T mask, T y) { return fn(x, mask) + y; }
'''
jitted_fn = create_jit_fn(code_string)
def ref_fn(x, mask, y):
return x * mask + y
a = torch.rand(3, device=device)
b = torch.rand(3, device=device)
c = torch.rand(3, device=device)
expected = ref_fn(a, b, c)
result = jitted_fn(a, b, c)
self.assertEqual(expected, result)
@parametrize("num_inputs", [1, 5, 8])
def test_various_num_inputs(self, num_inputs):
inputs = []
for i in range(num_inputs):
inputs.append(torch.rand(3, device='cuda').mul(10))
input_string = ",".join([f"T i{i}" for i in range(num_inputs)])
function_body = "+".join([f"i{i}" for i in range(num_inputs)])
code_string = f"template <typename T> T my_kernel({input_string}) {{ return {function_body}; }}"
jitted_fn = create_jit_fn(code_string)
def ref_fn(*inputs):
return torch.sum(torch.stack(inputs), dim=0)
expected = ref_fn(*inputs)
result = jitted_fn(*inputs)
self.assertEqual(expected, result)
@parametrize("num_outputs", [1, 4, 8])
def test_various_num_outputs(self, num_outputs):
input = torch.rand(3, device='cuda')
output_string = ", ".join([f"T& out{i}" for i in range(num_outputs)])
function_body = ""
for i in range(num_outputs):
function_body += f"out{i} = input + {i};\n"
# NB: return type must be void, otherwise ROCm silently fails
code_string = f"template <typename T> void my_kernel(T input, {output_string}) {{ {function_body} }}"
jitted_fn = create_multi_output_jit_fn(code_string, num_outputs)
def ref_fn(input):
outputs = []
for i in range(num_outputs):
outputs.append(input + i)
if num_outputs == 1:
return outputs[0]
return tuple(outputs)
expected = ref_fn(input)
result = jitted_fn(input)
for i in range(num_outputs):
self.assertEqual(expected[i], result[i])
@parametrize("code_string", [
"template <typename T> T my _kernel(T x) { return x; }",
"template <typename T> Tmy_kernel(T x) { return x; }",
])
def test_invalid_function_name(self, code_string):
with self.assertRaises(Exception):
jitted_fn = create_jit_fn(code_string)
instantiate_device_type_tests(TestPythonJiterator, globals(), only_for="cuda")
if __name__ == '__main__':
run_tests()
|