File: test_base.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (87 lines) | stat: -rw-r--r-- 2,853 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#pragma once

#if defined(USE_GTEST)
#include <gtest/gtest.h>
#include <test/cpp/common/support.h>
#else
#include <cmath>
#include "c10/util/Exception.h"
#include "test/cpp/tensorexpr/gtest_assert_float_eq.h"
#define ASSERT_EQ(x, y, ...) TORCH_INTERNAL_ASSERT((x) == (y), __VA_ARGS__)
#define ASSERT_FLOAT_EQ(x, y, ...) \
  TORCH_INTERNAL_ASSERT(AlmostEquals((x), (y)), __VA_ARGS__)
#define ASSERT_NE(x, y, ...) TORCH_INTERNAL_ASSERT((x) != (y), __VA_ARGS__)
#define ASSERT_GT(x, y, ...) TORCH_INTERNAL_ASSERT((x) > (y), __VA_ARGS__)
#define ASSERT_GE(x, y, ...) TORCH_INTERNAL_ASSERT((x) >= (y), __VA_ARGS__)
#define ASSERT_LT(x, y, ...) TORCH_INTERNAL_ASSERT((x) < (y), __VA_ARGS__)
#define ASSERT_LE(x, y, ...) TORCH_INTERNAL_ASSERT((x) <= (y), __VA_ARGS__)

#define ASSERT_NEAR(x, y, a, ...) \
  TORCH_INTERNAL_ASSERT(std::fabs((x) - (y)) < (a), __VA_ARGS__)

#define ASSERT_TRUE TORCH_INTERNAL_ASSERT
#define ASSERT_FALSE(x) ASSERT_TRUE(!(x))
#define ASSERT_THROWS_WITH(statement, substring)                         \
  try {                                                                  \
    (void)statement;                                                     \
    ASSERT_TRUE(false);                                                  \
  } catch (const std::exception& e) {                                    \
    ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \
  }
#define ASSERT_ANY_THROW(statement)     \
  {                                     \
    bool threw = false;                 \
    try {                               \
      (void)statement;                  \
    } catch (const std::exception& e) { \
      threw = true;                     \
    }                                   \
    ASSERT_TRUE(threw);                 \
  }

#endif // defined(USE_GTEST)

namespace torch {
namespace jit {
namespace tensorexpr {

template <typename U, typename V>
void ExpectAllNear(
    const std::vector<U>& v1,
    const std::vector<U>& v2,
    V threshold,
    const std::string& name = "") {
  ASSERT_EQ(v1.size(), v2.size());
  for (size_t i = 0; i < v1.size(); i++) {
    ASSERT_NEAR(v1[i], v2[i], threshold);
  }
}

template <typename U, typename V>
void ExpectAllNear(
    const std::vector<U>& vec,
    const U& val,
    V threshold,
    const std::string& name = "") {
  for (size_t i = 0; i < vec.size(); i++) {
    ASSERT_NEAR(vec[i], val, threshold);
  }
}

template <typename T>
static void assertAllEqual(const std::vector<T>& vec, const T& val) {
  for (auto const& elt : vec) {
    ASSERT_EQ(elt, val);
  }
}

template <typename T>
static void assertAllEqual(const std::vector<T>& v1, const std::vector<T>& v2) {
  ASSERT_EQ(v1.size(), v2.size());
  for (size_t i = 0; i < v1.size(); ++i) {
    ASSERT_EQ(v1[i], v2[i]);
  }
}
} // namespace tensorexpr
} // namespace jit
} // namespace torch