1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
|
//===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/ViewOpGraph.h"
#include "PassDetail.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "llvm/Support/CommandLine.h"
using namespace mlir;
/// Return the size limits for eliding large attributes.
static int64_t getLargeAttributeSizeLimit() {
// Use the default from the printer flags if possible.
if (Optional<int64_t> limit = OpPrintingFlags().getLargeElementsAttrLimit())
return *limit;
return 16;
}
namespace llvm {
// Specialize GraphTraits to treat Block as a graph of Operations as nodes and
// uses as edges.
template <> struct GraphTraits<Block *> {
using GraphType = Block *;
using NodeRef = Operation *;
using ChildIteratorType = Operation::user_iterator;
static ChildIteratorType child_begin(NodeRef n) { return n->user_begin(); }
static ChildIteratorType child_end(NodeRef n) { return n->user_end(); }
// Operation's destructor is private so use Operation* instead and use
// mapped iterator.
static Operation *AddressOf(Operation &op) { return &op; }
using nodes_iterator = mapped_iterator<Block::iterator, decltype(&AddressOf)>;
static nodes_iterator nodes_begin(Block *b) {
return nodes_iterator(b->begin(), &AddressOf);
}
static nodes_iterator nodes_end(Block *b) {
return nodes_iterator(b->end(), &AddressOf);
}
};
// Specialize DOTGraphTraits to produce more readable output.
template <> struct DOTGraphTraits<Block *> : public DefaultDOTGraphTraits {
using DefaultDOTGraphTraits::DefaultDOTGraphTraits;
static std::string getNodeLabel(Operation *op, Block *);
};
std::string DOTGraphTraits<Block *>::getNodeLabel(Operation *op, Block *b) {
// Reuse the print output for the node labels.
std::string ostr;
raw_string_ostream os(ostr);
os << op->getName() << "\n";
if (!op->getLoc().isa<UnknownLoc>()) {
os << op->getLoc() << "\n";
}
// Print resultant types
llvm::interleaveComma(op->getResultTypes(), os);
os << "\n";
// A value used to elide large container attribute.
int64_t largeAttrLimit = getLargeAttributeSizeLimit();
for (auto attr : op->getAttrs()) {
os << '\n' << attr.first << ": ";
// Always emit splat attributes.
if (attr.second.isa<SplatElementsAttr>()) {
attr.second.print(os);
continue;
}
// Elide "big" elements attributes.
auto elements = attr.second.dyn_cast<ElementsAttr>();
if (elements && elements.getNumElements() > largeAttrLimit) {
os << std::string(elements.getType().getRank(), '[') << "..."
<< std::string(elements.getType().getRank(), ']') << " : "
<< elements.getType();
continue;
}
auto array = attr.second.dyn_cast<ArrayAttr>();
if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
os << "[...]";
continue;
}
// Print all other attributes.
attr.second.print(os);
}
return os.str();
}
} // end namespace llvm
namespace {
// PrintOpPass is simple pass to write graph per function.
// Note: this is a module pass only to avoid interleaving on the same ostream
// due to multi-threading over functions.
class PrintOpPass : public ViewOpGraphPassBase<PrintOpPass> {
public:
PrintOpPass(raw_ostream &os, bool shortNames, const Twine &title) : os(os) {
this->shortNames = shortNames;
this->title = title.str();
}
std::string getOpName(Operation &op) {
auto symbolAttr =
op.getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
if (symbolAttr)
return std::string(symbolAttr.getValue());
++unnamedOpCtr;
return (op.getName().getStringRef() + llvm::utostr(unnamedOpCtr)).str();
}
// Print all the ops in a module.
void processModule(ModuleOp module) {
for (Operation &op : module) {
// Modules may actually be nested, recurse on nesting.
if (auto nestedModule = dyn_cast<ModuleOp>(op)) {
processModule(nestedModule);
continue;
}
auto opName = getOpName(op);
for (Region ®ion : op.getRegions()) {
for (auto indexed_block : llvm::enumerate(region)) {
// Suffix block number if there are more than 1 block.
auto blockName = llvm::hasSingleElement(region)
? ""
: ("__" + llvm::utostr(indexed_block.index()));
llvm::WriteGraph(os, &indexed_block.value(), shortNames,
Twine(title) + opName + blockName);
}
}
}
}
void runOnOperation() override { processModule(getOperation()); }
private:
raw_ostream &os;
int unnamedOpCtr = 0;
};
} // namespace
void mlir::viewGraph(Block &block, const Twine &name, bool shortNames,
const Twine &title, llvm::GraphProgram::Name program) {
llvm::ViewGraph(&block, name, shortNames, title, program);
}
raw_ostream &mlir::writeGraph(raw_ostream &os, Block &block, bool shortNames,
const Twine &title) {
return llvm::WriteGraph(os, &block, shortNames, title);
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::createPrintOpGraphPass(raw_ostream &os, bool shortNames,
const Twine &title) {
return std::make_unique<PrintOpPass>(os, shortNames, title);
}
|