File: test_c10_kernel.cpp

package info (click to toggle)
pytorch 2.9.1%2Bdfsg-1~exp2
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 180,096 kB
  • sloc: python: 1,473,255; cpp: 942,030; ansic: 79,796; asm: 7,754; javascript: 2,502; java: 1,962; sh: 1,809; makefile: 628; xml: 8
file content (74 lines) | stat: -rw-r--r-- 1,887 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include <ATen/core/op_registration/op_registration.h>
#include <gtest/gtest.h>
#include <torch/nativert/executor/ExecutionFrame.h>
#include <torch/nativert/graph/Graph.h>
#include <torch/nativert/kernels/C10Kernel.h>
#include <torch/torch.h>

namespace torch::nativert {

at::Tensor foo_kernel(const at::Tensor& a, const at::Tensor& b) {
  return a + b;
}

TEST(C10KernelTest, computeInternal) {
  auto registrar = c10::RegisterOperators().op(
      "test::foo(Tensor a, Tensor b) -> Tensor", &foo_kernel);

  static constexpr std::string_view source =
      R"(graph(%a, %b):
%x = test.foo.default(a=%a, b=%b)
return (%x)
)";

  auto graph = stringToGraph(source);
  const auto& nodes = graph->nodes();
  auto it = nodes.begin();
  std::advance(it, 1);
  const Node& node = *it;

  auto a = at::randn({6, 6, 6});
  auto b = at::randn({6, 6, 6});

  auto frame = ExecutionFrame(*graph);
  frame.setIValue(graph->getValue("a")->id(), a);
  frame.setIValue(graph->getValue("b")->id(), b);

  auto kernel = C10Kernel(&node);

  kernel.computeInternal(frame);

  at::Tensor expected = a + b;
  EXPECT_TRUE(
      torch::equal(frame.getTensor(graph->getValue("x")->id()), expected));
}

TEST(ScalarBinaryOpKernelTest, computeInternal) {
  static constexpr std::string_view source =
      R"(graph(%a, %b):
%x = _operator.add(a=%a, b=%b)
return (%x)
)";

  auto graph = stringToGraph(source);
  const auto& nodes = graph->nodes();
  auto it = nodes.begin();
  std::advance(it, 1);
  const Node& node = *it;

  auto a = 1;
  auto b = 2;

  auto frame = ExecutionFrame(*graph);
  frame.setIValue(graph->getValue("a")->id(), a);
  frame.setIValue(graph->getValue("b")->id(), b);

  auto kernel = ScalarBinaryOpKernel(&node);

  kernel.computeInternal(frame);

  auto expected = a + b;
  EXPECT_EQ(frame.getIValue(graph->getValue("x")->id()).toInt(), expected);
}

} // namespace torch::nativert