File: test_execution_planner.cpp

package info (click to toggle)
pytorch 2.9.1%2Bdfsg-1~exp2
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 180,096 kB
  • sloc: python: 1,473,255; cpp: 942,030; ansic: 79,796; asm: 7,754; javascript: 2,502; java: 1,962; sh: 1,809; makefile: 628; xml: 8
file content (47 lines) | stat: -rw-r--r-- 1,226 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <gtest/gtest.h>
#include <torch/nativert/executor/ExecutionPlanner.h>

namespace torch::nativert {

TEST(ExecutionPlannerTest, CreatePlan) {
  auto graph = stringToGraph(R"(
    graph(%x, %y):
  %a = foo(a=%x, b=%y)
  %b = foo1(a=%x, b=%y)
  %c = foo2(c=%a, d=%b)
  return(%c)
  )");

  {
    auto plan = ExecutionPlanner{*graph}.createPlan();

    auto& values_to_free = plan->valuesToFree;
    EXPECT_EQ(values_to_free.size(), 5);

    for (const auto i : c10::irange(3)) {
      EXPECT_TRUE(values_to_free[i].empty());
    }

    EXPECT_EQ(values_to_free[3].size(), 2);
    std::set<int64_t> ids{values_to_free[3].begin(), values_to_free[3].end()};
    EXPECT_EQ(
        ids,
        std::set<int64_t>(
            {graph->tryGetValue("a")->id(), graph->tryGetValue("b")->id()}));

    EXPECT_EQ(values_to_free[4].size(), 0);
  }

  {
    auto static_values = ExecutionPlanner::staticValues(*graph);
    std::set<int64_t> static_ids{static_values.begin(), static_values.end()};
    EXPECT_EQ(
        static_ids,
        std::set<int64_t>(
            {graph->tryGetValue("x")->id(),
             graph->tryGetValue("y")->id(),
             graph->tryGetValue("c")->id()}));
  }
}

} // namespace torch::nativert