File: test_executorch_custom_ops.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (147 lines) | stat: -rw-r--r-- 5,320 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from __future__ import annotations

import tempfile
import unittest
from typing import Any
from unittest.mock import ANY, Mock, patch

import expecttest

import torchgen
from torchgen.executorch.api.custom_ops import ComputeNativeFunctionStub
from torchgen.executorch.model import ETKernelIndex
from torchgen.gen_executorch import gen_headers
from torchgen.model import Location, NativeFunction
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import FileManager


SPACES = "    "


def _get_native_function_from_yaml(yaml_obj: dict[str, object]) -> NativeFunction:
    native_function, _ = NativeFunction.from_yaml(
        yaml_obj,
        loc=Location(__file__, 1),
        valid_tags=set(),
    )
    return native_function


class TestComputeNativeFunctionStub(expecttest.TestCase):
    """
    Could use torch.testing._internal.common_utils to reduce boilerplate.
    GH CI job doesn't build torch before running tools unit tests, hence
    manually adding these parametrized tests.
    """

    def _test_function_schema_generates_correct_kernel(
        self, obj: dict[str, Any], expected: str
    ) -> None:
        func = _get_native_function_from_yaml(obj)

        gen = ComputeNativeFunctionStub()
        res = gen(func)
        self.assertIsNotNone(res)
        self.assertExpectedInline(
            str(res),
            expected,
        )

    def test_function_schema_generates_correct_kernel_tensor_out(self) -> None:
        obj = {"func": "custom::foo.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)"}
        expected = """
at::Tensor & wrapper_CPU_out_foo_out(const at::Tensor & self, at::Tensor & out) {
    return out;
}
    """
        self._test_function_schema_generates_correct_kernel(obj, expected)

    def test_function_schema_generates_correct_kernel_no_out(self) -> None:
        obj = {"func": "custom::foo.Tensor(Tensor self) -> Tensor"}
        expected = """
at::Tensor wrapper_CPU_Tensor_foo(const at::Tensor & self) {
    return self;
}
    """
        self._test_function_schema_generates_correct_kernel(obj, expected)

    def test_function_schema_generates_correct_kernel_no_return(self) -> None:
        obj = {"func": "custom::foo.out(Tensor self, *, Tensor(a!)[] out) -> ()"}
        expected = f"""
void wrapper_CPU_out_foo_out(const at::Tensor & self, at::TensorList out) {{
{SPACES}
}}
    """
        self._test_function_schema_generates_correct_kernel(obj, expected)

    def test_function_schema_generates_correct_kernel_3_returns(self) -> None:
        obj = {
            "func": "custom::foo(Tensor self, Tensor[] other) -> (Tensor, Tensor, Tensor)"
        }
        expected = """
::std::tuple<at::Tensor,at::Tensor,at::Tensor> wrapper_CPU__foo(const at::Tensor & self, at::TensorList other) {
    return ::std::tuple<at::Tensor, at::Tensor, at::Tensor>(
                at::Tensor(), at::Tensor(), at::Tensor()
            );
}
    """
        self._test_function_schema_generates_correct_kernel(obj, expected)

    def test_function_schema_generates_correct_kernel_1_return_no_out(self) -> None:
        obj = {"func": "custom::foo(Tensor[] a) -> Tensor"}
        expected = """
at::Tensor wrapper_CPU__foo(at::TensorList a) {
    return at::Tensor();
}
    """
        self._test_function_schema_generates_correct_kernel(obj, expected)

    def test_schema_has_no_return_type_argument_throws(self) -> None:
        func = _get_native_function_from_yaml(
            {"func": "custom::foo.bool(Tensor self) -> bool"}
        )

        gen = ComputeNativeFunctionStub()
        with self.assertRaisesRegex(Exception, "Can't handle this return type"):
            gen(func)


class TestGenCustomOpsHeader(unittest.TestCase):
    @patch.object(torchgen.utils.FileManager, "write_with_template")
    @patch.object(torchgen.utils.FileManager, "write")
    def test_fm_writes_custom_ops_header_when_boolean_is_true(
        self, unused: Mock, mock_method: Mock
    ) -> None:
        with tempfile.TemporaryDirectory() as tempdir:
            fm = FileManager(tempdir, tempdir, False)
            gen_headers(
                native_functions=[],
                gen_custom_ops_header=True,
                custom_ops_native_functions=[],
                selector=SelectiveBuilder.get_nop_selector(),
                kernel_index=ETKernelIndex(index={}),
                cpu_fm=fm,
                use_aten_lib=False,
            )
            mock_method.assert_called_once_with(
                "CustomOpsNativeFunctions.h", "NativeFunctions.h", ANY
            )

    @patch.object(torchgen.utils.FileManager, "write_with_template")
    @patch.object(torchgen.utils.FileManager, "write")
    def test_fm_doesnot_writes_custom_ops_header_when_boolean_is_false(
        self, unused: Mock, mock_method: Mock
    ) -> None:
        with tempfile.TemporaryDirectory() as tempdir:
            fm = FileManager(tempdir, tempdir, False)
            gen_headers(
                native_functions=[],
                gen_custom_ops_header=False,
                custom_ops_native_functions=[],
                selector=SelectiveBuilder.get_nop_selector(),
                kernel_index=ETKernelIndex(index={}),
                cpu_fm=fm,
                use_aten_lib=False,
            )
            mock_method.assert_not_called()