1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
|
#ifdef TORCH_ENABLE_LLVM
#include <gtest/gtest.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
#include <torch/csrc/jit/tensorexpr/loopnest.h>
#include <torch/csrc/jit/tensorexpr/tensor.h>
#include <torch/torch.h>
#include <cstring>
using namespace torch::indexing;
namespace te = torch::jit::tensorexpr;
static void vectorize(te::LoopNest* ln, te::Tensor target, int width) {
auto loops = ln->getLoopStmtsFor(target);
te::ForPtr inner, tail;
ln->splitWithTail(loops[0], width, &inner, &tail);
ASSERT_TRUE(te::LoopNest::vectorize(inner));
}
std::string diffs(const at::Tensor& a, const at::Tensor& b) {
auto diff = torch::abs(a.flatten() - b.flatten());
auto count_diffs = torch::sum(diff > 0.f);
auto greatest_diff_index = torch::argmax(diff);
std::stringstream ss;
ss << "Found " << count_diffs << " unequal element(s). "
<< "The greatest difference was " << diff.index({greatest_diff_index})
<< " at index " << greatest_diff_index;
return ss.str();
}
TEST(Approx, log_vml) {
te::VarHandle N("N", te::kInt);
te::BufHandle A("A", {N}, te::kFloat);
te::Tensor B = te::Compute(
"B", {N}, [&](const te::VarHandle& i) { return log_vml(A.load(i)); });
te::LoopNest ln({B});
ln.prepareForCodegen();
vectorize(&ln, B, 8);
te::StmtPtr s = ln.root_stmt();
s = te::IRSimplifier::simplify(s);
te::LLVMCodeGen cg(s, {A, B, N});
auto eps = std::numeric_limits<float>::epsilon();
auto test = [&](const at::Tensor& A_t) {
at::Tensor B_ref = at::log(A_t);
at::Tensor B_t = at::empty_like(A_t);
auto ap = A_t.data_ptr<float>();
auto bp = B_t.data_ptr<float>();
cg.call({ap, bp, A_t.numel()});
// Results should be bit-identical.
ASSERT_TRUE(torch::allclose(
B_t, B_ref, /*rtol=*/eps, /*atol=*/0.0f, /*equal_nan=*/true))
<< "Input[:8]\n"
<< A_t.index({Slice(0, 8)}) << "\n"
<< "Test[:8]\n"
<< B_t.index({Slice(0, 8)}) << "\n"
<< "Ref[:8]\n"
<< B_ref.index({Slice(0, 8)}) << diffs(B_t, B_ref);
};
// Generate every single-precision FP value in [1.0, 2.0).
at::Tensor A_t = torch::arange(1.0f, 2.0f, eps);
ASSERT_EQ(A_t.numel(), 1 << 23);
test(A_t);
test(A_t * 2.0f);
test(A_t * 0.5f);
test(A_t * 4.0f);
test(A_t * 0.25f);
test(A_t * powf(2.0f, 16));
test(A_t * powf(2.0f, -16));
test(A_t * powf(2.0f, 126));
test(A_t * powf(2.0f, -126));
test(torch::full({32}, INFINITY));
test(torch::full({32}, NAN));
auto min = std::numeric_limits<float>::min();
auto denorm_min = std::numeric_limits<float>::denorm_min();
// Denormals aren't bit precise, because sleef isn't bit-precise either.
A_t = torch::arange(0.0f, min, denorm_min);
ASSERT_EQ(A_t.numel(), 1 << 23);
auto B_ref = at::log(A_t);
auto B_t = at::empty_like(B_ref);
cg.call({A_t.data_ptr<float>(), B_t.data_ptr<float>(), A_t.numel()});
ASSERT_TRUE(torch::allclose(B_t, B_ref));
}
#endif // TORCH_ENABLE_LLVM
|