File: test_upgrader_utils.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (99 lines) | stat: -rw-r--r-- 3,198 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include <gtest/gtest.h>
#include <torch/csrc/jit/operator_upgraders/utils.h>
#include <torch/csrc/jit/operator_upgraders/version_map.h>

#include <test/cpp/jit/test_utils.h>

#include <vector>

namespace torch {
namespace jit {

TEST(UpgraderUtils, FindCorrectUpgrader) {
  std::vector<UpgraderEntry> dummy_entry = {
      {4, "foo__0_3", "foo.bar()"},
      {8, "foo__4_7", "foo.bar()"},
  };

  auto upgrader_at_6 = findUpgrader(dummy_entry, 6);
  EXPECT_TRUE(upgrader_at_6.has_value());
  EXPECT_EQ(upgrader_at_6.value().upgrader_name, "foo__4_7");

  auto upgrader_at_1 = findUpgrader(dummy_entry, 1);
  EXPECT_TRUE(upgrader_at_1.has_value());
  EXPECT_EQ(upgrader_at_1.value().upgrader_name, "foo__0_3");

  auto upgrader_at_10 = findUpgrader(dummy_entry, 10);
  EXPECT_TRUE(upgrader_at_1.has_value());
  EXPECT_EQ(upgrader_at_1.value().upgrader_name, "foo__0_3");
}

TEST(UpgraderUtils, IsVersionMapSorted) {
  auto map = get_operator_version_map();
  // tests if the each list of UpgraderEntry in the map is sorted by
  // their bumped_at_version field.
  for (const auto& entry : map) {
    std::vector<int> versions;
    for (const auto& el : entry.second) {
      versions.push_back(el.bumped_at_version);
    }
    EXPECT_TRUE(std::is_sorted(versions.begin(), versions.end()));
  }
}

TEST(UpgraderUtils, FindIfOpIsCurrent) {
  std::vector<UpgraderEntry> dummy_entry = {
      {4, "foo__0_3", "foo.bar()"},
      {8, "foo__4_7", "foo.bar()"},
  };

  auto isCurrent = isOpCurrentBasedOnUpgraderEntries(dummy_entry, 6);
  auto isCurrentV2 = isOpCurrentBasedOnUpgraderEntries(dummy_entry, 8);
  EXPECT_FALSE(isCurrent);
  EXPECT_TRUE(isCurrentV2);

  // symbol based look up
  test_only_add_entry("foo", dummy_entry[0]);
  test_only_add_entry("foo", dummy_entry[1]);
  EXPECT_FALSE(isOpSymbolCurrent("foo", 6));
  EXPECT_TRUE(isOpSymbolCurrent("foo", 8));
  test_only_remove_entry("foo");
}

TEST(UpgraderUtils, CanLoadHistoricOp) {
  std::vector<UpgraderEntry> dummy_entry = {
      {4, "foo__0_3", "foo.bar()"},
      {8, "foo__4_7", "foo.foo()"},
  };

  std::vector<std::string> schemas = {"foo.bar()", "foo.foo()"};

  // symbol based look up
  test_only_add_entry("old_op_not_exist.first", dummy_entry[0]);
  test_only_add_entry("old_op_not_exist.second", dummy_entry[1]);

  auto oldSchemas = loadPossibleHistoricOps("old_op_not_exist", 2);
  EXPECT_EQ(oldSchemas.size(), 2);
  for (const auto& entry : oldSchemas) {
    EXPECT_TRUE(
        std::find(schemas.begin(), schemas.end(), entry) != schemas.end());
  }

  auto oldSchemasWithCurrentVersion =
      loadPossibleHistoricOps("old_op_not_exist", 9);
  EXPECT_EQ(oldSchemasWithCurrentVersion.size(), 0);

  test_only_remove_entry("old_op_not_exist.first");
  test_only_remove_entry("old_op_not_exist.first");

  // it is ok to have old schemas without overload
  test_only_add_entry("old_op_not_exist_no_overload", dummy_entry[0]);
  auto oldSchemasNoOverload =
      loadPossibleHistoricOps("old_op_not_exist_no_overload", 2);
  EXPECT_EQ(oldSchemasNoOverload.size(), 1);
  EXPECT_EQ(oldSchemasNoOverload[0], "foo.bar()");
  test_only_remove_entry("old_op_not_exist_no_overload");
}

} // namespace jit
} // namespace torch