File: test_executorch_signatures.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (59 lines) | stat: -rw-r--r-- 2,389 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import unittest

from torchgen.executorch.api.types import ExecutorchCppSignature
from torchgen.local import parametrize
from torchgen.model import Location, NativeFunction


DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml(
    {"func": "foo.out(Tensor input, *, Tensor(a!) out) -> Tensor(a!)"},
    loc=Location(__file__, 1),
    valid_tags=set(),
)


class ExecutorchCppSignatureTest(unittest.TestCase):
    def setUp(self) -> None:
        self.sig = ExecutorchCppSignature.from_native_function(DEFAULT_NATIVE_FUNCTION)

    def test_runtime_signature_contains_runtime_context(self) -> None:
        # test if `KernelRuntimeContext` argument exists in `RuntimeSignature`
        with parametrize(
            use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
        ):
            args = self.sig.arguments(include_context=True)
            self.assertEqual(len(args), 3)
            self.assertTrue(any(a.name == "context" for a in args))

    def test_runtime_signature_does_not_contain_runtime_context(self) -> None:
        # test if `KernelRuntimeContext` argument is missing in `RuntimeSignature`
        with parametrize(
            use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
        ):
            args = self.sig.arguments(include_context=False)
            self.assertEqual(len(args), 2)
            self.assertFalse(any(a.name == "context" for a in args))

    def test_runtime_signature_declaration_correct(self) -> None:
        with parametrize(
            use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
        ):
            decl = self.sig.decl(include_context=True)
            self.assertEqual(
                decl,
                (
                    "torch::executor::Tensor & foo_outf("
                    "torch::executor::KernelRuntimeContext & context, "
                    "const torch::executor::Tensor & input, "
                    "torch::executor::Tensor & out)"
                ),
            )
            no_context_decl = self.sig.decl(include_context=False)
            self.assertEqual(
                no_context_decl,
                (
                    "torch::executor::Tensor & foo_outf("
                    "const torch::executor::Tensor & input, "
                    "torch::executor::Tensor & out)"
                ),
            )