1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615
|
#include <torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h>
#include <torch/csrc/jit/codegen/onednn/graph_helper.h>
#include <ATen/core/functional.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
namespace torch::jit::fuser::onednn {
using opkind = dnnl::graph::op::kind;
static void fixConvOptionalBias(Node* node) {
if (node->namedInput("bias")->mustNotBeNone() == false) {
// Replace non-existent optional bias with const None
auto g = node->owningGraph();
auto n = g->createNone();
auto v = n->insertBefore(node)->output();
node->replaceInput(2, v);
}
}
static std::optional<size_t> getDimensions(Value* v) {
if (v->type()->isSubtypeOf(TensorType::get())) {
return v->type()->cast<TensorType>()->sizes().size();
} else {
return std::nullopt;
}
}
// PyTorch ops that can't otherwise be mapped to oneDNN Graph ops are mapped as
// Wildcards instead. They make the integration code with PyTorch simpler by
// passing every op to the oneDNN Graph library in the add_op call -
// no need to check beforehand whether the op is supported by oneDNN Graph or
// not oneDNN Graph ops separated by wildcards don't end up in the same
// partition.
static Operator makeWildcardOp(Node* node) {
auto o = Operator(node, opkind::Wildcard);
// wildcard op contains only topology info
for (size_t i = 0; i < node->inputs().size(); i++) {
o.setInput(0, i);
}
for (size_t i = 0; i < node->outputs().size(); i++) {
o.setOutput(i);
}
return o;
}
// If we don't meet a certain condition to map a PyTorch op to a oneDNN Graph
// op, then we create a wildcard op corresponding to that PyTorch op instead.
#define REQUIRE(cond) \
if (!(cond)) { \
GRAPH_DEBUG("Unsupported condition " #cond "\n"); \
return makeWildcardOp(node); \
}
Operator LlgaGraphHelper::makeEltwiseOp(Node* node, opkind kind) {
return Operator(node, kind).setInput(0).setOutput(dnnl_graph_, 0);
}
Operator LlgaGraphHelper::makeBinaryOp(Node* node, opkind kind) {
REQUIRE(
node->input(0)->type()->isSubtypeOf(TensorType::get()) &&
node->input(1)->type()->isSubtypeOf(TensorType::get()))
return Operator(node, kind).setInput(0, 1).setOutput(dnnl_graph_, 0);
}
// Map a PyTorch op to its corresponding oneDNN Graph op.
// If mapping isn't possible, then create a wildcard op instead.
// The mapping is done as per oneDNN Graph op schema defined in
// third_party/ideep/mkl-dnn/src/interface/op_def.hpp.
Operator LlgaGraphHelper::createOperator(Node* node) {
auto nodeKind = node->kind();
// we're using an if-else clause instead of a switch staement
// because we would soon be adding custom ops with function schemas.
// We would have to use Symbol::fromQualString at that time anyway,
// but we are okay with this choice, since this code is not in the hot-path.
if (nodeKind == Symbol::fromQualString("aten::conv2d")) {
fixConvOptionalBias(node);
return Operator(node, opkind::Convolution)
.setInput(0, 1, 2)
.setOutput(dnnl_graph_, 0)
.setAttr(dnnl::graph::op::attr::strides, Operator::Ints, 3)
.setAttr(dnnl::graph::op::attr::pads_begin, Operator::Ints, 4)
.setAttr(dnnl::graph::op::attr::pads_end, Operator::Ints, 4)
.setAttr(dnnl::graph::op::attr::dilations, Operator::Ints, 5)
.setAttr(dnnl::graph::op::attr::groups, Operator::Int, 6)
.setAttr(dnnl::graph::op::attr::weights_format, std::string("OIX"))
.setAttr(dnnl::graph::op::attr::data_format, std::string("NCX"));
} else if (
(nodeKind == Symbol::fromQualString("aten::_convolution")) ||
(nodeKind == Symbol::fromQualString("aten::convolution"))) {
bool transposed = toIValue(node->namedInput("transposed"))->toBool();
REQUIRE(!transposed);
return Operator(node, opkind::Convolution)
.setInput(0, 1, 2)
.setOutput(dnnl_graph_, 0)
.setAttr(dnnl::graph::op::attr::strides, Operator::Ints, 3)
.setAttr(dnnl::graph::op::attr::pads_begin, Operator::Ints, 4)
.setAttr(dnnl::graph::op::attr::pads_end, Operator::Ints, 4)
.setAttr(dnnl::graph::op::attr::dilations, Operator::Ints, 5)
.setAttr(dnnl::graph::op::attr::groups, Operator::Int, 8)
.setAttr(dnnl::graph::op::attr::weights_format, std::string("OIX"))
.setAttr(dnnl::graph::op::attr::data_format, std::string("NCX"));
} else if (nodeKind == Symbol::fromQualString("aten::batch_norm")) {
auto training = toIValue(node->namedInput("training"));
REQUIRE(training.has_value()); // cannot get training status in script mode
if (!training->toBool()) {
return Operator(node, opkind::BatchNormInference)
.setInput(0, 1, 2, 3, 4)
.setOutput(dnnl_graph_, 0)
.setAttr(dnnl::graph::op::attr::epsilon, Operator::Float, 7)
.setAttr(dnnl::graph::op::attr::data_format, std::string("NCX"));
}
} else if (nodeKind == Symbol::fromQualString("aten::layer_norm")) {
auto normalized_shape = toIValue(node->namedInput("normalized_shape"));
REQUIRE(normalized_shape->toIntList().size() == 1);
return Operator(node, opkind::LayerNorm)
.setInput(0, 2, 3)
.setOutput(dnnl_graph_, 0)
.setAttr(dnnl::graph::op::attr::epsilon, Operator::Float, 4)
.setAttr(dnnl::graph::op::attr::keep_stats, false);
} else if (nodeKind == Symbol::fromQualString("aten::addmm")) {
auto alpha = toIValue(node->namedInput("alpha"));
auto beta = toIValue(node->namedInput("beta"));
if (alpha.has_value() && beta.has_value()) {
if ((alpha->toDouble() == 1.0) && (beta->toDouble() == 1.0)) {
return Operator(node, opkind::MatMul)
.setInput(1, 2, 0)
.setOutput(dnnl_graph_, 0);
} else if ((alpha->toDouble() == 1.0) && (beta->toDouble() == 0.0)) {
return Operator(node, opkind::MatMul)
.setInput(1, 2)
.setOutput(dnnl_graph_, 0);
}
}
} else if (nodeKind == Symbol::fromQualString("aten::add"))
return makeBinaryOp(node, opkind::Add);
else if (nodeKind == Symbol::fromQualString("aten::mul"))
return makeBinaryOp(node, opkind::Multiply);
else if (nodeKind == Symbol::fromQualString("aten::div"))
return makeBinaryOp(node, opkind::Divide);
else if (nodeKind == Symbol::fromQualString("aten::tanh"))
return makeEltwiseOp(node, opkind::Tanh);
else if (nodeKind == Symbol::fromQualString("aten::relu"))
return makeEltwiseOp(node, opkind::ReLU);
else if (nodeKind == Symbol::fromQualString("aten::elu"))
return makeEltwiseOp(node, opkind::Elu)
.setAttr(dnnl::graph::op::attr::alpha, Operator::Float, 1);
else if (nodeKind == Symbol::fromQualString("aten::sigmoid"))
return makeEltwiseOp(node, opkind::Sigmoid);
else if (nodeKind == Symbol::fromQualString("aten::gelu"))
return makeEltwiseOp(node, opkind::GELU);
else if (nodeKind == Symbol::fromQualString("aten::round"))
return makeEltwiseOp(node, opkind::Round);
else if (nodeKind == Symbol::fromQualString("aten::exp"))
return makeEltwiseOp(node, opkind::Exp);
else if (nodeKind == Symbol::fromQualString("aten::sqrt"))
return makeEltwiseOp(node, opkind::Sqrt);
else if (nodeKind == Symbol::fromQualString("aten::abs"))
return makeEltwiseOp(node, opkind::Abs);
else if (nodeKind == Symbol::fromQualString("aten::square"))
return makeEltwiseOp(node, opkind::Square);
else if (nodeKind == Symbol::fromQualString("aten::clamp")) {
// PyTorch API already checks that both min & max are not None.
// But we can check it nevertheless.
auto clamp_min = toIValue(node->input(1));
auto clamp_max = toIValue(node->input(2));
REQUIRE(!(clamp_max->isNone() && clamp_min->isNone()));
auto clamp_min_value = (clamp_min->isNone())
? -std::numeric_limits<float>::infinity()
: Operator::ScalarToFloat(node, 1);
auto clamp_max_value = (clamp_max->isNone())
? std::numeric_limits<float>::infinity()
: Operator::ScalarToFloat(node, 2);
return makeEltwiseOp(node, opkind::Clamp)
.setAttr(dnnl::graph::op::attr::min, clamp_min_value)
.setAttr(dnnl::graph::op::attr::max, clamp_max_value);
} else if (nodeKind == Symbol::fromQualString("aten::hardtanh")) {
return makeEltwiseOp(node, opkind::Clamp)
.setAttr(dnnl::graph::op::attr::min, Operator::ScalarToFloat, 1)
.setAttr(dnnl::graph::op::attr::max, Operator::ScalarToFloat, 2);
} else if (nodeKind == Symbol::fromQualString("aten::hardswish"))
return makeEltwiseOp(node, opkind::HardSwish);
else if (nodeKind == Symbol::fromQualString("aten::log"))
return makeEltwiseOp(node, opkind::Log);
else if (nodeKind == Symbol::fromQualString("aten::leaky_relu")) {
return makeEltwiseOp(node, opkind::LeakyReLU)
.setAttr(dnnl::graph::op::attr::alpha, Operator::Float, 1);
} else if (nodeKind == Symbol::fromQualString("aten::relu6")) {
return makeEltwiseOp(node, opkind::Clamp)
.setAttr(dnnl::graph::op::attr::min, 0.f)
.setAttr(dnnl::graph::op::attr::max, 6.f);
} else if (
(nodeKind == Symbol::fromQualString("aten::softmax")) ||
(nodeKind == Symbol::fromQualString("aten::_softmax"))) {
auto axis = toIValue(node->namedInput("dim"))->toInt();
return Operator(node, opkind::SoftMax)
.setInput(0)
.setOutput(dnnl_graph_, 0)
.setAttr(dnnl::graph::op::attr::axis, axis);
} else if (nodeKind == Symbol::fromQualString("aten::_log_softmax")) {
auto axis = toIValue(node->namedInput("dim"))->toInt();
return Operator(node, opkind::LogSoftmax)
.setInput(0)
.setOutput(dnnl_graph_, 0)
.setAttr(dnnl::graph::op::attr::axis, axis);
} else if (nodeKind == Symbol::fromQualString("aten::cat")) {
auto o = Operator(node, opkind::Concat);
REQUIRE(node->namedInput("tensors")->node()->kind() == prim::ListConstruct);
REQUIRE(node->namedInput("tensors")->uses().size() == 1);
REQUIRE(node->namedInput("dim")->node()->kind() == prim::Constant);
// aten::cat needs a special handling since it takes a Tensor[] as input.
// We set the inputs of ListConstruct as the inputs of cat.
//
// Pytorch IR: LLGA sees:
// %a %b %c %dim %a %b %c
// \ | / | \ | /
// prim::ListConstruct prim::Constant llga::Concat[axis=%dim]
// \ /
// aten::cat
auto listConstruct = node->input(0)->node();
for (auto input : listConstruct->inputs())
o.setInputValue(input);
return o.setOutput(dnnl_graph_, 0)
.setAttr(dnnl::graph::op::attr::axis, Operator::Int, 1);
} else if (
(nodeKind == Symbol::fromQualString("aten::max_pool2d")) ||
(nodeKind == Symbol::fromQualString("aten::max_pool2d_with_indices"))) {
// Currently, LLGA lacks support to create indices mask.
// Once it's supported, max_pool2d_with_indices should be mapped differently
REQUIRE(node->namedInput("kernel_size")->node()->kind() == prim::Constant);
auto rounding_type =
toIValue(node->namedInput("ceil_mode"))->toBool() ? "ceil" : "floor";
return Operator(node, opkind::MaxPool)
.setInput(0)
.setOutput(dnnl_graph_, 0)
.setAttr(dnnl::graph::op::attr::kernel, Operator::Ints, 1)
.setAttr(dnnl::graph::op::attr::strides, Operator::Ints, 2)
.setAttr(dnnl::graph::op::attr::pads_begin, Operator::Ints, 3)
.setAttr(dnnl::graph::op::attr::pads_end, Operator::Ints, 3)
.setAttr(dnnl::graph::op::attr::dilations, Operator::Ints, 4)
.setAttr(
dnnl::graph::op::attr::rounding_type, std::string(rounding_type))
.setAttr(dnnl::graph::op::attr::data_format, std::string("NCX"));
} else if (nodeKind == Symbol::fromQualString("aten::avg_pool2d")) {
// TODO: do we need add checks for all Constants?
REQUIRE(node->namedInput("kernel_size")->node()->kind() == prim::Constant);
auto rounding_type =
toIValue(node->namedInput("ceil_mode"))->toBool() ? "ceil" : "floor";
auto divisor_override = toIValue(node->namedInput("divisor_override"));
REQUIRE(divisor_override->isNone());
return Operator(node, opkind::AvgPool)
.setInput(0)
.setOutput(dnnl_graph_, 0)
.setAttr(dnnl::graph::op::attr::kernel, Operator::Ints, 1)
.setAttr(dnnl::graph::op::attr::strides, Operator::Ints, 2)
.setAttr(dnnl::graph::op::attr::pads_begin, Operator::Ints, 3)
.setAttr(dnnl::graph::op::attr::pads_end, Operator::Ints, 3)
.setAttr(dnnl::graph::op::attr::exclude_pad, !Operator::Bool(node, 5))
.setAttr(
dnnl::graph::op::attr::rounding_type, std::string(rounding_type))
.setAttr(dnnl::graph::op::attr::data_format, std::string("NCX"));
} else if (nodeKind == Symbol::fromQualString("aten::matmul")) {
auto dim0 = getDimensions(node->namedInput("self")).value_or(-1);
auto dim1 = getDimensions(node->namedInput("other")).value_or(-1);
// TODO: support all shape combinations
REQUIRE(
(dim0 == 2 && dim1 == 2) || (dim0 == 4 && dim1 == 4) ||
(dim0 == 3 && dim1 == 2));
return Operator(node, opkind::MatMul)
.setInput(0, 1)
.setOutput(dnnl_graph_, 0);
} // fall through
else if (nodeKind == Symbol::fromQualString("aten::mm")) {
return Operator(node, opkind::MatMul)
.setInput(0, 1)
.setOutput(dnnl_graph_, 0);
} else if (nodeKind == Symbol::fromQualString("aten::bmm")) {
return Operator(node, opkind::MatMul)
.setInput(0, 1)
.setOutput(dnnl_graph_, 0);
} else if (nodeKind == Symbol::fromQualString("aten::linear")) {
return Operator(node, opkind::MatMul)
.setInput(0, 1, 2)
.setOutput(dnnl_graph_, 0)
.setAttr(dnnl::graph::op::attr::transpose_b, true);
} else if (nodeKind == Symbol::fromQualString("aten::permute")) {
REQUIRE(aliasDb_->hasInputWriters(node) == false);
return Operator(node, opkind::StaticTranspose)
.setInput(0)
.setOutput(dnnl_graph_, 0)
.setAttr(
dnnl::graph::op::attr::order,
toIValue(node->namedInput("dims"))->toIntVector());
} else if (nodeKind == Symbol::fromQualString("aten::contiguous")) {
// Contiguous should only be mapped to oneDNN Graph if the destination
// memory-layout is different than the source memory-format
// Strides would be different, but shape would be same
auto typeOfInput = node->input(0)->type()->expect<TensorType>();
auto typeOfOutput = node->output(0)->type()->expect<TensorType>();
auto inputStrides = typeOfInput->strides().concrete_sizes();
auto outputStrides = typeOfOutput->strides().concrete_sizes();
REQUIRE(inputStrides != outputStrides);
return Operator(node, opkind::Reorder)
.setInput(0)
.setOutput(dnnl_graph_, 0);
}
GRAPH_DEBUG("Making ", nodeKind.toQualString(), " a wildcard");
return makeWildcardOp(node);
}
static DeviceType inferDeviceFromValue(Value* v) {
auto tt = v->type()->cast<TensorType>();
if (!tt) {
return at::kCPU;
}
auto device = tt->device();
if (!device) {
return at::kCPU;
}
return device->type();
}
static DeviceType inferDevice(const std::shared_ptr<Graph>& graph) {
auto dt = inferDeviceFromValue(graph->inputs()[0]);
TORCH_CHECK(
std::all_of(
graph->inputs().begin(),
graph->inputs().end(),
[dt](Value* v) { return inferDeviceFromValue(v) == dt; }),
"All inputs must have the same deive type");
return dt;
}
static dnnl::engine::kind getLlgaEngineKind(DeviceType type) {
switch (type) {
case DeviceType::CPU:
return dnnl::engine::kind::cpu;
default:
TORCH_CHECK(false, "Not support device type ", type);
}
}
static void mayAddListConstructIntoConcatPartition(
Node* n,
OpPartitionMap& opToOwningPartition) {
// Since prim::ListConstruct is not visible to the LLGA,
// it will not be in any partition returned from partfuseritioning results.
// We need rewrite opToOwningPartition to make the prim::ListConstruct to be
// 'virtually' in the same partition with the aten::cat, so that
// prim::ListConstruct can be fused into the fusion group by graph fuser.
// We emphasize on 'virtually' because get_num_ops() for cat's partition
// would still return 1.
if (n->kind() == aten::cat && opToOwningPartition.has(n)) {
auto listConstrcut = n->namedInput("tensors")->node();
auto partitionId = opToOwningPartition.get(n);
opToOwningPartition.add(listConstrcut, partitionId);
}
}
// Verify that input tensors are compatible with oneDNN Graph.
// Scalars would be converted to 1-D tensors later anyway,
// but they shouldn't be complex-double
// If this check fails, convert op to wildcard
static bool checkInputCompatibility(Node* node) {
auto allInputs = node->inputs();
for (auto input : allInputs) {
c10::IValue inputIValue = toIValue(input);
if (inputIValue.isTensor()) {
const at::Tensor& tensor = inputIValue.toTensor();
if (tensor.device() != at::kCPU) {
return false;
}
auto dtype = tensor.scalar_type();
if ((dtype != at::ScalarType::BFloat16) &&
(dtype != at::ScalarType::Float) && (dtype != at::ScalarType::Long)) {
// We've allowed Long dtype here although oneDNN Graph does not support
// Long dtype because oneDNN Graph will end up not handling the op that
// has an input with Long dtype, so it'd be handled by PyTorch.
return false;
}
} else if (inputIValue.isScalar()) {
if (inputIValue.isComplexDouble()) {
return false;
}
} else if (input->type()->isSubtypeOf(TensorType::get())) {
auto input_typeptr = input->type()->cast<TensorType>();
if (input_typeptr->scalarType().has_value()) {
at::ScalarType dtype = input_typeptr->scalarType().value();
if ((dtype != at::ScalarType::Float) &&
(dtype != at::ScalarType::BFloat16)) {
return false;
}
}
}
}
return true;
}
LlgaGraphHelper::LlgaGraphHelper(
const std::shared_ptr<Graph>& graph,
dnnl::graph::partition::policy policy) {
auto deviceType = inferDevice(graph);
auto engineKind = getLlgaEngineKind(deviceType);
dnnl_graph_ = std::make_unique<dnnl::graph::graph>(engineKind);
aliasDb_ = std::make_unique<torch::jit::AliasDb>(graph);
GRAPH_DEBUG("Constructing LLGA graph");
// TODO: select nodes in top-level block for now
for (auto* node : graph->block()->nodes()) {
auto kindOfNode = node->kind();
GRAPH_DEBUG("Trying to add ", kindOfNode.toQualString());
if (checkInputCompatibility(node)) {
auto op = createOperator(node);
dnnl_graph_->add_op(op.llgaOp());
GRAPH_DEBUG(" Added node ", kindOfNode.toQualString());
} else {
GRAPH_DEBUG("Incompatible inputs for ", kindOfNode.toQualString());
dnnl_graph_->add_op(makeWildcardOp(node).llgaOp());
}
for (Value* input : node->inputs()) {
tensorIdToValue_.emplace(input->unique(), input);
}
}
dnnl_graph_->finalize();
GRAPH_DEBUG("Get Partitions");
std::vector<dnnl::graph::partition> partitions =
dnnl_graph_->get_partitions(policy);
// excluded unsupported Wildcard partitions
for (auto& partition : partitions) {
if (partition.is_supported()) {
partitions_.push_back(partition);
}
}
GRAPH_DEBUG(" Got #partitions: ", partitions_.size());
for (size_t partId = 0; partId < partitions_.size(); partId++) {
for (auto opId : partitions_[partId].get_ops()) {
opToOwningPartition_.add(opId, partId);
}
}
// Scanning the graph again for post processing
for (auto* node : graph->block()->nodes()) {
mayAddListConstructIntoConcatPartition(node, opToOwningPartition_);
}
}
bool LlgaGraphHelper::isLlgaSubgraph(const Node* node) {
return node->hasAttribute(attr::Subgraph) &&
node->kind() == prim::oneDNNFusionGroup;
}
bool LlgaGraphHelper::shouldMerge(Node* toMerge, Node* subgraph) {
TORCH_CHECK(
isLlgaSubgraph(subgraph),
"The consumer node does not contain a subgraph");
if (!shouldConsiderForMerge(toMerge)) {
return false;
}
return opToOwningPartition_.get(toMerge) ==
opToOwningPartition_.get(subgraph);
}
// Except for conv & GEMMs, which should always be handled by oneDNN Graph,
// only use single-op partitions for ops unsupported by NNC, or ops
// that oneDNN executes faster. prim::ListConstruct is an exception, since
// we simply want to fuse it with cat.
static bool isBetterSuitedForLLGA(NodeKind kindOfOp) {
return (
(kindOfOp == aten::layer_norm) || (kindOfOp == aten::avg_pool2d) ||
(kindOfOp == aten::matmul) || (kindOfOp == aten::max_pool2d) ||
(kindOfOp == aten::conv2d) || (kindOfOp == aten::_convolution) ||
(kindOfOp == aten::mm) || (kindOfOp == aten::linear) ||
(kindOfOp == aten::cat) || (kindOfOp == prim::ListConstruct));
}
bool LlgaGraphHelper::checkForSingleOpPartition(Node* node) {
if (opToOwningPartition_.has(node)) {
auto partitionId = opToOwningPartition_.get(node);
if (partitions_[partitionId].get_ops_num() == 1) {
auto kindOfNode = node->kind();
return isBetterSuitedForLLGA(kindOfNode);
} else {
// multi-op partition
return true;
}
} else {
// this op isn't present in any partition
return false;
}
}
bool LlgaGraphHelper::shouldConsiderForMerge(Node* node) {
// if we're already in the process of merging
if (isLlgaSubgraph(node)) {
return true;
}
return checkForSingleOpPartition(node);
}
Node* LlgaGraphHelper::createSingletonSubgraph(Node* n, AliasDb& aliasDb) {
auto partitionId = opToOwningPartition_.get(n);
GRAPH_DEBUG(
"Creating FusionGroup_", partitionId, " for ", n->kind().toQualString());
auto group = SubgraphUtils::createSingletonSubgraphAndUpdateAliasing(
n, prim::oneDNNFusionGroup, aliasDb);
opToOwningPartition_.add(group, partitionId);
return group;
}
void LlgaGraphHelper::mergeNodeIntoSubgraph(
Node* toMerge,
Node* subgraphNode,
AliasDb& aliasDb) {
if (isLlgaSubgraph(toMerge)) {
GRAPH_DEBUG(
"Merging ",
toMerge->kind().toQualString(),
"_",
opToOwningPartition_.get(toMerge),
" into ",
subgraphNode->kind().toQualString(),
"_",
opToOwningPartition_.get(subgraphNode));
} else {
GRAPH_DEBUG(
"Merging ",
toMerge->kind().toQualString(),
" into ",
subgraphNode->kind().toQualString(),
"_",
opToOwningPartition_.get(subgraphNode));
}
SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing(
toMerge, subgraphNode, aliasDb);
}
void LlgaGraphHelper::unmergeIfAnyNodeIsMissing(Node* subgraphNode) {
TORCH_CHECK(isLlgaSubgraph(subgraphNode), "Cannot unmerge a non-LLGA node");
auto partitionId = opToOwningPartition_.get(subgraphNode);
auto expectOpNum = partitions_[partitionId].get_ops_num();
auto actualOpNum = countSupportedOps(subgraphNode->g(attr::Subgraph));
if (expectOpNum != actualOpNum) {
GRAPH_DEBUG(
"Unmerging FusionGroup_",
partitionId,
". Expected ",
expectOpNum,
" ops, but got ",
actualOpNum,
" ops.");
SubgraphUtils::unmergeSubgraph(subgraphNode);
}
}
size_t LlgaGraphHelper::countSupportedOps(
const std::shared_ptr<Graph>& graph) const {
// TODO: count nodes in top-level block for now
size_t cnt = 0;
for (auto* node : graph->block()->nodes()) {
auto nodeKind = node->kind();
if ((nodeKind != prim::Constant) && (nodeKind != prim::ListConstruct)) {
cnt++;
}
}
return cnt;
}
std::vector<dnnl::graph::partition> LlgaGraphHelper::getPartitions() const {
return partitions_;
}
std::map<size_t, Value*> LlgaGraphHelper::getTensorIdToValue() const {
return tensorIdToValue_;
}
LlgaNodeWrapper::LlgaNodeWrapper(const Node* node)
: n(const_cast<Node*>(node)) { // NOLINT
TORCH_CHECK(
LlgaGraphHelper::isLlgaSubgraph(n), "Cannot wrap a non-LLGA fusion node");
}
void LlgaNodeWrapper::setOpaqueLayout(size_t offset) {
const auto num_output = n->is(attr::output_layouts).size();
TORCH_CHECK(
offset < num_output,
"Out of range. (Invalid index ",
offset,
" for attr::output_layouts with size ",
num_output,
")");
auto& layouts =
const_cast<std::vector<int64_t>&>(n->is(attr::output_layouts)); // NOLINT
layouts.at(offset) = OPAQUE_LAYOUT;
}
bool LlgaNodeWrapper::useOpaqueLayout(size_t offset) const {
const auto num_output = n->is(attr::output_layouts).size();
TORCH_CHECK(
offset < num_output,
"Out of range. (Invalid index ",
offset,
" for attr::output_layouts with size ",
num_output,
")");
return n->is(attr::output_layouts)[offset] == OPAQUE_LAYOUT;
}
} // namespace torch::jit::fuser::onednn
|