File: test_misc.cpp

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (85 lines) | stat: -rw-r--r-- 3,198 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include <gtest/gtest.h>
#include <string>

#include <c10/util/int128.h>
#include <torch/csrc/lazy/core/hash.h>

namespace torch {
namespace lazy {

template <typename T>
void test_hash_repeatable_sensitive(const T& example_a, const T& example_b) {
  // repeatable
  EXPECT_EQ(Hash(example_a), Hash(example_a));
  EXPECT_EQ(MHash(example_a), MHash(example_a));
  EXPECT_EQ(MHash(example_a, example_a), MHash(example_a, example_a));

  // sensitive
  EXPECT_NE(Hash(example_a), Hash(example_b));
  EXPECT_NE(MHash(example_a), MHash(example_b));
  EXPECT_NE(MHash(example_a, example_a), MHash(example_a, example_b));
}

TEST(HashTest, Scalar) {
  GTEST_SKIP()
      << "Broken test. See https://github.com/pytorch/pytorch/issues/99883";
  c10::Scalar a(0);
  c10::Scalar b(0);

  // simulate some garbage in the unused bits of the
  // the tagged union that is c10::Scalar, which is bigger
  // than the size of the int64_t we're currently using it with
  *((uint8_t*)&b) = 1;
  // actual 'value' of the Scalar as a 64 bit int shouldn't have changed
  EXPECT_EQ(a.toLong(), b.toLong());
  // and hash should ignore this garbage
  EXPECT_EQ(Hash(a), Hash(b));
  EXPECT_EQ(MHash(a), MHash(b));
  EXPECT_EQ(MHash(a, a), MHash(a, b));
}

TEST(HashTest, Sanity) {
  // String
  test_hash_repeatable_sensitive(
      std::string(
          "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Ut at suscipit purus."),
      std::string(
          "Lorem Jpsum dolor sit amet, consectetur adipiscing elit. Ut at suscipit purus."));

  // Number types
  test_hash_repeatable_sensitive(true, false);
  test_hash_repeatable_sensitive((int8_t)0xfa, (int8_t)0xfb);
  test_hash_repeatable_sensitive((int16_t)0xface, (int16_t)0xfade);
  test_hash_repeatable_sensitive((int32_t)0xfaceb000, (int32_t)0xfadeb000);
  test_hash_repeatable_sensitive((int64_t)0x1faceb000, (int64_t)0x1fadeb000);
  test_hash_repeatable_sensitive((uint8_t)0xfa, (uint8_t)0xfb);
  test_hash_repeatable_sensitive((uint16_t)0xface, (uint16_t)0xfade);
  test_hash_repeatable_sensitive((uint32_t)0xfaceb000, (uint32_t)0xfadeb000);
  test_hash_repeatable_sensitive((uint64_t)0x1faceb000, (uint64_t)0x1fadeb000);

  // c10 types
  test_hash_repeatable_sensitive(c10::ScalarType::Bool, c10::ScalarType::Byte);
  test_hash_repeatable_sensitive(c10::Scalar(1.334), c10::Scalar(1.335));
  test_hash_repeatable_sensitive(c10::Scalar(true), c10::Scalar(false));
  test_hash_repeatable_sensitive(c10::Scalar(12345), c10::Scalar(12354));

  // std::optional
  test_hash_repeatable_sensitive(
      std::optional<std::string>("I have value!"),
      std::optional<std::string>(std::nullopt));

  // Containers
  auto a = std::vector<int32_t>({0, 1, 1, 2, 3, 5, 8});
  auto b = std::vector<int32_t>({1, 1, 2, 3, 5, 8, 12});
  test_hash_repeatable_sensitive(a, b);
  test_hash_repeatable_sensitive(
      c10::ArrayRef<int32_t>(a), c10::ArrayRef<int32_t>(b));

  // vector<bool> is a special case bc it is implemented as vector<bit>
  auto bool_a = std::vector<bool>({true, false, false, true});
  auto bool_b = std::vector<bool>({true, true, false, true});
  test_hash_repeatable_sensitive(bool_a, bool_b);
}

} // namespace lazy
} // namespace torch