1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
|
#include <gtest/gtest.h>
#include <ATen/core/NestedIntSymNodeImpl.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymNodeImpl.h>
#include <torch/torch.h>
#include <test/cpp/api/support.h>
TEST(NestedIntTest, Comparisons) {
auto a = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 1)));
auto b = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 1)));
auto c = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(2, 1)));
auto d = c10::SymInt(3);
ASSERT_TRUE(a == a);
ASSERT_TRUE(a == b);
ASSERT_FALSE(a != a);
ASSERT_FALSE(a != b);
ASSERT_FALSE(a == c);
ASSERT_TRUE(a != c);
ASSERT_FALSE(a == d);
ASSERT_TRUE(a != d);
ASSERT_FALSE(d == a);
ASSERT_TRUE(d != a);
// ge
ASSERT_TRUE(a >= a);
ASSERT_TRUE(a >= b);
ASSERT_TRUE(b >= a);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a >= c), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c >= a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c >= 3), c10::Error);
ASSERT_TRUE(c >= 2);
ASSERT_TRUE(c >= 1);
ASSERT_FALSE(1 >= c);
// lt
ASSERT_FALSE(a < a);
ASSERT_FALSE(a < b);
ASSERT_FALSE(b < a);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a < c), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c < a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(3 < a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(2 < a), c10::Error);
ASSERT_TRUE(1 < a);
// le
ASSERT_TRUE(a <= a);
ASSERT_TRUE(b <= a);
ASSERT_TRUE(a <= b);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a <= c), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c <= a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(3 <= c), c10::Error);
ASSERT_TRUE(2 <= c);
ASSERT_TRUE(1 <= c);
ASSERT_FALSE(c <= 1);
// gt
ASSERT_FALSE(a > a);
ASSERT_FALSE(b > a);
ASSERT_FALSE(a > b);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a > c), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c > a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a > 3), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a > 2), c10::Error);
ASSERT_TRUE(a > 1);
}
TEST(NestedIntTest, WithFactor) {
auto a = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 5)));
auto b = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 10)));
// eq
ASSERT_FALSE(a == b);
ASSERT_FALSE(a >= b);
ASSERT_TRUE(b >= a);
ASSERT_TRUE(a <= b);
ASSERT_FALSE(b <= a);
// ne
ASSERT_TRUE(a != b);
// mul
ASSERT_TRUE(a * 2 == b);
ASSERT_TRUE(a * 3 >= b);
ASSERT_TRUE(a * 2 == 2 * a);
}
|