1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
|
#pragma once
#if defined(USE_GTEST)
#include <gtest/gtest.h>
#include <test/cpp/common/support.h>
#else
#include <cmath>
#include "c10/util/Exception.h"
#include "test/cpp/tensorexpr/gtest_assert_float_eq.h"
#define ASSERT_EQ(x, y, ...) TORCH_INTERNAL_ASSERT((x) == (y), __VA_ARGS__)
#define ASSERT_FLOAT_EQ(x, y, ...) \
TORCH_INTERNAL_ASSERT(AlmostEquals((x), (y)), __VA_ARGS__)
#define ASSERT_NE(x, y, ...) TORCH_INTERNAL_ASSERT((x) != (y), __VA_ARGS__)
#define ASSERT_GT(x, y, ...) TORCH_INTERNAL_ASSERT((x) > (y), __VA_ARGS__)
#define ASSERT_GE(x, y, ...) TORCH_INTERNAL_ASSERT((x) >= (y), __VA_ARGS__)
#define ASSERT_LT(x, y, ...) TORCH_INTERNAL_ASSERT((x) < (y), __VA_ARGS__)
#define ASSERT_LE(x, y, ...) TORCH_INTERNAL_ASSERT((x) <= (y), __VA_ARGS__)
#define ASSERT_NEAR(x, y, a, ...) \
TORCH_INTERNAL_ASSERT(std::fabs((x) - (y)) < (a), __VA_ARGS__)
#define ASSERT_TRUE TORCH_INTERNAL_ASSERT
#define ASSERT_FALSE(x) ASSERT_TRUE(!(x))
#define ASSERT_THROWS_WITH(statement, substring) \
try { \
(void)statement; \
ASSERT_TRUE(false); \
} catch (const std::exception& e) { \
ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \
}
#define ASSERT_ANY_THROW(statement) \
{ \
bool threw = false; \
try { \
(void)statement; \
} catch (const std::exception& e) { \
threw = true; \
} \
ASSERT_TRUE(threw); \
}
#endif // defined(USE_GTEST)
namespace torch {
namespace jit {
namespace tensorexpr {
template <typename U, typename V>
void ExpectAllNear(
const std::vector<U>& v1,
const std::vector<U>& v2,
V threshold,
const std::string& name = "") {
ASSERT_EQ(v1.size(), v2.size());
for (size_t i = 0; i < v1.size(); i++) {
ASSERT_NEAR(v1[i], v2[i], threshold);
}
}
template <typename U, typename V>
void ExpectAllNear(
const std::vector<U>& vec,
const U& val,
V threshold,
const std::string& name = "") {
for (size_t i = 0; i < vec.size(); i++) {
ASSERT_NEAR(vec[i], val, threshold);
}
}
template <typename T>
static void assertAllEqual(const std::vector<T>& vec, const T& val) {
for (auto const& elt : vec) {
ASSERT_EQ(elt, val);
}
}
template <typename T>
static void assertAllEqual(const std::vector<T>& v1, const std::vector<T>& v2) {
ASSERT_EQ(v1.size(), v2.size());
for (size_t i = 0; i < v1.size(); ++i) {
ASSERT_EQ(v1[i], v2[i]);
}
}
} // namespace tensorexpr
} // namespace jit
} // namespace torch
|