File: test_triton_extension_backend.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (118 lines) | stat: -rw-r--r-- 3,902 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Owner(s): ["module: inductor"]
import random
import string
import sys
import unittest

import torch
import torch._dynamo
import torch.utils.cpp_extension


try:
    from extension_backends.triton.device_interface import (  # @manual=fbcode//caffe2/test/inductor/extension_backends:extension_codegen_backend  # noqa: B950
        DeviceInterface,
    )
    from extension_backends.triton.extension_codegen_backend import (  # @manual=fbcode//caffe2/test/inductor/extension_backends:extension_codegen_backend  # noqa: B950
        CPUDeviceOpOverrides,
        ExtensionScheduling,
        ExtensionWrapperCodegen,
    )
except ImportError:
    from .extension_backends.triton.device_interface import DeviceInterface
    from .extension_backends.triton.extension_codegen_backend import (
        CPUDeviceOpOverrides,
        ExtensionScheduling,
        ExtensionWrapperCodegen,
    )

from torch._C import FileCheck
from torch._dynamo import device_interface
from torch._inductor import metrics
from torch._inductor.codegen.common import (
    get_scheduling_for_device,
    get_wrapper_codegen_for_device,
    register_backend_for_device,
    register_device_op_overrides,
)
from torch._inductor.utils import get_triton_code
from torch.testing._internal.common_utils import IS_FBCODE, IS_MACOS


try:
    from .test_extension_backend import BaseExtensionBackendTests
except ImportError:
    from test_extension_backend import BaseExtensionBackendTests

try:
    try:
        from . import test_torchinductor
    except ImportError:
        import test_torchinductor  # @manual=fbcode//caffe2/test/inductor:test_inductor-library
except unittest.SkipTest:
    if __name__ == "__main__":
        sys.exit(0)
    raise


TestCase = test_torchinductor.TestCase


def mock_triton_hash_with_backend(*args, **kwargs):
    # Generate a random string of length 64. Used to mock the triton_hash_with_backend function
    # since we don't have a triton backend
    return "".join(random.choices(string.ascii_uppercase + string.digits, k=64))


@unittest.skipIf(IS_FBCODE, "cpp_extension doesn't work in fbcode right now")
class TritonExtensionBackendTests(BaseExtensionBackendTests):
    """
    Test creating a backend for inductor with Triton scheduling.
    """

    def test_open_device_registration(self):
        torch._register_device_module("privateuseone", self.module)
        register_backend_for_device(
            "privateuseone", ExtensionScheduling, ExtensionWrapperCodegen
        )
        register_device_op_overrides("privateuseone", CPUDeviceOpOverrides())
        device_interface.register_interface_for_device("privateuseone", DeviceInterface)

        self.assertEqual(
            get_scheduling_for_device("privateuseone"), ExtensionScheduling
        )
        self.assertEqual(
            get_wrapper_codegen_for_device("privateuseone"), ExtensionWrapperCodegen
        )
        self.assertEqual(
            device_interface.get_interface_for_device("privateuseone"), DeviceInterface
        )

        device = torch.device("privateuseone")
        x = torch.empty(2, 16).fill_(1).to(device)

        def foo(x):
            return torch.sin(x) + x.min()

        metrics.reset()
        opt_fn = torch.compile(foo)

        # Since we don't have a triton backend, we need to mock the triton_hash_with_backend
        # function
        with unittest.mock.patch(
            "torch.utils._triton.triton_hash_with_backend",
            new=mock_triton_hash_with_backend,
        ):
            code = get_triton_code(opt_fn, x)

        FileCheck().check("import triton").check("@triton.jit").check(
            "tl_math.sin"
        ).check("device_str='privateuseone'").run(code)


if __name__ == "__main__":
    from torch._inductor.test_case import run_tests
    from torch.testing._internal.inductor_utils import HAS_CPU

    if HAS_CPU and not IS_MACOS:
        run_tests()