1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028
|
//===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
// binding classes wrapping a generic operation API.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
/// File header and includes.
/// {0} is the dialect namespace.
constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
from ._ods_common import _cext as _ods_cext
from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
_ods_ir = _ods_cext.ir
try:
from . import _{0}_ops_ext as _ods_ext_module
except ImportError:
_ods_ext_module = None
import builtins
)Py";
/// Template for dialect class:
/// {0} is the dialect namespace.
constexpr const char *dialectClassTemplate = R"Py(
@_ods_cext.register_dialect
class _Dialect(_ods_ir.Dialect):
DIALECT_NAMESPACE = "{0}"
pass
)Py";
constexpr const char *dialectExtensionTemplate = R"Py(
from ._{0}_ops_gen import _Dialect
)Py";
/// Template for operation class:
/// {0} is the Python class name;
/// {1} is the operation name.
constexpr const char *opClassTemplate = R"Py(
@_ods_cext.register_operation(_Dialect)
@_ods_extend_opview_class(_ods_ext_module)
class {0}(_ods_ir.OpView):
OPERATION_NAME = "{1}"
)Py";
/// Template for class level declarations of operand and result
/// segment specs.
/// {0} is either "OPERAND" or "RESULT"
/// {1} is the segment spec
/// Each segment spec is either None (default) or an array of integers
/// where:
/// 1 = single element (expect non sequence operand/result)
/// 0 = optional element (expect a value or std::nullopt)
/// -1 = operand/result is a sequence corresponding to a variadic
constexpr const char *opClassSizedSegmentsTemplate = R"Py(
_ODS_{0}_SEGMENTS = {1}
)Py";
/// Template for class level declarations of the _ODS_REGIONS spec:
/// {0} is the minimum number of regions
/// {1} is the Python bool literal for hasNoVariadicRegions
constexpr const char *opClassRegionSpecTemplate = R"Py(
_ODS_REGIONS = ({0}, {1})
)Py";
/// Template for single-element accessor:
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
/// {2} is the position in the element list.
constexpr const char *opSingleTemplate = R"Py(
@builtins.property
def {0}(self):
return self.operation.{1}s[{2}]
)Py";
/// Template for single-element accessor after a variable-length group:
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
/// {2} is the total number of element groups;
/// {3} is the position of the current group in the group list.
/// This works for both a single variadic group (non-negative length) and an
/// single optional element (zero length if the element is absent).
constexpr const char *opSingleAfterVariableTemplate = R"Py(
@builtins.property
def {0}(self):
_ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
)Py";
/// Template for an optional element accessor:
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
/// {2} is the total number of element groups;
/// {3} is the position of the current group in the group list.
/// This works if we have only one variable-length group (and it's the optional
/// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
/// smaller than the total number of groups.
constexpr const char *opOneOptionalTemplate = R"Py(
@builtins.property
def {0}(self):
return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
)Py";
/// Template for the variadic group accessor in the single variadic group case:
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
/// {2} is the total number of element groups;
/// {3} is the position of the current group in the group list.
constexpr const char *opOneVariadicTemplate = R"Py(
@builtins.property
def {0}(self):
_ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
)Py";
/// First part of the template for equally-sized variadic group accessor:
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
/// {2} is the total number of variadic groups;
/// {3} is the number of non-variadic groups preceding the current group;
/// {3} is the number of variadic groups preceding the current group.
constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
@builtins.property
def {0}(self):
start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
/// Second part of the template for equally-sized case, accessing a single
/// element:
/// {0} is either 'operand' or 'result'.
constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
return self.operation.{0}s[start]
)Py";
/// Second part of the template for equally-sized case, accessing a variadic
/// group:
/// {0} is either 'operand' or 'result'.
constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
return self.operation.{0}s[start:start + pg]
)Py";
/// Template for an attribute-sized group accessor:
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
/// {2} is the position of the group in the group list;
/// {3} is a return suffix (expected [0] for single-element, empty for
/// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
constexpr const char *opVariadicSegmentTemplate = R"Py(
@builtins.property
def {0}(self):
{1}_range = _ods_segmented_accessor(
self.operation.{1}s,
self.operation.attributes["{1}SegmentSizes"], {2})
return {1}_range{3}
)Py";
/// Template for a suffix when accessing an optional element in the
/// attribute-sized case:
/// {0} is either 'operand' or 'result';
constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
R"Py([0] if len({0}_range) > 0 else None)Py";
/// Template for an operation attribute getter:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *attributeGetterTemplate = R"Py(
@builtins.property
def {0}(self):
return self.operation.attributes["{1}"]
)Py";
/// Template for an optional operation attribute getter:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *optionalAttributeGetterTemplate = R"Py(
@builtins.property
def {0}(self):
if "{1}" not in self.operation.attributes:
return None
return self.operation.attributes["{1}"]
)Py";
/// Template for a getter of a unit operation attribute, returns True of the
/// unit attribute is present, False otherwise (unit attributes have meaning
/// by mere presence):
/// {0} is the name of the attribute sanitized for Python,
/// {1} is the original name of the attribute.
constexpr const char *unitAttributeGetterTemplate = R"Py(
@builtins.property
def {0}(self):
return "{1}" in self.operation.attributes
)Py";
/// Template for an operation attribute setter:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *attributeSetterTemplate = R"Py(
@{0}.setter
def {0}(self, value):
if value is None:
raise ValueError("'None' not allowed as value for mandatory attributes")
self.operation.attributes["{1}"] = value
)Py";
/// Template for a setter of an optional operation attribute, setting to None
/// removes the attribute:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *optionalAttributeSetterTemplate = R"Py(
@{0}.setter
def {0}(self, value):
if value is not None:
self.operation.attributes["{1}"] = value
elif "{1}" in self.operation.attributes:
del self.operation.attributes["{1}"]
)Py";
/// Template for a setter of a unit operation attribute, setting to None or
/// False removes the attribute:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *unitAttributeSetterTemplate = R"Py(
@{0}.setter
def {0}(self, value):
if bool(value):
self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
elif "{1}" in self.operation.attributes:
del self.operation.attributes["{1}"]
)Py";
/// Template for a deleter of an optional or a unit operation attribute, removes
/// the attribute from the operation:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *attributeDeleterTemplate = R"Py(
@{0}.deleter
def {0}(self):
del self.operation.attributes["{1}"]
)Py";
constexpr const char *regionAccessorTemplate = R"PY(
@builtins.property
def {0}(self):
return self.regions[{1}]
)PY";
static llvm::cl::OptionCategory
clOpPythonBindingCat("Options for -gen-python-op-bindings");
static llvm::cl::opt<std::string>
clDialectName("bind-dialect",
llvm::cl::desc("The dialect to run the generator for"),
llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
static llvm::cl::opt<std::string> clDialectExtensionName(
"dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
using AttributeClasses = DenseMap<StringRef, StringRef>;
/// Checks whether `str` is a Python keyword or would shadow builtin function.
static bool isPythonReserved(StringRef str) {
static llvm::StringSet<> reserved(
{"and", "as", "assert", "break", "callable", "class",
"continue", "def", "del", "elif", "else", "except",
"finally", "for", "from", "global", "if", "import",
"in", "is", "lambda", "nonlocal", "not", "or",
"pass", "raise", "return", "issubclass", "try", "type",
"while", "with", "yield"});
return reserved.contains(str);
}
/// Checks whether `str` would shadow a generated variable or attribute
/// part of the OpView API.
static bool isODSReserved(StringRef str) {
static llvm::StringSet<> reserved(
{"attributes", "create", "context", "ip", "operands", "print", "get_asm",
"loc", "verify", "regions", "results", "self", "operation",
"DIALECT_NAMESPACE", "OPERATION_NAME"});
return str.startswith("_ods_") || str.endswith("_ods") ||
reserved.contains(str);
}
/// Modifies the `name` in a way that it becomes suitable for Python bindings
/// (does not change the `name` if it already is suitable) and returns the
/// modified version.
static std::string sanitizeName(StringRef name) {
if (isPythonReserved(name) || isODSReserved(name))
return (name + "_").str();
return name.str();
}
static std::string attrSizedTraitForKind(const char *kind) {
return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
llvm::StringRef(kind).take_front().upper(),
llvm::StringRef(kind).drop_front());
}
/// Emits accessors to "elements" of an Op definition. Currently, the supported
/// elements are operands and results, indicated by `kind`, which must be either
/// `operand` or `result` and is used verbatim in the emitted code.
static void emitElementAccessors(
const Operator &op, raw_ostream &os, const char *kind,
llvm::function_ref<unsigned(const Operator &)> getNumVariableLength,
llvm::function_ref<int(const Operator &)> getNumElements,
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
getElement) {
assert(llvm::is_contained(
llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
"unsupported kind");
// Traits indicating how to process variadic elements.
std::string sameSizeTrait =
llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
llvm::StringRef(kind).take_front().upper(),
llvm::StringRef(kind).drop_front());
std::string attrSizedTrait = attrSizedTraitForKind(kind);
unsigned numVariableLength = getNumVariableLength(op);
// If there is only one variable-length element group, its size can be
// inferred from the total number of elements. If there are none, the
// generation is straightforward.
if (numVariableLength <= 1) {
bool seenVariableLength = false;
for (int i = 0, e = getNumElements(op); i < e; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
if (element.isVariableLength())
seenVariableLength = true;
if (element.name.empty())
continue;
if (element.isVariableLength()) {
os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
: opOneVariadicTemplate,
sanitizeName(element.name), kind,
getNumElements(op), i);
} else if (seenVariableLength) {
os << llvm::formatv(opSingleAfterVariableTemplate,
sanitizeName(element.name), kind,
getNumElements(op), i);
} else {
os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
i);
}
}
return;
}
// Handle the operations where variadic groups have the same size.
if (op.getTrait(sameSizeTrait)) {
int numPrecedingSimple = 0;
int numPrecedingVariadic = 0;
for (int i = 0, e = getNumElements(op); i < e; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
if (!element.name.empty()) {
os << llvm::formatv(opVariadicEqualPrefixTemplate,
sanitizeName(element.name), kind, numVariableLength,
numPrecedingSimple, numPrecedingVariadic);
os << llvm::formatv(element.isVariableLength()
? opVariadicEqualVariadicTemplate
: opVariadicEqualSimpleTemplate,
kind);
}
if (element.isVariableLength())
++numPrecedingVariadic;
else
++numPrecedingSimple;
}
return;
}
// Handle the operations where the size of groups (variadic or not) is
// provided as an attribute. For non-variadic elements, make sure to return
// an element rather than a singleton container.
if (op.getTrait(attrSizedTrait)) {
for (int i = 0, e = getNumElements(op); i < e; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
if (element.name.empty())
continue;
std::string trailing;
if (!element.isVariableLength())
trailing = "[0]";
else if (element.isOptional())
trailing = std::string(
llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
kind, i, trailing);
}
return;
}
llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
}
/// Free function helpers accessing Operator components.
static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
return op.getOperand(i);
}
static int getNumResults(const Operator &op) { return op.getNumResults(); }
static const NamedTypeConstraint &getResult(const Operator &op, int i) {
return op.getResult(i);
}
/// Emits accessors to Op operands.
static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
auto getNumVariableLengthOperands = [](const Operator &oper) {
return oper.getNumVariableLengthOperands();
};
emitElementAccessors(op, os, "operand", getNumVariableLengthOperands,
getNumOperands, getOperand);
}
/// Emits accessors Op results.
static void emitResultAccessors(const Operator &op, raw_ostream &os) {
auto getNumVariableLengthResults = [](const Operator &oper) {
return oper.getNumVariableLengthResults();
};
emitElementAccessors(op, os, "result", getNumVariableLengthResults,
getNumResults, getResult);
}
/// Emits accessors to Op attributes.
static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
for (const auto &namedAttr : op.getAttributes()) {
// Skip "derived" attributes because they are just C++ functions that we
// don't currently expose.
if (namedAttr.attr.isDerivedAttr())
continue;
if (namedAttr.name.empty())
continue;
std::string sanitizedName = sanitizeName(namedAttr.name);
// Unit attributes are handled specially.
if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
namedAttr.name);
os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
namedAttr.name);
os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
namedAttr.name);
continue;
}
if (namedAttr.attr.isOptional()) {
os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName,
namedAttr.name);
os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName,
namedAttr.name);
os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
namedAttr.name);
} else {
os << llvm::formatv(attributeGetterTemplate, sanitizedName,
namedAttr.name);
os << llvm::formatv(attributeSetterTemplate, sanitizedName,
namedAttr.name);
// Non-optional attributes cannot be deleted.
}
}
}
/// Template for the default auto-generated builder.
/// {0} is a comma-separated list of builder arguments, including the trailing
/// `loc` and `ip`;
/// {1} is the code populating `operands`, `results` and `attributes`,
/// `successors` fields.
constexpr const char *initTemplate = R"Py(
def __init__(self, {0}):
operands = []
results = []
attributes = {{}
regions = None
{1}
super().__init__(self.build_generic({2}))
)Py";
/// Template for appending a single element to the operand/result list.
/// {0} is the field name.
constexpr const char *singleOperandAppendTemplate =
"operands.append(_get_op_result_or_value({0}))";
constexpr const char *singleResultAppendTemplate = "results.append({0})";
/// Template for appending an optional element to the operand/result list.
/// {0} is the field name.
constexpr const char *optionalAppendOperandTemplate =
"if {0} is not None: operands.append(_get_op_result_or_value({0}))";
constexpr const char *optionalAppendAttrSizedOperandsTemplate =
"operands.append(_get_op_result_or_value({0}) if {0} is not None else "
"None)";
constexpr const char *optionalAppendResultTemplate =
"if {0} is not None: results.append({0})";
/// Template for appending a list of elements to the operand/result list.
/// {0} is the field name.
constexpr const char *multiOperandAppendTemplate =
"operands.extend(_get_op_results_or_values({0}))";
constexpr const char *multiOperandAppendPackTemplate =
"operands.append(_get_op_results_or_values({0}))";
constexpr const char *multiResultAppendTemplate = "results.extend({0})";
/// Template for attribute builder from raw input in the operation builder.
/// {0} is the builder argument name;
/// {1} is the attribute builder from raw;
/// {2} is the attribute builder from raw.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initAttributeWithBuilderTemplate =
R"Py(attributes["{1}"] = ({0} if (
issubclass(type({0}), _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('{2}')) else
_ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
/// Template for attribute builder from raw input for optional attribute in the
/// operation builder.
/// {0} is the builder argument name;
/// {1} is the attribute builder from raw;
/// {2} is the attribute builder from raw.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initOptionalAttributeWithBuilderTemplate =
R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
issubclass(type({0}), _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('{2}')) else
_ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
constexpr const char *initUnitAttributeTemplate =
R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
_ods_get_default_loc_context(loc)))Py";
/// Template to initialize the successors list in the builder if there are any
/// successors.
/// {0} is the value to initialize the successors list to.
constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py";
/// Template to append or extend the list of successors in the builder.
/// {0} is the list method ('append' or 'extend');
/// {1} is the value to add.
constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";
/// Returns true if the SameArgumentAndResultTypes trait can be used to infer
/// result types of the given operation.
static bool hasSameArgumentAndResultTypes(const Operator &op) {
return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") &&
op.getNumVariableLengthResults() == 0;
}
/// Returns true if the FirstAttrDerivedResultType trait can be used to infer
/// result types of the given operation.
static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") &&
op.getNumVariableLengthResults() == 0;
}
/// Returns true if the InferTypeOpInterface can be used to infer result types
/// of the given operation.
static bool hasInferTypeInterface(const Operator &op) {
return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
op.getNumRegions() == 0;
}
/// Returns true if there is a trait or interface that can be used to infer
/// result types of the given operation.
static bool canInferType(const Operator &op) {
return hasSameArgumentAndResultTypes(op) ||
hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op);
}
/// Populates `builderArgs` with result names if the builder is expected to
/// accept them as arguments.
static void
populateBuilderArgsResults(const Operator &op,
llvm::SmallVectorImpl<std::string> &builderArgs) {
if (canInferType(op))
return;
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
std::string name = op.getResultName(i).str();
if (name.empty()) {
if (op.getNumResults() == 1) {
// Special case for one result, make the default name be 'result'
// to properly match the built-in result accessor.
name = "result";
} else {
name = llvm::formatv("_gen_res_{0}", i);
}
}
name = sanitizeName(name);
builderArgs.push_back(name);
}
}
/// Populates `builderArgs` with the Python-compatible names of builder function
/// arguments using intermixed attributes and operands in the same order as they
/// appear in the `arguments` field of the op definition. Additionally,
/// `operandNames` is populated with names of operands in their order of
/// appearance.
static void
populateBuilderArgs(const Operator &op,
llvm::SmallVectorImpl<std::string> &builderArgs,
llvm::SmallVectorImpl<std::string> &operandNames,
llvm::SmallVectorImpl<std::string> &successorArgNames) {
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
std::string name = op.getArgName(i).str();
if (name.empty())
name = llvm::formatv("_gen_arg_{0}", i);
name = sanitizeName(name);
builderArgs.push_back(name);
if (!op.getArg(i).is<NamedAttribute *>())
operandNames.push_back(name);
}
}
/// Populates `builderArgs` with the Python-compatible names of builder function
/// successor arguments. Additionally, `successorArgNames` is also populated.
static void populateBuilderArgsSuccessors(
const Operator &op, llvm::SmallVectorImpl<std::string> &builderArgs,
llvm::SmallVectorImpl<std::string> &successorArgNames) {
for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
NamedSuccessor successor = op.getSuccessor(i);
std::string name = std::string(successor.name);
if (name.empty())
name = llvm::formatv("_gen_successor_{0}", i);
name = sanitizeName(name);
builderArgs.push_back(name);
successorArgNames.push_back(name);
}
}
/// Populates `builderLines` with additional lines that are required in the
/// builder to set up operation attributes. `argNames` is expected to contain
/// the names of builder arguments that correspond to op arguments, i.e. to the
/// operands and attributes in the same order as they appear in the `arguments`
/// field.
static void
populateBuilderLinesAttr(const Operator &op,
llvm::ArrayRef<std::string> argNames,
llvm::SmallVectorImpl<std::string> &builderLines) {
builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
Argument arg = op.getArg(i);
auto *attribute = llvm::dyn_cast_if_present<NamedAttribute *>(arg);
if (!attribute)
continue;
// Unit attributes are handled specially.
if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
attribute->name, argNames[i]));
continue;
}
builderLines.push_back(llvm::formatv(
attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
? initOptionalAttributeWithBuilderTemplate
: initAttributeWithBuilderTemplate,
argNames[i], attribute->name, attribute->attr.getAttrDefName()));
}
}
/// Populates `builderLines` with additional lines that are required in the
/// builder to set up successors. successorArgNames is expected to correspond
/// to the Python argument name for each successor on the op.
static void populateBuilderLinesSuccessors(
const Operator &op, llvm::ArrayRef<std::string> successorArgNames,
llvm::SmallVectorImpl<std::string> &builderLines) {
if (successorArgNames.empty()) {
builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None"));
return;
}
builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]"));
for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
auto &argName = successorArgNames[i];
const NamedSuccessor &successor = op.getSuccessor(i);
builderLines.push_back(
llvm::formatv(addSuccessorTemplate,
successor.isVariadic() ? "extend" : "append", argName));
}
}
/// Populates `builderLines` with additional lines that are required in the
/// builder to set up op operands.
static void
populateBuilderLinesOperand(const Operator &op,
llvm::ArrayRef<std::string> names,
llvm::SmallVectorImpl<std::string> &builderLines) {
bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
// For each element, find or generate a name.
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
const NamedTypeConstraint &element = op.getOperand(i);
std::string name = names[i];
// Choose the formatting string based on the element kind.
llvm::StringRef formatString;
if (!element.isVariableLength()) {
formatString = singleOperandAppendTemplate;
} else if (element.isOptional()) {
if (sizedSegments) {
formatString = optionalAppendAttrSizedOperandsTemplate;
} else {
formatString = optionalAppendOperandTemplate;
}
} else {
assert(element.isVariadic() && "unhandled element group type");
// If emitting with sizedSegments, then we add the actual list-typed
// element. Otherwise, we extend the actual operands.
if (sizedSegments) {
formatString = multiOperandAppendPackTemplate;
} else {
formatString = multiOperandAppendTemplate;
}
}
builderLines.push_back(llvm::formatv(formatString.data(), name));
}
}
/// Python code template for deriving the operation result types from its
/// attribute:
/// - {0} is the name of the attribute from which to derive the types.
constexpr const char *deriveTypeFromAttrTemplate =
R"PY(_ods_result_type_source_attr = attributes["{0}"]
_ods_derived_result_type = (
_ods_ir.TypeAttr(_ods_result_type_source_attr).value
if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
_ods_result_type_source_attr.type))PY";
/// Python code template appending {0} type {1} times to the results list.
constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
/// Appends the given multiline string as individual strings into
/// `builderLines`.
static void appendLineByLine(StringRef string,
llvm::SmallVectorImpl<std::string> &builderLines) {
std::pair<StringRef, StringRef> split = std::make_pair(string, string);
do {
split = split.second.split('\n');
builderLines.push_back(split.first.str());
} while (!split.second.empty());
}
/// Populates `builderLines` with additional lines that are required in the
/// builder to set up op results.
static void
populateBuilderLinesResult(const Operator &op,
llvm::ArrayRef<std::string> names,
llvm::SmallVectorImpl<std::string> &builderLines) {
bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
if (hasSameArgumentAndResultTypes(op)) {
builderLines.push_back(llvm::formatv(
appendSameResultsTemplate, "operands[0].type", op.getNumResults()));
return;
}
if (hasFirstAttrDerivedResultTypes(op)) {
const NamedAttribute &firstAttr = op.getAttribute(0);
assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
"from which the type is derived");
appendLineByLine(
llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
builderLines);
builderLines.push_back(llvm::formatv(appendSameResultsTemplate,
"_ods_derived_result_type",
op.getNumResults()));
return;
}
if (hasInferTypeInterface(op))
return;
// For each element, find or generate a name.
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
const NamedTypeConstraint &element = op.getResult(i);
std::string name = names[i];
// Choose the formatting string based on the element kind.
llvm::StringRef formatString;
if (!element.isVariableLength()) {
formatString = singleResultAppendTemplate;
} else if (element.isOptional()) {
formatString = optionalAppendResultTemplate;
} else {
assert(element.isVariadic() && "unhandled element group type");
// If emitting with sizedSegments, then we add the actual list-typed
// element. Otherwise, we extend the actual operands.
if (sizedSegments) {
formatString = singleResultAppendTemplate;
} else {
formatString = multiResultAppendTemplate;
}
}
builderLines.push_back(llvm::formatv(formatString.data(), name));
}
}
/// If the operation has variadic regions, adds a builder argument to specify
/// the number of those regions and builder lines to forward it to the generic
/// constructor.
static void
populateBuilderRegions(const Operator &op,
llvm::SmallVectorImpl<std::string> &builderArgs,
llvm::SmallVectorImpl<std::string> &builderLines) {
if (op.hasNoVariadicRegions())
return;
// This is currently enforced when Operator is constructed.
assert(op.getNumVariadicRegions() == 1 &&
op.getRegion(op.getNumRegions() - 1).isVariadic() &&
"expected the last region to be varidic");
const NamedRegion ®ion = op.getRegion(op.getNumRegions() - 1);
std::string name =
("num_" + region.name.take_front().lower() + region.name.drop_front())
.str();
builderArgs.push_back(name);
builderLines.push_back(
llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
}
/// Emits a default builder constructing an operation from the list of its
/// result types, followed by a list of its operands.
static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
// If we are asked to skip default builders, comply.
if (op.skipDefaultBuilders())
return;
llvm::SmallVector<std::string> builderArgs;
llvm::SmallVector<std::string> builderLines;
llvm::SmallVector<std::string> operandArgNames;
llvm::SmallVector<std::string> successorArgNames;
builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
op.getNumNativeAttributes() + op.getNumSuccessors());
populateBuilderArgsResults(op, builderArgs);
size_t numResultArgs = builderArgs.size();
populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames);
size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
populateBuilderLinesOperand(op, operandArgNames, builderLines);
populateBuilderLinesAttr(
op, llvm::ArrayRef(builderArgs).drop_front(numResultArgs), builderLines);
populateBuilderLinesResult(
op, llvm::ArrayRef(builderArgs).take_front(numResultArgs), builderLines);
populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
populateBuilderRegions(op, builderArgs, builderLines);
// Layout of builderArgs vector elements:
// [ result_args operand_attr_args successor_args regions ]
// Determine whether the argument corresponding to a given index into the
// builderArgs vector is a python keyword argument or not.
auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool {
// All result, successor, and region arguments are positional arguments.
if ((builderArgIndex < numResultArgs) ||
(builderArgIndex >= (numResultArgs + numOperandAttrArgs)))
return false;
// Keyword arguments:
// - optional named attributes (including unit attributes)
// - default-valued named attributes
// - optional operands
Argument a = op.getArg(builderArgIndex - numResultArgs);
if (auto *nattr = llvm::dyn_cast_if_present<NamedAttribute *>(a))
return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue());
if (auto *ntype = llvm::dyn_cast_if_present<NamedTypeConstraint *>(a))
return ntype->isOptional();
return false;
};
// StringRefs in functionArgs refer to strings allocated by builderArgs.
llvm::SmallVector<llvm::StringRef> functionArgs;
// Add positional arguments.
for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
if (!isKeywordArgFn(i))
functionArgs.push_back(builderArgs[i]);
}
// Add a bare '*' to indicate that all following arguments must be keyword
// arguments.
functionArgs.push_back("*");
// Add a default 'None' value to each keyword arg string, and then add to the
// function args list.
for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
if (isKeywordArgFn(i)) {
builderArgs[i].append("=None");
functionArgs.push_back(builderArgs[i]);
}
}
functionArgs.push_back("loc=None");
functionArgs.push_back("ip=None");
SmallVector<std::string> initArgs;
initArgs.push_back("attributes=attributes");
if (!hasInferTypeInterface(op))
initArgs.push_back("results=results");
initArgs.push_back("operands=operands");
initArgs.push_back("successors=_ods_successors");
initArgs.push_back("regions=regions");
initArgs.push_back("loc=loc");
initArgs.push_back("ip=ip");
os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "),
llvm::join(builderLines, "\n "),
llvm::join(initArgs, ", "));
}
static void emitSegmentSpec(
const Operator &op, const char *kind,
llvm::function_ref<int(const Operator &)> getNumElements,
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
getElement,
raw_ostream &os) {
std::string segmentSpec("[");
for (int i = 0, e = getNumElements(op); i < e; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
if (element.isOptional()) {
segmentSpec.append("0,");
} else if (element.isVariadic()) {
segmentSpec.append("-1,");
} else {
segmentSpec.append("1,");
}
}
segmentSpec.append("]");
os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
}
static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
// Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
// Note that the base OpView class defines this as (0, True).
unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount,
op.hasNoVariadicRegions() ? "True" : "False");
}
/// Emits named accessors to regions.
static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
for (const auto &en : llvm::enumerate(op.getRegions())) {
const NamedRegion ®ion = en.value();
if (region.name.empty())
continue;
assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
"expected only the last region to be variadic");
os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name),
std::to_string(en.index()) +
(region.isVariadic() ? ":" : ""));
}
}
/// Emits bindings for a specific Op to the given output stream.
static void emitOpBindings(const Operator &op, raw_ostream &os) {
os << llvm::formatv(opClassTemplate, op.getCppClassName(),
op.getOperationName());
// Sized segments.
if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os);
}
if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) {
emitSegmentSpec(op, "RESULT", getNumResults, getResult, os);
}
emitRegionAttributes(op, os);
emitDefaultOpBuilder(op, os);
emitOperandAccessors(op, os);
emitAttributeAccessors(op, os);
emitResultAccessors(op, os);
emitRegionAccessors(op, os);
}
/// Emits bindings for the dialect specified in the command line, including file
/// headers and utilities. Returns `false` on success to comply with Tablegen
/// registration requirements.
static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
if (clDialectName.empty())
llvm::PrintFatalError("dialect name not provided");
bool isExtension = !clDialectExtensionName.empty();
os << llvm::formatv(fileHeader, isExtension
? clDialectExtensionName.getValue()
: clDialectName.getValue());
if (isExtension)
os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue());
else
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
Operator op(rec);
if (op.getDialectName() == clDialectName.getValue())
emitOpBindings(op, os);
}
return false;
}
static GenRegistration
genPythonBindings("gen-python-op-bindings",
"Generate Python bindings for MLIR Ops", &emitAllOps);
|