1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
|
#include <torch/csrc/jit/codegen/onednn/graph_helper.h>
#include <torch/csrc/jit/codegen/onednn/layout_propagation.h>
#include <torch/csrc/jit/jit_log.h>
namespace torch::jit::fuser::onednn {
static void LayoutPropagation(Node* n) {
if (!LlgaGraphHelper::isLlgaSubgraph(n))
return;
// initial attr::output_layouts if undefined
if (!n->hasAttribute(attr::output_layouts)) {
const auto num_output = n->outputs().size();
GRAPH_DEBUG("Initial output_layouts of size ", num_output);
std::vector<int64_t> layouts(num_output, STRIDED_LAYOUT);
n->is_(attr::output_layouts, layouts);
}
for (auto input : n->inputs()) {
auto prev = input->node();
auto offset = input->offset();
if (LlgaGraphHelper::isLlgaSubgraph(prev)) {
bool useOpaqueLayout = true;
for (auto& use : input->uses()) {
if (!LlgaGraphHelper::isLlgaSubgraph(use.user)) {
useOpaqueLayout = false;
break;
}
}
if (useOpaqueLayout) {
LlgaNodeWrapper(prev).setOpaqueLayout(offset);
}
}
}
}
static void LayoutPropagation(at::ArrayRef<Block*> blocks) {
for (Block* block : blocks)
for (Node* node : block->nodes())
LayoutPropagation(node);
}
void PropagateLayout(const std::shared_ptr<Graph>& graph) {
LayoutPropagation(graph->block());
}
} // namespace torch::jit::fuser::onednn
|