1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
|
#pragma once
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
// Prepare binary ops for LLGA
//
// The pass does the following:
//
// - Convert scalar input of aten::add and aten::mul into Float tensor with
// dimension [1]
//
// - Decompose fused add into aten::mul + aten::add when alpha != 1.0
//
// - Eliminate identity add/mul, i.e., tensor + 0, tensor * 1
//
void PrepareBinaryForLLGA(const std::shared_ptr<Graph>& graph);
} // namespace onednn
} // namespace fuser
} // namespace jit
} // namespace torch
|