1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
|
#pragma once
#include <memory>
#include <vector>
#include <test/cpp/tensorexpr/test_base.h>
#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
#include <torch/csrc/jit/testing/file_check.h>
namespace torch {
namespace jit {
using namespace torch::jit::tensorexpr;
#define IS_NODE(T, node) \
{ \
auto node_ = to<T>(node); \
ASSERT_NE(nullptr, node_); \
}
#define IS_NODE_WITH_NAME(T, node, name) \
auto name = to<T>(node); \
ASSERT_NE(nullptr, name);
#define IS_NODE_WITH_NAME_AND_CAST(T, node, name, Type) \
NodePtr<T> name = nullptr; \
{ \
auto node_ = to<Cast>(node); \
ASSERT_NE(nullptr, node_); \
ASSERT_EQ(node_->dtype().scalar_type(), ScalarType::Type); \
name = to<T>(node_->src_value()); \
} \
ASSERT_NE(nullptr, name);
#define IS_IMM_WITH_VAL(T, node, val) \
{ \
auto node_ = to<T##Imm>(node); \
ASSERT_NE(nullptr, node_); \
ASSERT_EQ(node_->value(), val); \
}
#define IS_VAR_WITH_NAME(node, name) \
{ \
auto node_ = to<Var>(node); \
ASSERT_NE(nullptr, node_); \
ASSERT_EQ(node_->name_hint(), name); \
}
#define IS_BINOP_W_VARS(T, node, name, v1, v2) \
NodePtr<T> name = nullptr; \
{ \
name = to<T>(node); \
ASSERT_NE(nullptr, name); \
IS_VAR_WITH_NAME(name->lhs(), v1); \
IS_VAR_WITH_NAME(name->rhs(), v2); \
}
#define IS_BINOP_W_CONST(T, node, name, v, c) \
NodePtr<T> name = nullptr; \
{ \
name = to<T>(node); \
ASSERT_NE(nullptr, name); \
IS_VAR_WITH_NAME(name->lhs(), v); \
IS_IMM_WITH_VAL(Int, name->rhs(), c); \
}
#define IS_RAND(node) \
{ \
auto node_ = to<Intrinsics>(node); \
ASSERT_NE(nullptr, node_); \
ASSERT_EQ(node_->op_type(), kRand); \
}
void checkIR(StmtPtr s, const std::string& pattern);
void checkExprIR(ExprPtr e, const std::string& pattern);
void checkExprIR(const ExprHandle& e, const std::string& pattern);
} // namespace jit
} // namespace torch
|