1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
|
#pragma once
#include <test/cpp/common/support.h>
#include <gtest/gtest.h>
#include <ATen/TensorIndexing.h>
#include <c10/util/Exception.h>
#include <torch/nn/cloneable.h>
#include <torch/types.h>
#include <torch/utils.h>
#include <string>
#include <utility>
namespace torch {
namespace test {
// Lets you use a container without making a new class,
// for experimental implementations
class SimpleContainer : public nn::Cloneable<SimpleContainer> {
public:
void reset() override {}
template <typename ModuleHolder>
ModuleHolder add(
ModuleHolder module_holder,
std::string name = std::string()) {
return Module::register_module(std::move(name), module_holder);
}
};
struct SeedingFixture : public ::testing::Test {
SeedingFixture() {
torch::manual_seed(0);
}
};
struct WarningCapture : public WarningHandler {
WarningCapture() : prev_(Warning::get_warning_handler()) {
Warning::set_warning_handler(this);
}
~WarningCapture() {
Warning::set_warning_handler(prev_);
}
const std::vector<std::string>& messages() {
return messages_;
}
std::string str() {
return c10::Join("\n", messages_);
}
void process(
const SourceLocation& source_location,
const std::string& msg,
const bool /*verbatim*/) override {
messages_.push_back(msg);
}
private:
WarningHandler* prev_;
std::vector<std::string> messages_;
};
inline bool pointer_equal(at::Tensor first, at::Tensor second) {
return first.data_ptr() == second.data_ptr();
}
// This mirrors the `isinstance(x, torch.Tensor) and isinstance(y,
// torch.Tensor)` branch in `TestCase.assertEqual` in
// torch/testing/_internal/common_utils.py
inline void assert_tensor_equal(
at::Tensor a,
at::Tensor b,
bool allow_inf = false) {
ASSERT_TRUE(a.sizes() == b.sizes());
if (a.numel() > 0) {
if (a.device().type() == torch::kCPU &&
(a.scalar_type() == torch::kFloat16 ||
a.scalar_type() == torch::kBFloat16)) {
// CPU half and bfloat16 tensors don't have the methods we need below
a = a.to(torch::kFloat32);
}
if (a.device().type() == torch::kCUDA &&
a.scalar_type() == torch::kBFloat16) {
// CUDA bfloat16 tensors don't have the methods we need below
a = a.to(torch::kFloat32);
}
b = b.to(a);
if ((a.scalar_type() == torch::kBool) !=
(b.scalar_type() == torch::kBool)) {
TORCH_CHECK(false, "Was expecting both tensors to be bool type.");
} else {
if (a.scalar_type() == torch::kBool && b.scalar_type() == torch::kBool) {
// we want to respect precision but as bool doesn't support subtraction,
// boolean tensor has to be converted to int
a = a.to(torch::kInt);
b = b.to(torch::kInt);
}
auto diff = a - b;
if (a.is_floating_point()) {
// check that NaNs are in the same locations
auto nan_mask = torch::isnan(a);
ASSERT_TRUE(torch::equal(nan_mask, torch::isnan(b)));
diff.index_put_({nan_mask}, 0);
// inf check if allow_inf=true
if (allow_inf) {
auto inf_mask = torch::isinf(a);
auto inf_sign = inf_mask.sign();
ASSERT_TRUE(torch::equal(inf_sign, torch::isinf(b).sign()));
diff.index_put_({inf_mask}, 0);
}
}
// TODO: implement abs on CharTensor (int8)
if (diff.is_signed() && diff.scalar_type() != torch::kInt8) {
diff = diff.abs();
}
auto max_err = diff.max().item<double>();
ASSERT_LE(max_err, 1e-5);
}
}
}
// This mirrors the `isinstance(x, torch.Tensor) and isinstance(y,
// torch.Tensor)` branch in `TestCase.assertNotEqual` in
// torch/testing/_internal/common_utils.py
inline void assert_tensor_not_equal(at::Tensor x, at::Tensor y) {
if (x.sizes() != y.sizes()) {
return;
}
ASSERT_GT(x.numel(), 0);
y = y.type_as(x);
y = x.is_cuda() ? y.to({torch::kCUDA, x.get_device()}) : y.cpu();
auto nan_mask = x != x;
if (torch::equal(nan_mask, y != y)) {
auto diff = x - y;
if (diff.is_signed()) {
diff = diff.abs();
}
diff.index_put_({nan_mask}, 0);
// Use `item()` to work around:
// https://github.com/pytorch/pytorch/issues/22301
auto max_err = diff.max().item<double>();
ASSERT_GE(max_err, 1e-5);
}
}
inline int count_substr_occurrences(
const std::string& str,
const std::string& substr) {
int count = 0;
size_t pos = str.find(substr);
while (pos != std::string::npos) {
count++;
pos = str.find(substr, pos + substr.size());
}
return count;
}
// A RAII, thread local (!) guard that changes default dtype upon
// construction, and sets it back to the original dtype upon destruction.
//
// Usage of this guard is synchronized across threads, so that at any given
// time, only one guard can take effect.
struct AutoDefaultDtypeMode {
static std::mutex default_dtype_mutex;
AutoDefaultDtypeMode(c10::ScalarType default_dtype)
: prev_default_dtype(
torch::typeMetaToScalarType(torch::get_default_dtype())) {
default_dtype_mutex.lock();
torch::set_default_dtype(torch::scalarTypeToTypeMeta(default_dtype));
}
~AutoDefaultDtypeMode() {
default_dtype_mutex.unlock();
torch::set_default_dtype(torch::scalarTypeToTypeMeta(prev_default_dtype));
}
c10::ScalarType prev_default_dtype;
};
inline void assert_tensor_creation_meta(
torch::Tensor& x,
torch::autograd::CreationMeta creation_meta) {
auto autograd_meta = x.unsafeGetTensorImpl()->autograd_meta();
TORCH_CHECK(autograd_meta);
auto view_meta =
static_cast<torch::autograd::DifferentiableViewMeta*>(autograd_meta);
TORCH_CHECK(view_meta->has_bw_view());
ASSERT_EQ(view_meta->get_creation_meta(), creation_meta);
}
} // namespace test
} // namespace torch
|