File: test_cudacodecache.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (99 lines) | stat: -rw-r--r-- 3,088 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Owner(s): ["module: inductor"]

import ctypes
import unittest

import torch
from torch._inductor import config
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.codecache import CUDACodeCache
from torch._inductor.codegen.cuda.cuda_env import nvcc_exist
from torch._inductor.exc import CUDACompileError
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import fresh_inductor_cache


_SOURCE_CODE = r"""

#include <stdio.h>

__global__
void saxpy_device(int n, float a, float *x, float *y)
{
  int i = blockIdx.x*blockDim.x + threadIdx.x;
  if (i < n) y[i] = a*x[i] + y[i];
}

extern "C" {

__attribute__((__visibility__("default")))
int saxpy(int n, float a, float *x, float *y) {
  // Perform SAXPY
  saxpy_device<<<(n+255)/256, 256>>>(n, a, x, y);
  return 0;
}

}
"""


@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUDA_HOME setup")
class TestCUDACodeCache(InductorTestCase):
    def test_cuda_load(self):
        with fresh_inductor_cache():
            # Test both .o and .so compilation.
            (
                object_file_path,
                object_hash_key,
                source_code_path0,
            ) = CUDACodeCache.compile(_SOURCE_CODE, "o")
            dll_wrapper, so_hash_key, source_code_path1 = CUDACodeCache.load(
                _SOURCE_CODE, "so"
            )
            self.assertNotEqual(source_code_path0, source_code_path1)
            self.assertNotEqual(object_hash_key, so_hash_key)

            # Test load and call functions in .so.
            x = torch.rand(10).float().cuda()
            y = torch.rand(10).float().cuda()
            a = 5.0
            expected_y = a * x + y
            res = dll_wrapper.saxpy(
                ctypes.c_int(10),
                ctypes.c_float(a),
                ctypes.c_void_p(x.data_ptr()),
                ctypes.c_void_p(y.data_ptr()),
            )
            torch.testing.assert_close(y, expected_y)

    def test_compilation_error(self):
        with fresh_inductor_cache():
            error_source_code = _SOURCE_CODE.replace("saxpy_device", "saxpy_wrong", 1)
            with self.assertRaises(CUDACompileError):
                CUDACodeCache.compile(error_source_code, "o")

    def test_async_compile(self):
        with fresh_inductor_cache():
            async_compile = AsyncCompile()
            compiled_res = async_compile.cuda(_SOURCE_CODE, "so")
            async_compile.wait(globals())

            # Test load and call functions in .so.
            x = torch.rand(5).float().cuda()
            y = torch.rand(5).float().cuda()
            a = 2.0
            expected_y = a * x + y
            res = compiled_res.result().saxpy(
                ctypes.c_int(5),
                ctypes.c_float(a),
                ctypes.c_void_p(x.data_ptr()),
                ctypes.c_void_p(y.data_ptr()),
            )
            torch.testing.assert_close(y, expected_y)


if __name__ == "__main__":
    from torch._inductor.test_case import run_tests

    if nvcc_exist():
        run_tests("cuda")