1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
|
# Owner(s): ["module: inductor"]
import ctypes
import unittest
import torch
from torch._inductor import config
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.codecache import CUDACodeCache
from torch._inductor.codegen.cuda.cuda_env import nvcc_exist
from torch._inductor.exc import CUDACompileError
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import fresh_inductor_cache
_SOURCE_CODE = r"""
#include <stdio.h>
__global__
void saxpy_device(int n, float a, float *x, float *y)
{
int i = blockIdx.x*blockDim.x + threadIdx.x;
if (i < n) y[i] = a*x[i] + y[i];
}
extern "C" {
__attribute__((__visibility__("default")))
int saxpy(int n, float a, float *x, float *y) {
// Perform SAXPY
saxpy_device<<<(n+255)/256, 256>>>(n, a, x, y);
return 0;
}
}
"""
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUDA_HOME setup")
class TestCUDACodeCache(InductorTestCase):
def test_cuda_load(self):
with fresh_inductor_cache():
# Test both .o and .so compilation.
(
object_file_path,
object_hash_key,
source_code_path0,
) = CUDACodeCache.compile(_SOURCE_CODE, "o")
dll_wrapper, so_hash_key, source_code_path1 = CUDACodeCache.load(
_SOURCE_CODE, "so"
)
self.assertNotEqual(source_code_path0, source_code_path1)
self.assertNotEqual(object_hash_key, so_hash_key)
# Test load and call functions in .so.
x = torch.rand(10).float().cuda()
y = torch.rand(10).float().cuda()
a = 5.0
expected_y = a * x + y
res = dll_wrapper.saxpy(
ctypes.c_int(10),
ctypes.c_float(a),
ctypes.c_void_p(x.data_ptr()),
ctypes.c_void_p(y.data_ptr()),
)
torch.testing.assert_close(y, expected_y)
def test_compilation_error(self):
with fresh_inductor_cache():
error_source_code = _SOURCE_CODE.replace("saxpy_device", "saxpy_wrong", 1)
with self.assertRaises(CUDACompileError):
CUDACodeCache.compile(error_source_code, "o")
def test_async_compile(self):
with fresh_inductor_cache():
async_compile = AsyncCompile()
compiled_res = async_compile.cuda(_SOURCE_CODE, "so")
async_compile.wait(globals())
# Test load and call functions in .so.
x = torch.rand(5).float().cuda()
y = torch.rand(5).float().cuda()
a = 2.0
expected_y = a * x + y
res = compiled_res.result().saxpy(
ctypes.c_int(5),
ctypes.c_float(a),
ctypes.c_void_p(x.data_ptr()),
ctypes.c_void_p(y.data_ptr()),
)
torch.testing.assert_close(y, expected_y)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if nvcc_exist():
run_tests("cuda")
|