1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
|
#pragma once
#include <iostream>
#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
#include <torch/csrc/jit/tensorexpr/ir.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <torch/csrc/jit/tensorexpr/unique_name_manager.h>
namespace torch {
namespace jit {
namespace tensorexpr {
class Tensor;
class TORCH_API IRPrinter : public IRVisitor {
public:
explicit IRPrinter(std::ostream& os) : printer_os_(this, os) {}
void print(ExprHandle);
void print(Expr&);
void print(Stmt&);
void visit(AddPtr v) override;
void visit(SubPtr v) override;
void visit(MulPtr v) override;
void visit(DivPtr v) override;
void visit(ModPtr v) override;
void visit(MaxPtr v) override;
void visit(MinPtr v) override;
void visit(AndPtr v) override;
void visit(OrPtr v) override;
void visit(XorPtr v) override;
void visit(LshiftPtr v) override;
void visit(RshiftPtr v) override;
void visit(CompareSelectPtr v) override;
#define IMM_PRINT_VISIT(Type, Name) void visit(Name##ImmPtr v) override;
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT);
#undef IMM_PRINT_VISIT
void visit(CastPtr v) override;
void visit(BitCastPtr v) override;
void visit(VarPtr v) override;
void visit(BufPtr v) override;
void visit(RampPtr v) override;
void visit(LoadPtr v) override;
void visit(BroadcastPtr v) override;
void visit(IfThenElsePtr v) override;
void visit(IntrinsicsPtr v) override;
void visit(TermPtr v) override;
void visit(PolynomialPtr v) override;
void visit(RoundOffPtr v) override;
void visit(MaxTermPtr v) override;
void visit(MinTermPtr v) override;
void visit(ReduceOpPtr v) override;
void visit(AtomicAddPtr v) override;
void visit(SyncThreadsPtr v) override;
void visit(ExternalCallPtr v) override;
void visit(ExternalCallWithAllocPtr v) override;
void visit(StorePtr v) override;
void visit(ForPtr v) override;
void visit(CondPtr v) override;
void visit(BlockPtr v) override;
void visit(AllocatePtr v) override;
void visit(FreePtr v) override;
void visit(FreeExtPtr v) override;
void visit(PlacementAllocatePtr v) override;
void visit(LetPtr v) override;
// A child class may have a difference rule for generating dtype
// string, e.g. CUDA needs int64_t to be generated as long long.
virtual std::string dtypeToCppString(const Dtype& dtype);
std::ostream& os() {
return printer_os_;
}
class PrinterStream : public std::ostream {
public:
PrinterStream(IRPrinter* printer, std::ostream& os)
: std::ostream(os.rdbuf()), printer_(printer) {}
IRPrinter* printer() {
return printer_;
}
private:
IRPrinter* printer_ = nullptr;
};
protected:
std::string to_string(CompareSelectOperation op);
UniqueNameManager* name_manager() {
return &name_manager_;
}
void emitIndent();
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
int indent_ = 0;
private:
PrinterStream printer_os_;
UniqueNameManager name_manager_;
};
TORCH_API std::ostream& operator<<(std::ostream& stream, const Expr&);
TORCH_API std::ostream& operator<<(std::ostream& stream, const ExprHandle&);
TORCH_API std::ostream& operator<<(std::ostream& stream, const Stmt&);
TORCH_API std::ostream& operator<<(std::ostream& stream, const Tensor&);
TORCH_API void print(ExprPtr expr);
TORCH_API void print(StmtPtr stmt);
TORCH_API void print(const Tensor& t);
} // namespace tensorexpr
} // namespace jit
} // namespace torch
namespace std {
using torch::jit::tensorexpr::Expr;
using torch::jit::tensorexpr::ExprPtr;
using torch::jit::tensorexpr::Stmt;
using torch::jit::tensorexpr::StmtPtr;
using torch::jit::tensorexpr::Tensor;
TORCH_API std::string to_string(ExprPtr expr);
TORCH_API std::string to_string(StmtPtr stmt);
TORCH_API std::string to_string(const Tensor& t);
} // namespace std
|