1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187
|
#include <ATen/core/symbol.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/constants.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/integer_value_refinement.h>
#include <torch/csrc/jit/passes/loop_unrolling.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
#include <torch/csrc/jit/passes/peephole.h>
#include <torch/csrc/jit/passes/peephole_list_idioms.h>
#include <torch/csrc/jit/passes/peephole_non_tensor.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
#include <torch/csrc/jit/passes/symbolic_shape_cache.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/runtime/exception_message.h>
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
#include <torch/csrc/utils/memory.h>
#include <algorithm>
#include <memory>
#include <numeric>
#include <unordered_map>
#include <vector>
/*
XXX: this is still in prototype phase and has much work left to do, including
but not limited to:
- Refactor APIs
- Add decent coverage of common ops
- Add shape analysis pass on Graph that handles Loops
- Allow concurrent reads to the operator map
- Supporting returning partially evaluated shape compute graph
*/
static bool symbolic_shape_analysis_test_mode = false;
namespace torch {
namespace jit {
// This is similar to c10::SymbolicShape, but instead of either having
// a concrete dimension or a symbolic dimension, an argument may be:
// - A Symbolic Dimension
// - A Constant Integer
// - Neither of the above. The third case can occur due to inputs to
// ops like view that accept negative values. Maintaining the distinction
// between an unknown symbolic dimension and an unknown integer allows
// us to optimize out comparisons to values < 0 (symbolic shapes are always >=
// 0) For example, a call like graph(%y: Tensor(SS(-1), 10, 10), %inp: int):
// %five: int = prim::Constant[value=5]()
// %zero: int = prim::Constant[value=0]()
// %1 : int = aten::size(%y, %zero)
// %2 : int[] = prim::ListConstruct(%five, %1, %inp)
// %y.2: Tensor(5, SS(-1), (New Symbolic Shape)) = aten::view(%y, %2)
//
// x.view([5, y.size(0), inp])
// will have inputs equal to [5, SS(-1), c10::nullopt]
struct ShapeArg
: public std::
pair<c10::optional<c10::ShapeSymbol>, c10::optional<int64_t>> {
using pair::pair;
static ShapeArg unknownInteger() {
return ShapeArg();
}
ShapeArg(int64_t int_value) {
this->first = c10::nullopt;
this->second = int_value;
}
ShapeArg(c10::ShapeSymbol ss) {
if (ss.is_static()) {
this->first = c10::nullopt;
this->second = ss.value();
} else {
this->first = ss;
this->second = c10::nullopt;
}
}
c10::optional<int64_t> asConstantInt() const {
return this->second;
}
c10::optional<c10::ShapeSymbol> asShapeSymbol() const {
return this->first;
}
private:
ShapeArg() {
this->first = c10::nullopt;
this->second = c10::nullopt;
}
};
std::ostream& operator<<(std::ostream& out, const ShapeArg& sa) {
if (auto val = sa.asConstantInt()) {
out << *val;
} else if (auto ss = sa.asShapeSymbol()) {
out << *ss;
} else {
out << "UNK";
}
return out;
}
struct ShapeArguments {
// Superset of SymbolicShape, with additional support for unknown, nonsymbolic
// vals
public:
ShapeArguments(const c10::SymbolicShape& ss) {
has_dim_ = ss.rank().has_value();
if (has_dim_) {
for (size_t i = 0; i < *ss.rank(); ++i) {
maybe_shape_symbols_.emplace_back(ss.at(i));
}
}
}
ShapeArguments(std::vector<ShapeArg> ss)
: has_dim_(true), maybe_shape_symbols_(std::move(ss)) {}
bool has_dim() const {
return has_dim_;
}
int64_t len() const {
TORCH_INTERNAL_ASSERT(has_dim_, "ShapeArguments has no known dim")
return (int64_t)maybe_shape_symbols_.size();
}
const ShapeArg at(size_t i) const {
TORCH_INTERNAL_ASSERT(has_dim_, "ShapeArguments has no known dim")
return maybe_shape_symbols_.at(i);
}
private:
bool has_dim_;
std::vector<ShapeArg> maybe_shape_symbols_;
};
std::ostream& operator<<(std::ostream& os, const ShapeArguments& sa) {
if (!sa.has_dim()) {
os << "(UNKNOWN DIM)";
return os;
}
os << "(";
for (size_t i = 0; i < sa.len(); i++) {
os << sa.at(i);
}
os << ")";
return os;
}
bool setSymbolicShapeAnalysisTestMode(bool value) {
bool old_value = symbolic_shape_analysis_test_mode;
symbolic_shape_analysis_test_mode = value;
return old_value;
}
bool symbolicShapeAnalysisTestModeEnabled() {
return symbolic_shape_analysis_test_mode;
}
using SSArgument = c10::variant<ShapeArguments, IValue>;
std::ostream& operator<<(std::ostream& out, const SSArgument& sa) {
if (const IValue* iv = c10::get_if<IValue>(&sa)) {
out << *iv;
} else {
out << c10::get<ShapeArguments>(sa);
}
return out;
}
namespace {
bool isListOfInts(const TypePtr& type) {
return type->cast<ListType>() &&
type->cast<ListType>()->getElementType()->cast<IntType>();
}
bool isListOfListOfInts(const TypePtr& type) {
// Allows List[Optional[List[Int]]]
if (!type->cast<ListType>()) {
return false;
}
TypePtr element_type = type->cast<ListType>()->getElementType();
if (element_type->cast<OptionalType>()) {
element_type = element_type->cast<OptionalType>()->getElementType();
}
return isListOfInts(element_type);
}
bool isListOfTensors(const TypePtr& type) {
return type->cast<ListType>() &&
type->cast<ListType>()->getElementType()->cast<TensorType>();
}
c10::optional<size_t> normIndex(int64_t index, size_t len) {
if (index < 0) {
index = index + len;
}
if (index >= 0 && index < static_cast<int64_t>(len)) {
return index;
} else {
return c10::nullopt;
}
}
bool shapeGraphCleanupPasses(std::shared_ptr<Graph> graph) {
// TODO: lower simple tuples ?
bool made_change = RemoveListMutation(graph);
made_change |= UnrollConstantLoops(graph);
made_change |= ConstantPropagation(graph);
made_change |= PeepholeOptimizeNonTensor(graph);
made_change |= PeepholeOptimizeListIdioms(graph, /*refine_list_len*/ true);
made_change |= RefineIntegerValues(graph);
made_change |= ConstantPropagation(graph);
// todo add return change for constant pooling
ConstantPooling(graph);
made_change |= EliminateCommonSubexpression(graph);
EliminateDeadCode(graph);
return made_change;
}
void replaceWithIValue(Value* v, IValue val) {
WithInsertPoint guard(*v->node()->owningBlock()->nodes().begin());
v->replaceAllUsesWith(v->owningGraph()->insertConstant(val));
}
c10::SymbolicShape extractListShape(
Value* list,
std::unordered_map<Value*, int64_t>& symbolic_shape_values,
const AliasDb& db) {
if (list->node()->kind() == prim::Constant) {
auto int_list = toIValue(list)->toIntVector();
return c10::SymbolicShape(int_list);
}
// We need a list construct or a constant output
// that is not written to in order to analyze the output shape
if (list->node()->kind() != prim::ListConstruct || db.hasWriters(list)) {
GRAPH_DEBUG("Could not extract shape");
return c10::SymbolicShape();
}
Node* list_construct = list->node();
std::vector<c10::optional<int64_t>> output_shape;
for (Value* input : list_construct->inputs()) {
if (symbolic_shape_values.count(input)) {
output_shape.emplace_back(symbolic_shape_values[input]);
} else {
output_shape.push_back(constant_as<int64_t>(input));
}
}
return c10::SymbolicShape(output_shape);
}
// Symbolic Shape Analysis works through iteratively partially evaluating
// a TorchScript shape compute graph by inputing properties from input
// Tensors. We can substitute in properties like `len(x)` and `x[1]`
// if they are statically on the input Tensors. We can also use
// assertions like `assert len(x) == 4` in order to refine the input
// length and unroll loops over its elements. We iteratively optimize and
// substitute in properties until we are unable to make any further
// optimizations. Finally, we try to extract Tensor properties from the output.
// For instance `return [1, 2, inp[2] + 1, inp[3]]` we know that the ouptut
// will be length 4 with first two dimensions equal to 1 and 2. We can also
// deduce that the 4th dimension has the same symbolic shape as inp[3], which
// means that we do know its concrete value statically but we can asssign sets
// of tensor dimensions which must be equal at runtime.
struct SymbolicShapeOpAnalyzer {
std::shared_ptr<Graph> shape_compute_graph_;
const FunctionSchema* schema_;
std::vector<SSArgument> inputs_;
// For the case where we have a JIT graph,
// subsititute optional types for their component types
// if the type is known. This doesn't need to be done
// for known IValues.
void refineInputUnionTypes(const Node* parent_graph_node) {
for (size_t op_in_index = 0;
op_in_index < shape_compute_graph_->inputs().size();
op_in_index++) {
auto type = parent_graph_node->input(op_in_index)->type();
if (auto opt_type = shape_compute_graph_->inputs()
.at(op_in_index)
->type()
->cast<OptionalType>()) {
// None will get handled with constant substitution later
if (!type->cast<OptionalType>() &&
!NoneType::get()->isSubtypeOf(*type)) {
shape_compute_graph_->inputs()
.at(op_in_index)
->setType(opt_type->getElementType());
}
} else if (shape_compute_graph_->inputs()
.at(op_in_index)
->type()
->cast<NumberType>()) {
shape_compute_graph_->inputs().at(op_in_index)->setType(type);
}
}
}
// We handle non-constant values in the shape propagation step
void substituteConstantInputs() {
if (shape_compute_graph_->inputs().size() == 0) {
return;
}
bool seen_tensor_list = false;
size_t op_in_index = 0;
while (op_in_index < shape_compute_graph_->inputs().size()) {
Value* graph_in_var = shape_compute_graph_->inputs().at(op_in_index);
if (!isListOfListOfInts(graph_in_var->type())) {
op_in_index++;
continue;
}
// Modifying the graph where _node is part of to not use the tensor
// construct
// When we have partially evaluate a list of Tensors like cat(tensor[])
// We have a few problems:
// - optimizing out calls to the length of the list: len(tensors)
// - resolving accesses of the list to the tensor symbolic sizes the
// corresponding list element We can solve both of these problems by
// replacing the partial evaluation of cat([x, y]) def cat(tensors:
// List[List[int]], dim: int)
// body
// with
// def cat(x, y, dim: int)
// tensors = [x, y]
// body
TORCH_INTERNAL_ASSERT(
!seen_tensor_list,
"SSA doesn't handle case with multiple tensor lists")
seen_tensor_list = true;
uint64_t li_length = inputs_.size() - (schema_->arguments().size() - 1);
std::vector<Value*> li_inputs;
TypePtr element_type =
graph_in_var->type()->cast<ListType>()->getElementType();
for (size_t j = op_in_index; j < op_in_index + li_length; ++j) {
auto new_inp = shape_compute_graph_->insertInput(op_in_index + j);
new_inp->setType(element_type);
li_inputs.push_back(new_inp);
}
WithInsertPoint guard(*shape_compute_graph_->block()->nodes().begin());
auto new_li = shape_compute_graph_->insertNode(
shape_compute_graph_->createList(element_type, li_inputs));
graph_in_var->replaceAllUsesWith(new_li->output());
shape_compute_graph_->eraseInput(op_in_index + li_length);
}
TORCH_INTERNAL_ASSERT(
shape_compute_graph_->inputs().size() <= inputs_.size(),
"Shape Compute Graph expected to have less inputs than actual inputs"); //?
for (size_t op_in_index = 0;
op_in_index < shape_compute_graph_->inputs().size();
op_in_index++) {
SSArgument& argument = inputs_[op_in_index];
Value* graph_in_var = shape_compute_graph_->inputs().at(op_in_index);
if (IValue* cur_val = c10::get_if<IValue>(&argument)) {
GRAPH_DEBUG("Substituting constant input ", *cur_val);
replaceWithIValue(graph_in_var, *cur_val);
} else {
auto cur_arg = c10::get<ShapeArguments>(argument);
if (cur_arg.has_dim()) {
graph_in_var->setType(ListType::ofInts());
}
}
}
}
void substituteSymbolicProperties(
std::unordered_map<Value*, int64_t>* symbolic_shape_values) {
// clang-format off
// here we iteratively substitute properties of the node's input tensors
// into the shape compute graph. we can substitute constants into the
// like len(inp) or inp[0] if the tensor has a fixed length or a fixed
// first dimension. we also try to resolve symbolic shapes of the same
// symbolic value to the same Value * in the shape compute graph.
// for the shape logic:
// dim1 = inp1[0]
// dim2 = inp2[0]
// return dim1 if dim2 == 1 else dim2
// if we see that inp1[0] and inp2[0] both have the same symbolic shape
// value, then it is a valid transformation to replace dim2 with dim1 or
// vice versa. to do this we collect all Value * for a particular symbolic
// shape. Then, we replace all Value * within that set with their dominator.
// In the example above, this allows us to infer that the output will be the
// symbolic dimension value of dim1.
// if `symbolic_shape_values` is not null, record list accesses
// which resolve to symbolic dimension values with their concrete symbolic
// shape value. Because symbolic dimensions are represented as negative numbers and
// are not real values, inserting them as constants in the graph would invalidate
// the graph for further use. Instead, we keep track of what their value would be
// for extracting output shapes.
// clang-format on
std::unordered_map<int64_t, std::vector<Value*>> symbolic_shape_map;
TORCH_INTERNAL_ASSERT(
inputs_.size() >= shape_compute_graph_->inputs().size(),
"Missing Arg for Shape Graph");
for (int64_t index = 0; index < shape_compute_graph_->inputs().size();
index++) {
auto shape_arguments = c10::get_if<ShapeArguments>(&inputs_[index]);
if (!shape_arguments || !shape_arguments->has_dim()) {
continue;
}
// Add support for testing symbolic shapes with dynamic dims
for (const Use& use : shape_compute_graph_->inputs().at(index)->uses()) {
// TODO: either decompose composite ops like slice or add handling here
switch (use.user->kind()) {
case aten::len: {
size_t len = shape_arguments->len();
replaceWithIValue(use.user->output(), static_cast<int64_t>(len));
} break;
case aten::__getitem__: {
auto index = constant_as<int64_t>(use.user->inputs().at(1));
if (!index) {
continue;
}
auto norm_index = normIndex(*index, shape_arguments->len());
if (!norm_index) {
continue;
}
auto shape_arg = shape_arguments->at(*norm_index);
if (auto const_int = shape_arg.asConstantInt()) {
replaceWithIValue(use.user->output(), const_int);
continue;
}
auto maybe_shape_symbol = shape_arg.asShapeSymbol();
if (!maybe_shape_symbol) {
continue;
}
auto shape_symbol = *maybe_shape_symbol;
if (symbolic_shape_values) {
symbolic_shape_values->emplace(
use.user->output(), shape_symbol.value());
} else {
int64_t symbolic_index = shape_symbol.value();
symbolic_shape_map[symbolic_index].push_back(use.user->output());
}
for (const auto& sym_uses : use.user->output()->uses()) {
auto k = sym_uses.user->kind();
if (k != aten::ge && k != aten::le && k != aten::ne &&
k != aten::eq && k != aten::lt && k != aten::gt) {
break;
}
auto other_index = 1 - sym_uses.offset;
auto other_value =
constant_as<int64_t>(sym_uses.user->input(other_index));
if (!other_value) {
continue;
}
// check for dim >= 0, 0 <= dim
// dim >= 0
if (k == aten::ge && *other_value == 0 && other_index == 1) {
replaceWithIValue(sym_uses.user->output(), true);
continue;
}
// 0 <= dim
if (k == aten::le && *other_value == 0 && other_index == 0) {
replaceWithIValue(sym_uses.user->output(), true);
continue;
}
// check for dim comparisons to negative number
if (*other_value >= 0) {
continue;
}
if (k == aten::eq || k == aten::ne) {
// True if:
// -2 != {Positive}
replaceWithIValue(sym_uses.user->output(), k == aten::ne);
} else {
// True if:
// -2 <= / < {Positive}
// {Positive} >= / > {-2}
bool true_val =
((other_index == 0 && (k == aten::le || k == aten::lt)) ||
(other_index == 1 && (k == aten::ge || k == aten::gt)));
replaceWithIValue(sym_uses.user->output(), true_val);
}
}
}
}
}
for (const auto& symbolic_set : symbolic_shape_map) {
mergeSymbolicShapeSets(symbolic_set.second);
}
}
}
void mergeSymbolicShapeSets(const std::vector<Value*>& symbolic_set) {
// `symbolic_set` represents a set of Value * which are all equal
// to each other. Here, we optimize the graph by replacing values
// in the set with other dominating values.
// in the following example, where a, b and c are all in the same
// symbolic set:
// if cond:
// a = li[0]
// b = li[1]
// return [a, b]
// else:
// c = li[0]
// return [c, c]
// we can replace `b` with `a` because it is dominated by `a`,
// but we cannot replace `c` with another dominating value
// there are ways to compute this more efficiently but typically number of
// Values for each symbolic set is low and this is cheap to run
for (const auto i : c10::irange(symbolic_set.size())) {
Value* v = symbolic_set[i];
Value* dominating_value = v;
for (const auto& sym_set : symbolic_set) {
if (dominating_value->node()->isDominatedBy(sym_set->node())) {
dominating_value = sym_set;
}
}
if (dominating_value != v) {
v->replaceAllUsesWith(dominating_value);
}
}
}
std::vector<c10::SymbolicShape> propagateShapesInGraph() {
bool made_change = true;
constexpr size_t MAX_ATTEMPTS = 8;
for (int attempt_num = 0; made_change && attempt_num < MAX_ATTEMPTS;
attempt_num++) {
// symbolic shape concrete values are only used in final shape extraction
GRAPH_DUMP("Before substitution: ", shape_compute_graph_);
substituteSymbolicProperties(/*symbolic_shape_values*/ nullptr);
GRAPH_DUMP("Before Opt: ", shape_compute_graph_);
made_change = shapeGraphCleanupPasses(shape_compute_graph_);
}
std::unordered_map<Value*, int64_t> symbolic_shape_values;
substituteSymbolicProperties(&symbolic_shape_values);
GRAPH_DUMP("Done with partial evaluation", shape_compute_graph_);
return extractOutputShape(symbolic_shape_values);
}
std::vector<c10::SymbolicShape> extractOutputShape(
std::unordered_map<Value*, int64_t>& symbolic_shape_values) {
TORCH_INTERNAL_ASSERT(
shape_compute_graph_->outputs().size() == schema_->returns().size());
// TODO: would be nice if there were easy facility to look at uses and see
// if they are all pure instead of instanting db.
auto res = std::vector<c10::SymbolicShape>();
AliasDb db(shape_compute_graph_);
for (size_t i = 0; i < shape_compute_graph_->outputs().size(); ++i) {
auto output = shape_compute_graph_->outputs().at(i);
auto type = output->type();
TORCH_INTERNAL_ASSERT(isListOfInts(type));
c10::SymbolicShape ss =
extractListShape(output, symbolic_shape_values, db);
GRAPH_DEBUG("Extracted Output: ", ss);
res.push_back(ss);
}
return res;
}
public:
SymbolicShapeOpAnalyzer(const FunctionSchema* schema) : schema_(schema) {
shape_compute_graph_ = nullptr;
if (!schema_) {
return;
}
auto maybe_graph = shapeComputeGraphForSchema(*schema_);
if (!maybe_graph) {
return;
}
shape_compute_graph_ = (*maybe_graph)->copy();
}
SymbolicShapeOpAnalyzer(
const FunctionSchema* schema,
std::shared_ptr<Graph> graph)
: schema_(schema) {
shape_compute_graph_ = graph->copy();
}
c10::optional<std::vector<c10::SymbolicShape>> run(
std::vector<SSArgument>& inputs) {
if (!shape_compute_graph_) {
return c10::nullopt;
}
inputs_ = inputs;
substituteConstantInputs();
GRAPH_DEBUG(inputs_)
return propagateShapesInGraph();
}
std::shared_ptr<Graph> getShapeComputeGraph() {
return shape_compute_graph_;
}
};
SSArgument tensorShapeArg(Value* tensor_v) {
auto tt = tensor_v->type()->expect<TensorType>();
c10::SymbolicShape symbolic_shapes = tt->symbolic_sizes();
// for testing, we don't insert complete tensor shapes and rely on our
// partial evaluation pipeline to propagate information.
// this is a good proxy for our ability to propagate non-complete shape
// information.
if (symbolic_shapes.isComplete() && !symbolic_shape_analysis_test_mode) {
return IValue(tt->sizes().concrete_sizes());
}
if (toIValue(tensor_v)) {
auto size = constant_as<at::Tensor>(tensor_v)->sizes();
if (!symbolic_shape_analysis_test_mode) {
return IValue(size);
} else {
return c10::SymbolicShape(size);
}
}
return symbolic_shapes;
}
std::vector<SSArgument> getNodeInputShapes(Node* n, const AliasDb& db) {
// TODO: fix the List of integers implementation, and
// extract out the shape changes, otherwise this is complete
// NB: shape compute graphs may have less inputs than their node
// counterparts to allow e.g. sharing one single unary definition
// so iterate on # of shape inputs
// We make lists of Tensor inputs variadic, which results in
// offset between a node index and its corresponding graph index
std::vector<SSArgument> input_shapes = std::vector<SSArgument>();
for (size_t node_index = 0; node_index < n->inputs().size(); ++node_index) {
auto type = n->input(node_index)->type();
if (type->castRaw<TensorType>()) {
input_shapes.push_back(tensorShapeArg(n->input(node_index)));
continue;
}
if (isListOfTensors(type)) {
// waiting for more use cases to decide on best generalization
if (n->input(node_index)->node()->kind() == prim::Constant) {
auto ival = toIValue(n->input(node_index));
for (const auto& ten : ival->toTensorVector()) {
input_shapes.emplace_back(c10::List<int64_t>(ten.sizes()));
}
} else if (
n->input(node_index)->node()->kind() == prim::ListConstruct &&
!db.hasWriters(n->input(node_index))) {
auto li_construct_node = n->input(node_index)->node();
for (size_t j = 0; j < li_construct_node->inputs().size(); ++j) {
input_shapes.push_back(tensorShapeArg(li_construct_node->input(j)));
}
} else {
TORCH_INTERNAL_ASSERT(false, "Unhandled List, we shouldn't get here");
}
continue;
}
if (auto ival = toIValue(n->input(node_index))) {
input_shapes.emplace_back(*ival);
continue;
}
if (type->cast<ListType>() &&
type->cast<ListType>()->getElementType()->cast<IntType>()) {
auto input_src_node = n->input(node_index)->node();
if (input_src_node->kind() == prim::ListConstruct &&
!db.hasWriters(n->input(node_index))) {
// it is a very common in graphs to see patterns like:
// z = x.view(y.size())
// or:
// z = x.view(1, 10, y.size(0), y.size(1))
// We want to propagate symbolic dimensions and concrete sizes
// from y to z. To do this we try to associate symbolic dimensions
// or concrete sizes with the integer list inputs that have a
// constructor taken from constants or y.size() or y.size(0)
auto list_construct = n->input(node_index)->node();
std::vector<ShapeArg> shape;
for (Value* v : list_construct->inputs()) {
if (auto constant = constant_as<int64_t>(v)) {
shape.emplace_back(*constant);
} else if (v->node()->kind() == aten::size) {
auto const_index = constant_as<int64_t>(v->node()->input(1));
auto tt = v->node()->input(0)->type()->expect<TensorType>();
auto ss = tt->symbolic_sizes();
if (!ss.rank() || !const_index) {
// if we are getting a size of a tensor, it is an unknown
// symbolic dimension instead of an unknown integer (must be
// >=0)
shape.emplace_back(at::ShapeSymbol::newSymbol());
continue;
}
auto norm_index = normIndex(*const_index, *ss.rank());
if (!norm_index) {
shape.emplace_back(at::ShapeSymbol::newSymbol());
continue;
}
shape.emplace_back(ss[*norm_index]);
} else {
shape.emplace_back(ShapeArg::unknownInteger());
}
}
input_shapes.emplace_back(ShapeArguments(shape));
continue;
}
if (input_src_node->kind() == aten::size &&
!db.hasWriters(n->input(node_index))) {
auto ten_inp = input_src_node->input();
auto ss = ten_inp->type()->expect<TensorType>()->symbolic_sizes();
input_shapes.emplace_back(ss);
continue;
}
}
GRAPH_DEBUG(
"Unhandled input: ",
n->kind().toDisplayString(),
" arg num: ",
node_index);
input_shapes.emplace_back(c10::SymbolicShape());
}
TORCH_INTERNAL_ASSERT(
input_shapes.size() >= n->inputs().size(),
"input_shapes size: ",
input_shapes.size(),
" n inputs size: ",
n->inputs().size());
return input_shapes;
}
void applyOutputShapeToGraph(
Node* node,
const std::vector<c10::SymbolicShape>& output_shapes) {
TORCH_INTERNAL_ASSERT(
node->outputs().size() == output_shapes.size(),
"Output shape size mismatch");
for (size_t i = 0; i < output_shapes.size(); ++i) {
auto& ss = output_shapes.at(i);
node->output(i)->setType(
node->output(i)->type()->expect<TensorType>()->withSymbolicShapes(ss));
}
}
std::shared_ptr<Graph> PropagateShapesWithShapeFunction(
Node* n,
const AliasDb& db) {
const FunctionSchema* func_schema = n->maybeSchema();
if (!func_schema) {
return nullptr;
}
auto op_analyzer = SymbolicShapeOpAnalyzer(func_schema);
if (!op_analyzer.getShapeComputeGraph()) {
return nullptr;
}
auto input_shapes = getNodeInputShapes(n, db);
op_analyzer.refineInputUnionTypes(n);
if (auto output_shapes = op_analyzer.run(input_shapes)) {
applyOutputShapeToGraph(n, *output_shapes);
}
return op_analyzer.getShapeComputeGraph();
}
c10::SymbolicShape combine_bounds(
c10::SymbolicShape& lower_bound,
c10::SymbolicShape& upper_bound) {
// TODO: At some point we might want to add support for dynamic dims
TORCH_INTERNAL_ASSERT(lower_bound.rank() == upper_bound.rank());
if (lower_bound.rank() == c10::nullopt) {
return c10::SymbolicShape();
}
std::vector<c10::ShapeSymbol> merged_shapes;
for (int i = 0; i < lower_bound.rank(); i++) {
// TODO: Merge equivalent expressions (not needed for current use case)
if (lower_bound[i] == upper_bound[i]) {
merged_shapes.push_back(lower_bound[i]);
} else {
merged_shapes.push_back(c10::ShapeSymbol::newSymbol());
}
}
return c10::SymbolicShape(merged_shapes);
}
struct SymbolicShapeGraphAnalyzer {
SymbolicShapeGraphAnalyzer(
std::shared_ptr<Graph>& graph,
Node* beg,
Node* end)
: graph_(graph), beg_(beg), end_(end) {
TORCH_INTERNAL_ASSERT(
beg_->owningBlock() == end_->owningBlock() && end_->isAfter(beg_));
}
c10::optional<ShapeComputeGraphMapping> run() {
AliasDb db(graph_);
std::unordered_map<Node*, std::shared_ptr<Graph>> partial_evaluated_graphs =
propagateShapesAndGatherPartialEvalShapeGraphs(db);
auto stitched_shape_compute_graph = std::make_shared<Graph>();
// We want to build up a computational graph which computes all shapes
// we dont know statically - that is, all symbolic shapes within
// the region [beg, end). it must be executable before beg.
// TODO: dont require dimensions of tensors to be set AOT ?
for (auto it = beg_->iterator(); it != end_->iterator(); it++) {
auto curr = *it;
if (curr->kind() == prim::Constant) {
continue;
}
// TODO: generalize logic to for other tensor input ops when they are
// added
if (curr->kind() == prim::ListConstruct) {
auto uses = curr->output()->uses();
if (!std::all_of(uses.begin(), uses.end(), [](const Use& use) {
return use.user->kind() == aten::cat;
})) {
GRAPH_DEBUG("Non cat list use ", getHeader(curr));
return c10::nullopt;
}
continue;
}
if (!partial_evaluated_graphs.count(curr)) {
GRAPH_DEBUG("No graph ", getHeader(curr));
return c10::nullopt;
}
auto outputs = curr->outputs();
for (Value* v : outputs) {
auto tt = v->type()->cast<TensorType>();
if (!tt) {
GRAPH_DEBUG("Non tensor node", getHeader(curr));
return c10::nullopt;
}
auto symbolic_sizes = tt->symbolic_sizes();
// TODO: dont require # of dimensions of tensors set ?
if (!symbolic_sizes.rank()) {
GRAPH_DEBUG("No rank on output ", getHeader(curr));
return c10::nullopt;
}
}
auto partial_eval_graph = partial_evaluated_graphs[curr];
joinPartialEvaluatedShapeGraphToLargeShapeGraph(
curr, partial_eval_graph, stitched_shape_compute_graph);
}
size_t MAX_ITER = 8;
bool made_change = true;
size_t i = 0;
while (i < MAX_ITER && made_change) {
i++;
made_change = shapeGraphCleanupPasses(stitched_shape_compute_graph);
}
// for any output that is duplicated, the symbolic shape must be equal
// take the symbolic shape that is generated first and get equivalent ones
std::unordered_map<int64_t, int64_t> discovered_sym_shape_equalities;
std::unordered_map<Value*, int64_t> graph_output_to_symbolic_shape_dim;
std::vector<size_t> erase_indices;
for (size_t i = 0; i < stitched_shape_compute_graph->outputs().size();
++i) {
Value* output = stitched_shape_compute_graph->outputs().at(i);
// this Value is already contained, so the symbolic shape for i must be
// equal to the symbolic shape at the existing index
if (graph_output_to_symbolic_shape_dim.count(output)) {
auto curr_sym_shape = output_index_to_symbolic_shape_[i];
auto existing_sym_shape = graph_output_to_symbolic_shape_dim[output];
discovered_sym_shape_equalities[curr_sym_shape] = existing_sym_shape;
erase_indices.push_back(i);
} else {
graph_output_to_symbolic_shape_dim[output] =
output_index_to_symbolic_shape_[i];
}
}
for (int64_t i = erase_indices.size() - 1; i >= 0; i--) {
stitched_shape_compute_graph->eraseOutput(erase_indices[i]);
}
for (size_t i = 0; i < stitched_shape_compute_graph->inputs().size();) {
if (!stitched_shape_compute_graph->inputs().at(i)->hasUses()) {
enclosing_graph_value_to_shape_graph_input_.erase(
stitched_shape_compute_graph->inputs().at(i));
stitched_shape_compute_graph->eraseInput(i);
} else {
++i;
}
}
updateGraphWithSymbolicShapeEqualities(discovered_sym_shape_equalities);
return ShapeComputeGraphMapping(
stitched_shape_compute_graph,
enclosing_graph_value_to_shape_graph_input_,
graph_output_to_symbolic_shape_dim);
}
void updateGraphWithSymbolicShapeEqualities(
std::unordered_map<int64_t, int64_t>& sym_shape_equalities) {
for (auto it = beg_->iterator(); it != end_->iterator(); it++) {
auto curr = *it;
for (size_t i = 0; i < curr->outputs().size(); ++i) {
auto output = curr->output(i);
auto tt = output->type()->cast<TensorType>();
if (!tt || !tt->symbolic_sizes().rank()) {
continue;
}
bool changed = false;
std::vector<at::ShapeSymbol> shape_vec = *tt->symbolic_sizes().sizes();
auto new_sizes =
c10::fmap(shape_vec, [&](const at::ShapeSymbol& shape) {
auto value = shape.value();
if (sym_shape_equalities.count(value)) {
changed = true;
return sym_shape_equalities[value];
}
return value;
});
if (changed) {
output->setType(
tt->withSymbolicShapes(c10::SymbolicShape(new_sizes)));
}
}
}
}
void registerStitchedComputeOutput(
std::shared_ptr<Graph> stitched_shape_compute_graph,
Value* output,
int64_t symbolic_shape) {
stitched_shape_compute_graph->registerOutput(output);
output_index_to_symbolic_shape_
[stitched_shape_compute_graph->outputs().size() - 1] = symbolic_shape;
symbolic_shape_value_to_graph_output_[symbolic_shape] =
stitched_shape_compute_graph->outputs().at(
stitched_shape_compute_graph->outputs().size() - 1);
}
void joinPartialEvaluatedShapeGraphToLargeShapeGraph(
Node* curr,
std::shared_ptr<Graph> partial_eval_graph,
std::shared_ptr<Graph> stitched_shape_compute_graph) {
// we are building up the large shape compute graph by iteratively
// combining partially evaluated individual node shape graphs.
// We need to maintain two mappings, one from non-Tensor inputs in the
// enclosing graph to their equivalent mappings within the large shape
// compute graph, and one from symbolic shape dimension to new node output
// When we add a new tensor node, we do two things:
// 1: record a mapping from the tensor node output to its shape in the
// partial eval graph 2: add each symbolic shape dimension that we have
// not already added as a output to the large shape compute graph
// Once we are done stitching together all partial eval'd graphs, we can
// cleanup the graph and remove the unneeded complete shapes as outputs,
// leaving us only compute for calculating the runtime value of symbolic
// dimensions
// leaving us only compute for calculating the runtime value of symbolic
// dimensions
std::vector<Value*> node_inputs;
// TODO: generalize logic
if (curr->kind() == aten::cat) {
TORCH_INTERNAL_ASSERT(
curr->input(0)->node()->kind() == prim::ListConstruct);
for (Value* v : curr->input(0)->node()->inputs()) {
node_inputs.push_back(v);
}
node_inputs.push_back(curr->namedInput("dim"));
} else {
for (size_t i = 0; i < partial_eval_graph->inputs().size(); ++i) {
node_inputs.push_back(curr->input(i));
}
}
std::vector<Value*> partial_eval_inputs;
for (size_t i = 0; i < node_inputs.size(); ++i) {
auto node_input = node_inputs[i];
auto existing_graph_mapping =
enclosing_graph_value_to_shape_graph_input_.find(node_input);
if (existing_graph_mapping !=
enclosing_graph_value_to_shape_graph_input_.end()) {
partial_eval_inputs.push_back(existing_graph_mapping->second);
} else {
Value* shape_graph_input =
stitched_shape_compute_graph->addInput()->copyMetadata(
partial_eval_graph->inputs().at(i));
enclosing_graph_value_to_shape_graph_input_[node_input] =
shape_graph_input;
partial_eval_inputs.push_back(shape_graph_input);
}
// make sure all symbolic dimensions in the graph we are creating are
// computed in the partial eval graph
if (auto tt = node_input->type()->cast<TensorType>()) {
if (!tt->symbolic_sizes().rank()) {
continue;
}
auto rank = *tt->symbolic_sizes().rank();
for (size_t j = 0; j < rank; ++j) {
auto shape = tt->symbolic_sizes()[j];
if (shape.is_static() ||
symbolic_shape_value_to_graph_output_.count(shape.value())) {
continue;
}
auto input = enclosing_graph_value_to_shape_graph_input_[node_input];
WithInsertPoint guard(stitched_shape_compute_graph->block());
auto index = stitched_shape_compute_graph->insertConstant(
static_cast<int64_t>(j));
auto li_index = stitched_shape_compute_graph->insert(
aten::__getitem__, {input, index});
registerStitchedComputeOutput(
stitched_shape_compute_graph, li_index, shape.value());
}
}
}
WithInsertPoint guard(stitched_shape_compute_graph->block());
std::unordered_map<Value*, Value*> value_map;
insertGraph(
*stitched_shape_compute_graph,
*partial_eval_graph,
partial_eval_inputs,
value_map);
for (size_t i = 0; i < curr->outputs().size(); ++i) {
Value* new_list_output = value_map[partial_eval_graph->outputs().at(i)];
enclosing_graph_value_to_shape_graph_input_[curr->output(i)] =
new_list_output;
TORCH_INTERNAL_ASSERT(
new_list_output->node()->kind() == prim::ListConstruct ||
new_list_output->node()->kind() == prim::Constant);
TORCH_INTERNAL_ASSERT(!new_list_output->node()->hasUses());
auto symbolic_sizes =
curr->output(i)->type()->expect<TensorType>()->symbolic_sizes();
TORCH_INTERNAL_ASSERT(symbolic_sizes.rank());
for (size_t i = 0; i < *symbolic_sizes.rank(); i++) {
if (symbolic_sizes[i].is_static()) {
continue;
}
int64_t symbolic_shape = symbolic_sizes[i].value();
if (symbolic_shape_value_to_graph_output_.count(symbolic_shape)) {
continue;
}
registerStitchedComputeOutput(
stitched_shape_compute_graph,
new_list_output->node()->input(i),
symbolic_shape);
}
}
}
std::unordered_map<Node*, std::shared_ptr<Graph>>
propagateShapesAndGatherPartialEvalShapeGraphs(AliasDb& db) {
std::unordered_map<Node*, std::shared_ptr<Graph>> partial_evaluated_graphs;
for (auto it = beg_->iterator(); it != end_->iterator(); it++) {
auto curr = *it;
if (auto maybe_graph = PropagateShapesWithShapeFunction(curr, db)) {
partial_evaluated_graphs[curr] = maybe_graph;
}
}
return partial_evaluated_graphs;
}
std::unordered_map<Value*, Value*>
enclosing_graph_value_to_shape_graph_input_;
std::unordered_map<int64_t, Value*> symbolic_shape_value_to_graph_output_;
std::unordered_map<size_t, int64_t> output_index_to_symbolic_shape_;
std::shared_ptr<Graph>& graph_;
Node* beg_;
Node* end_;
};
void PropagateShapesOnBlock(Block* b, const AliasDb& db) {
for (Node* n : b->nodes()) {
// TODO: handle loop
if (n->kind() == prim::If) {
IfView if_v(n);
PropagateShapesOnBlock(if_v.thenBlock(), db);
PropagateShapesOnBlock(if_v.elseBlock(), db);
mergeTypes(if_v.thenOutputs(), if_v.elseOutputs(), if_v.outputs());
} else if (n->maybeSchema()) {
PropagateShapesWithShapeFunction(n, db);
} else if (n->kind() == prim::TupleConstruct) {
auto orig_type = n->output()->type()->expect<TupleType>();
auto new_types = fmap(n->inputs(), [](Value* v) { return v->type(); });
n->output()->setType(
orig_type->createWithContained(std::move(new_types)));
}
}
}
} // namespace
void PropagateShapesOnGraph(std::shared_ptr<Graph>& graph) {
AliasDb db(graph);
PropagateShapesOnBlock(graph->block(), db);
}
c10::optional<ShapeComputeGraphMapping>
PropagateShapesAndBuildLargeShapeComputeGraph(
std::shared_ptr<Graph>& graph,
Node* beg,
Node* end) {
return SymbolicShapeGraphAnalyzer(graph, beg, end).run();
}
TORCH_API c10::optional<std::vector<c10::SymbolicShape>>
calculateSymbolicShapesOnOp(
const FunctionSchema* schema,
const std::vector<SSAInput>& inputs) {
auto bounded_graphs = boundedGraphsForSchema(*schema);
auto has_shape_compute = shapeComputeGraphForSchema(*schema) != c10::nullopt;
if (!has_shape_compute && bounded_graphs == c10::nullopt) {
// Avoid doing all this work for functions that don't have a
// supported schema
return c10::nullopt;
}
if (auto cached_ret_vec = get_cached_shape_function(schema, inputs)) {
return cached_ret_vec;
}
std::vector<SSArgument> ssa_args;
for (auto& arg : inputs) {
if (const IValue* ival = c10::get_if<IValue>(&arg)) {
ssa_args.emplace_back(*ival);
} else {
const c10::SymbolicShape* ss = c10::get_if<c10::SymbolicShape>(&arg);
ssa_args.emplace_back(ShapeArguments(*ss));
}
}
// Handle bounded shape option
if (bounded_graphs) {
auto lower_bound =
SymbolicShapeOpAnalyzer(schema, bounded_graphs->lower_bound);
auto lower_bound_res = lower_bound.run(ssa_args);
auto upper_bound =
SymbolicShapeOpAnalyzer(schema, bounded_graphs->upper_bound);
auto upper_bound_res = upper_bound.run(ssa_args);
// Stitch together the values
if (lower_bound_res.has_value() && upper_bound_res.has_value()) {
TORCH_INTERNAL_ASSERT(lower_bound_res->size() == upper_bound_res->size());
auto merged_res = std::vector<c10::SymbolicShape>();
for (size_t i = 0; i < lower_bound_res->size(); i++) {
merged_res.push_back(
combine_bounds(lower_bound_res->at(i), upper_bound_res->at(i)));
}
cache_shape_function(schema, inputs, merged_res);
return merged_res;
}
return c10::nullopt;
}
auto op_analyzer = SymbolicShapeOpAnalyzer(schema);
auto res = op_analyzer.run(ssa_args);
if (res.has_value()) {
cache_shape_function(schema, inputs, res.value());
}
return res;
}
} // namespace jit
} // namespace torch
|