File: test_function_schema.cpp

package info (click to toggle)
pytorch 2.9.1%2Bdfsg-1~exp2
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 180,096 kB
  • sloc: python: 1,473,255; cpp: 942,030; ansic: 79,796; asm: 7,754; javascript: 2,502; java: 1,962; sh: 1,809; makefile: 628; xml: 8
file content (70 lines) | stat: -rw-r--r-- 2,272 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include <gtest/gtest.h>

#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/nativert/executor/memory/FunctionSchema.h>

using namespace ::testing;

int64_t increment_kernel(const at::Tensor& tensor, int64_t input) {
  return input + 1;
}

at::Tensor slice_kernel(const at::Tensor& tensor, int64_t dim) {
  return tensor.slice(dim);
}

TEST(TestFunctionSchema, testNoAlias) {
  auto registrar = c10::RegisterOperators().op(
      "_test::my_op(Tensor dummy, int input) -> int", &increment_kernel);
  auto handle = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""});

  EXPECT_TRUE(handle.has_value());
  EXPECT_TRUE(handle->hasSchema());

  auto nativert_schema = torch::nativert::FunctionSchema(handle->schema());

  EXPECT_FALSE(nativert_schema.alias(0, 0));
  EXPECT_FALSE(nativert_schema.alias(1, 0));

  // bounds check
  EXPECT_THROW(nativert_schema.alias(2, 0), c10::Error);
  EXPECT_THROW(nativert_schema.alias(1, 1), c10::Error);
}

TEST(TestFunctionSchema, testAliasOverride) {
  auto registrar = c10::RegisterOperators().op(
      "_test::my_op(Tensor dummy, int input) -> int", &increment_kernel);
  auto handle = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""});

  EXPECT_TRUE(handle.has_value());
  EXPECT_TRUE(handle->hasSchema());

  auto nativert_schema =
      torch::nativert::FunctionSchema(handle->schema(), {{0, 0}});

  EXPECT_TRUE(nativert_schema.alias(0, 0));
  EXPECT_FALSE(nativert_schema.alias(1, 0));

  // bounds check
  EXPECT_THROW(nativert_schema.alias(2, 0), c10::Error);
  EXPECT_THROW(nativert_schema.alias(1, 1), c10::Error);
}

TEST(TestFunctionSchema, testAlias) {
  auto registrar = c10::RegisterOperators().op(
      "_test::my_op(Tensor(a) dummy, int input) -> Tensor(a)", &slice_kernel);
  auto handle = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""});

  EXPECT_TRUE(handle.has_value());
  EXPECT_TRUE(handle->hasSchema());

  auto nativert_schema = torch::nativert::FunctionSchema(handle->schema());

  EXPECT_TRUE(nativert_schema.alias(0, 0));
  EXPECT_FALSE(nativert_schema.alias(1, 0));

  // bounds check
  EXPECT_THROW(nativert_schema.alias(2, 0), c10::Error);
  EXPECT_THROW(nativert_schema.alias(1, 1), c10::Error);
}