1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
|
import unittest
from torchgen.executorch.api.types import ExecutorchCppSignature
from torchgen.local import parametrize
from torchgen.model import Location, NativeFunction
DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml(
{"func": "foo.out(Tensor input, *, Tensor(a!) out) -> Tensor(a!)"},
loc=Location(__file__, 1),
valid_tags=set(),
)
class ExecutorchCppSignatureTest(unittest.TestCase):
def setUp(self) -> None:
self.sig = ExecutorchCppSignature.from_native_function(DEFAULT_NATIVE_FUNCTION)
def test_runtime_signature_contains_runtime_context(self) -> None:
# test if `KernelRuntimeContext` argument exists in `RuntimeSignature`
with parametrize(
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
):
args = self.sig.arguments(include_context=True)
self.assertEqual(len(args), 3)
self.assertTrue(any(a.name == "context" for a in args))
def test_runtime_signature_does_not_contain_runtime_context(self) -> None:
# test if `KernelRuntimeContext` argument is missing in `RuntimeSignature`
with parametrize(
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
):
args = self.sig.arguments(include_context=False)
self.assertEqual(len(args), 2)
self.assertFalse(any(a.name == "context" for a in args))
def test_runtime_signature_declaration_correct(self) -> None:
with parametrize(
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
):
decl = self.sig.decl(include_context=True)
self.assertEqual(
decl,
(
"torch::executor::Tensor & foo_outf("
"torch::executor::KernelRuntimeContext & context, "
"const torch::executor::Tensor & input, "
"torch::executor::Tensor & out)"
),
)
no_context_decl = self.sig.decl(include_context=False)
self.assertEqual(
no_context_decl,
(
"torch::executor::Tensor & foo_outf("
"const torch::executor::Tensor & input, "
"torch::executor::Tensor & out)"
),
)
|