1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
|
//===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains interfaces and analyses for defining a nested callgraph.
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/CallGraph.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// CallGraphNode
//===----------------------------------------------------------------------===//
/// Returns true if this node refers to the indirect/external node.
bool CallGraphNode::isExternal() const { return !callableRegion; }
/// Return the callable region this node represents. This can only be called
/// on non-external nodes.
Region *CallGraphNode::getCallableRegion() const {
assert(!isExternal() && "the external node has no callable region");
return callableRegion;
}
/// Adds an reference edge to the given node. This is only valid on the
/// external node.
void CallGraphNode::addAbstractEdge(CallGraphNode *node) {
assert(isExternal() && "abstract edges are only valid on external nodes");
addEdge(node, Edge::Kind::Abstract);
}
/// Add an outgoing call edge from this node.
void CallGraphNode::addCallEdge(CallGraphNode *node) {
addEdge(node, Edge::Kind::Call);
}
/// Adds a reference edge to the given child node.
void CallGraphNode::addChildEdge(CallGraphNode *child) {
addEdge(child, Edge::Kind::Child);
}
/// Returns true if this node has any child edges.
bool CallGraphNode::hasChildren() const {
return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); });
}
/// Add an edge to 'node' with the given kind.
void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
edges.insert({node, kind});
}
//===----------------------------------------------------------------------===//
// CallGraph
//===----------------------------------------------------------------------===//
/// Recursively compute the callgraph edges for the given operation. Computed
/// edges are placed into the given callgraph object.
static void computeCallGraph(Operation *op, CallGraph &cg,
SymbolTableCollection &symbolTable,
CallGraphNode *parentNode, bool resolveCalls) {
if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
// If there is no parent node, we ignore this operation. Even if this
// operation was a call, there would be no callgraph node to attribute it
// to.
if (resolveCalls && parentNode)
parentNode->addCallEdge(cg.resolveCallable(call, symbolTable));
return;
}
// Compute the callgraph nodes and edges for each of the nested operations.
if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) {
if (auto *callableRegion = callable.getCallableRegion())
parentNode = cg.getOrAddNode(callableRegion, parentNode);
else
return;
}
for (Region ®ion : op->getRegions())
for (Operation &nested : region.getOps())
computeCallGraph(&nested, cg, symbolTable, parentNode, resolveCalls);
}
CallGraph::CallGraph(Operation *op)
: externalCallerNode(/*callableRegion=*/nullptr),
unknownCalleeNode(/*callableRegion=*/nullptr) {
// Make two passes over the graph, one to compute the callables and one to
// resolve the calls. We split these up as we may have nested callable objects
// that need to be reserved before the calls.
SymbolTableCollection symbolTable;
computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
/*resolveCalls=*/false);
computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
/*resolveCalls=*/true);
}
/// Get or add a call graph node for the given region.
CallGraphNode *CallGraph::getOrAddNode(Region *region,
CallGraphNode *parentNode) {
assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
"expected parent operation to be callable");
std::unique_ptr<CallGraphNode> &node = nodes[region];
if (!node) {
node.reset(new CallGraphNode(region));
// Add this node to the given parent node if necessary.
if (parentNode) {
parentNode->addChildEdge(node.get());
} else {
// Otherwise, connect all callable nodes to the external node, this allows
// for conservatively including all callable nodes within the graph.
// FIXME This isn't correct, this is only necessary for callable nodes
// that *could* be called from external sources. This requires extending
// the interface for callables to check if they may be referenced
// externally.
externalCallerNode.addAbstractEdge(node.get());
}
}
return node.get();
}
/// Lookup a call graph node for the given region, or nullptr if none is
/// registered.
CallGraphNode *CallGraph::lookupNode(Region *region) const {
auto it = nodes.find(region);
return it == nodes.end() ? nullptr : it->second.get();
}
/// Resolve the callable for given callee to a node in the callgraph, or the
/// unknown callee node if a valid node was not resolved.
CallGraphNode *
CallGraph::resolveCallable(CallOpInterface call,
SymbolTableCollection &symbolTable) const {
Operation *callable = call.resolveCallable(&symbolTable);
if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
if (auto *node = lookupNode(callableOp.getCallableRegion()))
return node;
return getUnknownCalleeNode();
}
/// Erase the given node from the callgraph.
void CallGraph::eraseNode(CallGraphNode *node) {
// Erase any children of this node first.
if (node->hasChildren()) {
for (const CallGraphNode::Edge &edge : llvm::make_early_inc_range(*node))
if (edge.isChild())
eraseNode(edge.getTarget());
}
// Erase any edges to this node from any other nodes.
for (auto &it : nodes) {
it.second->edges.remove_if([node](const CallGraphNode::Edge &edge) {
return edge.getTarget() == node;
});
}
nodes.erase(node->getCallableRegion());
}
//===----------------------------------------------------------------------===//
// Printing
/// Dump the graph in a human readable format.
void CallGraph::dump() const { print(llvm::errs()); }
void CallGraph::print(raw_ostream &os) const {
os << "// ---- CallGraph ----\n";
// Functor used to output the name for the given node.
auto emitNodeName = [&](const CallGraphNode *node) {
if (node == getExternalCallerNode()) {
os << "<External-Caller-Node>";
return;
}
if (node == getUnknownCalleeNode()) {
os << "<Unknown-Callee-Node>";
return;
}
auto *callableRegion = node->getCallableRegion();
auto *parentOp = callableRegion->getParentOp();
os << "'" << callableRegion->getParentOp()->getName() << "' - Region #"
<< callableRegion->getRegionNumber();
auto attrs = parentOp->getAttrDictionary();
if (!attrs.empty())
os << " : " << attrs;
};
for (auto &nodeIt : nodes) {
const CallGraphNode *node = nodeIt.second.get();
// Dump the header for this node.
os << "// - Node : ";
emitNodeName(node);
os << "\n";
// Emit each of the edges.
for (auto &edge : *node) {
os << "// -- ";
if (edge.isCall())
os << "Call";
else if (edge.isChild())
os << "Child";
os << "-Edge : ";
emitNodeName(edge.getTarget());
os << "\n";
}
os << "//\n";
}
os << "// -- SCCs --\n";
for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) {
os << "// - SCC : \n";
for (auto &node : scc) {
os << "// -- Node :";
emitNodeName(node);
os << "\n";
}
os << "\n";
}
os << "// -------------------\n";
}
|