1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398
|
#include "simd_op_check.h"
#include "Halide.h"
#include <algorithm>
#include <iomanip>
#include <optional>
#include <regex>
#include <sstream>
#include <string>
#include <tuple>
#include <unordered_map>
using namespace Halide;
using namespace Halide::ConciseCasts;
using namespace std;
namespace {
using CastFuncTy = function<Expr(Expr)>;
class SimdOpCheckArmSve : public SimdOpCheckTest {
public:
SimdOpCheckArmSve(Target t, int w = 384, int h = 32)
: SimdOpCheckTest(t, w, h), debug_mode(Internal::get_env_variable("HL_DEBUG_SIMDOPCHECK")) {
// Determine and hold can_run_the_code
// TODO: Since features of Arm CPU cannot be obtained automatically from get_host_target(),
// it is necessary to set some feature (e.g. "arm_fp16") explicitly to HL_JIT_TARGET.
// Halide throws error if there is unacceptable mismatch between jit_target and host_target.
Target host = get_host_target();
Target jit_target = get_jit_target_from_environment();
cout << "host is: " << host.to_string() << endl;
cout << "HL_TARGET is: " << target.to_string() << endl;
cout << "HL_JIT_TARGET is: " << jit_target.to_string() << endl;
auto is_same_triple = [](const Target &t1, const Target &t2) -> bool {
return t1.arch == t2.arch && t1.bits == t2.bits && t1.os == t2.os && t1.vector_bits == t2.vector_bits;
};
can_run_the_code = is_same_triple(host, target) && is_same_triple(jit_target, target);
// A bunch of feature flags also need to match between the
// compiled code and the host in order to run the code.
for (Target::Feature f : {Target::ARMv7s, Target::ARMFp16, Target::NoNEON, Target::SVE2}) {
if (target.has_feature(f) != jit_target.has_feature(f)) {
can_run_the_code = false;
}
}
if (!can_run_the_code) {
cout << "[WARN] To perform verification of realization, "
<< R"(the target triple "arm-<bits>-<os>" and key feature "arm_fp16")"
<< " must be the same between HL_TARGET and HL_JIT_TARGET" << endl;
}
}
bool can_run_code() const override {
// If we can meet the condition about target, run the error checking Halide::Func.
return can_run_the_code;
}
void add_tests() override {
check_arm_integer();
check_arm_float();
check_arm_load_store();
check_arm_pairwise();
}
private:
void check_arm_integer() {
// clang-format off
vector<tuple<int, CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy,
CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy,
CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy>> test_params{
{8, in_i8, in_u8, in_f16, in_i16, in_u16, i8, i8_sat, i16, i8, i8_sat, u8, u8_sat, u16, u8, u8_sat},
{16, in_i16, in_u16, in_f16, in_i32, in_u32, i16, i16_sat, i32, i8, i8_sat, u16, u16_sat, u32, u8, u8_sat},
{32, in_i32, in_u32, in_f32, in_i64, in_u64, i32, i32_sat, i64, i16, i16_sat, u32, u32_sat, u64, u16, u16_sat},
{64, in_i64, in_u64, in_f64, in_i64, in_u64, i64, i64_sat, i64, i32, i32_sat, u64, u64_sat, u64, u32, u32_sat},
};
// clang-format on
for (const auto &[bits, in_i, in_u, in_f, in_i_wide, in_u_wide,
cast_i, satcast_i, widen_i, narrow_i, satnarrow_i,
cast_u, satcast_u, widen_u, narrow_u, satnarrow_u] : test_params) {
Expr i_1 = in_i(x), i_2 = in_i(x + 16), i_3 = in_i(x + 32);
Expr u_1 = in_u(x), u_2 = in_u(x + 16), u_3 = in_u(x + 32);
Expr i_wide_1 = in_i_wide(x), i_wide_2 = in_i_wide(x + 16);
Expr u_wide_1 = in_u_wide(x), u_wide_2 = in_u_wide(x + 16);
Expr f_1 = in_f(x);
// TODO: reconcile this comment and logic and figure out
// whether we're test 192 and 256 for NEON and which bit
// widths other that the target one for SVE2.
//
// In general neon ops have the 64-bit version, the 128-bit
// version (ending in q), and the widening version that takes
// 64-bit args and produces a 128-bit result (ending in l). We try
// to peephole match any with vector, so we just try 64-bits, 128
// bits, 192 bits, and 256 bits for everything.
std::vector<int> simd_bit_widths;
if (has_neon()) {
simd_bit_widths.push_back(64);
simd_bit_widths.push_back(128);
}
if (has_sve() && ((target.vector_bits > 128) || !has_neon())) {
simd_bit_widths.push_back(target.vector_bits);
}
for (auto &total_bits : simd_bit_widths) {
const int vf = total_bits / bits;
// Due to workaround for SVE LLVM issues, in case of vector of half length of natural_lanes,
// there is some inconsistency in generated SVE insturction about the number of lanes.
// So the verification of lanes is skipped for this specific case.
const int instr_lanes = (total_bits == 64 && has_sve()) ?
Instruction::ANY_LANES :
Instruction::get_instr_lanes(bits, vf, target);
const int widen_lanes = Instruction::get_instr_lanes(bits * 2, vf, target);
const int narrow_lanes = Instruction::get_instr_lanes(bits, vf * 2, target);
AddTestFunctor add_all(*this, bits, instr_lanes, vf);
AddTestFunctor add_all_vec(*this, bits, instr_lanes, vf, vf != 1);
AddTestFunctor add_8_16_32(*this, bits, instr_lanes, vf, bits != 64);
AddTestFunctor add_16_32_64(*this, bits, instr_lanes, vf, bits != 8);
AddTestFunctor add_16_32(*this, bits, instr_lanes, vf, bits == 16 || bits == 32);
AddTestFunctor add_32(*this, bits, instr_lanes, vf, bits == 32);
AddTestFunctor add_8_16_32_widen(*this, bits, widen_lanes, vf, bits != 64 && !has_sve());
AddTestFunctor add_16_32_64_narrow(*this, bits, narrow_lanes, vf * 2, bits != 8 && !has_sve());
AddTestFunctor add_16_32_narrow(*this, bits, narrow_lanes, vf * 2, (bits == 16 || bits == 32) && !has_sve());
AddTestFunctor add_16_narrow(*this, bits, narrow_lanes, vf * 2, bits == 16 && !has_sve());
// VABA I - Absolute Difference and Accumulate
if (!has_sve()) {
// Relying on LLVM to detect accumulation
add_8_16_32(sel_op("vaba.s", "saba"), i_1 + absd(i_2, i_3));
add_8_16_32(sel_op("vaba.u", "uaba"), u_1 + absd(u_2, u_3));
}
// VABAL I - Absolute Difference and Accumulate Long
add_8_16_32_widen(sel_op("vabal.s", "sabal"), i_wide_1 + absd(i_2, i_3));
add_8_16_32_widen(sel_op("vabal.u", "uabal"), u_wide_1 + absd(u_2, u_3));
// VABD I, F - Absolute Difference
add_8_16_32(sel_op("vabd.s", "sabd"), absd(i_2, i_3));
add_8_16_32(sel_op("vabd.u", "uabd"), absd(u_2, u_3));
// Via widening, taking abs, then narrowing
add_8_16_32(sel_op("vabd.s", "sabd"), cast_u(abs(widen_i(i_2) - i_3)));
add_8_16_32(sel_op("vabd.u", "uabd"), cast_u(abs(widen_i(u_2) - u_3)));
// VABDL I - Absolute Difference Long
add_8_16_32_widen(sel_op("vabdl.s", "sabdl"), widen_i(absd(i_2, i_3)));
add_8_16_32_widen(sel_op("vabdl.u", "uabdl"), widen_u(absd(u_2, u_3)));
// Via widening then taking an abs
add_8_16_32_widen(sel_op("vabdl.s", "sabdl"), abs(widen_i(i_2) - widen_i(i_3)));
add_8_16_32_widen(sel_op("vabdl.u", "uabdl"), abs(widen_i(u_2) - widen_i(u_3)));
// VABS I, F F, D Absolute
add_8_16_32(sel_op("vabs.s", "abs"), abs(i_1));
// VADD I, F F, D Add
add_all_vec(sel_op("vadd.i", "add"), i_1 + i_2);
add_all_vec(sel_op("vadd.i", "add"), u_1 + u_2);
// VADDHN I - Add and Narrow Returning High Half
add_16_32_64_narrow(sel_op("vaddhn.i", "addhn"), narrow_i((i_1 + i_2) >> (bits / 2)));
add_16_32_64_narrow(sel_op("vaddhn.i", "addhn"), narrow_u((u_1 + u_2) >> (bits / 2)));
// VADDL I - Add Long
add_8_16_32_widen(sel_op("vaddl.s", "saddl"), widen_i(i_1) + widen_i(i_2));
add_8_16_32_widen(sel_op("vaddl.u", "uaddl"), widen_u(u_1) + widen_u(u_2));
// VADDW I - Add Wide
add_8_16_32_widen(sel_op("vaddw.s", "saddw"), i_1 + i_wide_1);
add_8_16_32_widen(sel_op("vaddw.u", "uaddw"), u_1 + u_wide_1);
// VAND X - Bitwise AND
// Not implemented in front-end yet
// VBIC I - Bitwise Clear
// VBIF X - Bitwise Insert if False
// VBIT X - Bitwise Insert if True
// skip these ones
// VCEQ I, F - Compare Equal
add_8_16_32(sel_op("vceq.i", "cmeq", "cmpeq"), select(i_1 == i_2, cast_i(1), cast_i(2)));
add_8_16_32(sel_op("vceq.i", "cmeq", "cmpeq"), select(u_1 == u_2, cast_u(1), cast_u(2)));
#if 0
// VCGE I, F - Compare Greater Than or Equal
// Halide flips these to less than instead
check("vcge.s8", 16, select(i8_1 >= i8_2, i8(1), i8(2)));
check("vcge.u8", 16, select(u8_1 >= u8_2, u8(1), u8(2)));
check("vcge.s16", 8, select(i16_1 >= i16_2, i16(1), i16(2)));
check("vcge.u16", 8, select(u16_1 >= u16_2, u16(1), u16(2)));
check("vcge.s32", 4, select(i32_1 >= i32_2, i32(1), i32(2)));
check("vcge.u32", 4, select(u32_1 >= u32_2, u32(1), u32(2)));
check("vcge.f32", 4, select(f32_1 >= f32_2, 1.0f, 2.0f));
check("vcge.s8", 8, select(i8_1 >= i8_2, i8(1), i8(2)));
check("vcge.u8", 8, select(u8_1 >= u8_2, u8(1), u8(2)));
check("vcge.s16", 4, select(i16_1 >= i16_2, i16(1), i16(2)));
check("vcge.u16", 4, select(u16_1 >= u16_2, u16(1), u16(2)));
check("vcge.s32", 2, select(i32_1 >= i32_2, i32(1), i32(2)));
check("vcge.u32", 2, select(u32_1 >= u32_2, u32(1), u32(2)));
check("vcge.f32", 2, select(f32_1 >= f32_2, 1.0f, 2.0f));
#endif
// VCGT I, F - Compare Greater Than
add_8_16_32(sel_op("vcgt.s", "cmgt", "cmpgt"), select(i_1 > i_2, cast_i(1), cast_i(2)));
add_8_16_32(sel_op("vcgt.u", "cmhi", "cmphi"), select(u_1 > u_2, cast_u(1), cast_u(2)));
#if 0
// VCLS I - Count Leading Sign Bits
// We don't currently match these, but it wouldn't be hard to do.
check(arm32 ? "vcls.s8" : "cls", 8 * w, max(count_leading_zeros(i8_1), count_leading_zeros(~i8_1)));
check(arm32 ? "vcls.s16" : "cls", 8 * w, max(count_leading_zeros(i16_1), count_leading_zeros(~i16_1)));
check(arm32 ? "vcls.s32" : "cls", 8 * w, max(count_leading_zeros(i32_1), count_leading_zeros(~i32_1)));
#endif
// VCLZ I - Count Leading Zeros
add_8_16_32(sel_op("vclz.i", "clz"), count_leading_zeros(i_1));
add_8_16_32(sel_op("vclz.i", "clz"), count_leading_zeros(u_1));
// VCMP - F, D Compare Setting Flags
// We skip this
// VCNT I - Count Number of Set Bits
if (!has_sve()) {
// In NEON, there is only cnt for bytes, and then horizontal adds.
add_8_16_32({{sel_op("vcnt.", "cnt"), 8, total_bits == 64 ? 8 : 16}}, vf, popcount(i_1));
add_8_16_32({{sel_op("vcnt.", "cnt"), 8, total_bits == 64 ? 8 : 16}}, vf, popcount(u_1));
} else {
add_8_16_32("cnt", popcount(i_1));
add_8_16_32("cnt", popcount(u_1));
}
// VDUP X - Duplicate
add_8_16_32(sel_op("vdup.", "dup", "mov"), cast_i(y));
add_8_16_32(sel_op("vdup.", "dup", "mov"), cast_u(y));
// VEOR X - Bitwise Exclusive OR
// check("veor", 4, bool1 ^ bool2);
// VEXT I - Extract Elements and Concatenate
// unaligned loads with known offsets should use vext
#if 0
// We currently don't do this.
check("vext.8", 16, in_i8(x+1));
check("vext.16", 8, in_i16(x+1));
check("vext.32", 4, in_i32(x+1));
#endif
// VHADD I - Halving Add
add_8_16_32(sel_op("vhadd.s", "shadd"), cast_i((widen_i(i_1) + widen_i(i_2)) / 2));
add_8_16_32(sel_op("vhadd.u", "uhadd"), cast_u((widen_u(u_1) + widen_u(u_2)) / 2));
// Halide doesn't define overflow behavior for i32 so we
// can use vhadd instruction. We can't use it for unsigned u8,i16,u16,u32.
add_32(sel_op("vhadd.s", "shadd"), (i_1 + i_2) / 2);
// VHSUB I - Halving Subtract
add_8_16_32(sel_op("vhsub.s", "shsub"), cast_i((widen_i(i_1) - widen_i(i_2)) / 2));
add_8_16_32(sel_op("vhsub.u", "uhsub"), cast_u((widen_u(u_1) - widen_u(u_2)) / 2));
add_32(sel_op("vhsub.s", "shsub"), (i_1 - i_2) / 2);
// VMAX I, F - Maximum
add_8_16_32(sel_op("vmax.s", "smax"), max(i_1, i_2));
add_8_16_32(sel_op("vmax.u", "umax"), max(u_1, u_2));
// VMIN I, F - Minimum
add_8_16_32(sel_op("vmin.s", "smin"), min(i_1, i_2));
add_8_16_32(sel_op("vmin.u", "umin"), min(u_1, u_2));
// VMLA I, F F, D Multiply Accumulate
add_8_16_32("mla signed", sel_op("vmla.i", "mla", "(mad|mla)"), i_1 + i_2 * i_3);
add_8_16_32("mla unsigned", sel_op("vmla.i", "mla", "(mad|mla)"), u_1 + u_2 * u_3);
// VMLS I, F F, D Multiply Subtract
add_8_16_32("mls signed", sel_op("vmls.i", "mls", "(mls|msb)"), i_1 - i_2 * i_3);
add_8_16_32("mls unsigned", sel_op("vmls.i", "mls", "(mls|msb)"), u_1 - u_2 * u_3);
// VMLAL I - Multiply Accumulate Long
// Try to trick LLVM into generating a zext instead of a sext by making
// LLVM think the operand never has a leading 1 bit. zext breaks LLVM's
// pattern matching of mlal.
add_8_16_32_widen(sel_op("vmlal.s", "smlal"), i_wide_1 + widen_i(i_2 & 0x3) * i_3);
add_8_16_32_widen(sel_op("vmlal.u", "umlal"), u_wide_1 + widen_u(u_2) * u_3);
// VMLSL I - Multiply Subtract Long
add_8_16_32_widen(sel_op("vmlsl.s", "smlsl"), i_wide_1 - widen_i(i_2 & 0x3) * i_3);
add_8_16_32_widen(sel_op("vmlsl.u", "umlsl"), u_wide_1 - widen_u(u_2) * u_3);
// VMOV X F, D Move Register or Immediate
// This is for loading immediates, which we won't do in the inner loop anyway
// VMOVL I - Move Long
// For aarch64, llvm does a widening shift by 0 instead of using the sxtl instruction.
add_8_16_32_widen(sel_op("vmovl.s", "sshll"), widen_i(i_1));
add_8_16_32_widen(sel_op("vmovl.u", "ushll"), widen_u(u_1));
add_8_16_32_widen(sel_op("vmovl.u", "ushll"), widen_i(u_1));
// VMOVN I - Move and Narrow
if (total_bits >= 128) {
if (is_arm32()) {
add_16_32_64_narrow("vmovn.i", narrow_i(i_1));
add_16_32_64_narrow("vmovn.i", narrow_u(u_1));
} else {
add_16_32_64({{"uzp1", bits / 2, narrow_lanes * 2}}, vf * 2, narrow_i(i_1));
add_16_32_64({{"uzp1", bits / 2, narrow_lanes * 2}}, vf * 2, narrow_u(u_1));
}
} else {
add_16_32_64_narrow(sel_op("vmovn.i", "xtn"), narrow_i(i_1));
add_16_32_64_narrow(sel_op("vmovn.i", "xtn"), narrow_u(u_1));
}
// VMRS X F, D Move Advanced SIMD or VFP Register to ARM compute Engine
// VMSR X F, D Move ARM Core Register to Advanced SIMD or VFP
// trust llvm to use this correctly
// VMUL I, F, P F, D Multiply
add_8_16_32(sel_op("vmul.i", "mul"), i_2 * i_1);
add_8_16_32(sel_op("vmul.i", "mul"), u_2 * u_1);
// VMULL I, F, P - Multiply Long
add_8_16_32_widen(sel_op("vmull.s", "smull"), widen_i(i_1) * i_2);
add_8_16_32_widen(sel_op("vmull.u", "umull"), widen_u(u_1) * u_2);
// integer division by a constant should use fixed point unsigned
// multiplication, which is done by using a widening multiply
// followed by a narrowing
add_8_16_32_widen(sel_op("vmull.u", "umull"), i_1 / 37);
add_8_16_32_widen(sel_op("vmull.u", "umull"), u_1 / 37);
// VMVN X - Bitwise NOT
// check("vmvn", ~bool1);
// VNEG I, F F, D Negate
add_8_16_32(sel_op("vneg.s", "neg"), -i_1);
#if 0
// These are vfp, not neon. They only work on scalars
check("vnmla.f32", 4, -(f32_1 + f32_2*f32_3));
check("vnmla.f64", 2, -(f64_1 + f64_2*f64_3));
check("vnmls.f32", 4, -(f32_1 - f32_2*f32_3));
check("vnmls.f64", 2, -(f64_1 - f64_2*f64_3));
check("vnmul.f32", 4, -(f32_1*f32_2));
check("vnmul.f64", 2, -(f64_1*f64_2));
// Of questionable value. Catching abs calls is annoying, and the
// slow path is only one more op (for the max).
check("vqabs.s8", 16, abs(max(i8_1, -max_i8)));
check("vqabs.s8", 8, abs(max(i8_1, -max_i8)));
check("vqabs.s16", 8, abs(max(i16_1, -max_i16)));
check("vqabs.s16", 4, abs(max(i16_1, -max_i16)));
check("vqabs.s32", 4, abs(max(i32_1, -max_i32)));
check("vqabs.s32", 2, abs(max(i32_1, -max_i32)));
#endif
// VQADD I - Saturating Add
add_8_16_32(sel_op("vqadd.s", "sqadd"), satcast_i(widen_i(i_1) + widen_i(i_2)));
const Expr max_u = UInt(bits).max();
add_8_16_32(sel_op("vqadd.u", "uqadd"), cast_u(min(widen_u(u_1) + widen_u(u_2), max_u)));
// Check the case where we add a constant that could be narrowed
add_8_16_32(sel_op("vqadd.u", "uqadd"), cast_u(min(widen_u(u_1) + 17, max_u)));
// Can't do larger ones because we can't represent the intermediate 128-bit wide ops.
// VQDMLAL I - Saturating Double Multiply Accumulate Long
// VQDMLSL I - Saturating Double Multiply Subtract Long
// We don't do these, but it would be possible.
// VQDMULH I - Saturating Doubling Multiply Returning High Half
// VQDMULL I - Saturating Doubling Multiply Long
add_16_32(sel_op("vqdmulh.s", "sqdmulh"), satcast_i((widen_i(i_1) * widen_i(i_2)) >> (bits - 1)));
// VQMOVN I - Saturating Move and Narrow
// VQMOVUN I - Saturating Move and Unsigned Narrow
add_16_32_64_narrow(sel_op("vqmovn.s", "sqxtn"), satnarrow_i(i_1));
add_16_32_64_narrow(sel_op("vqmovun.s", "sqxtun"), satnarrow_u(i_1));
const Expr max_u_narrow = UInt(bits / 2).max();
add_16_32_64_narrow(sel_op("vqmovn.u", "uqxtn"), narrow_u(min(u_1, max_u_narrow)));
// Double saturating narrow
add_16_32_narrow(sel_op("vqmovn.s", "sqxtn"), satnarrow_i(i_wide_1));
add_16_32_narrow(sel_op("vqmovn.u", "uqxtn"), narrow_u(min(u_wide_1, max_u_narrow)));
add_16_32_narrow(sel_op("vqmovn.s", "sqxtn"), satnarrow_i(i_wide_1));
add_16_32_narrow(sel_op("vqmovun.s", "sqxtun"), satnarrow_u(i_wide_1));
// Triple saturating narrow
Expr i64_1 = in_i64(x), u64_1 = in_u64(x), f32_1 = in_f32(x), f64_1 = in_f64(x);
add_16_narrow(sel_op("vqmovn.s", "sqxtn"), satnarrow_i(i64_1));
add_16_narrow(sel_op("vqmovn.u", "uqxtn"), narrow_u(min(u64_1, max_u_narrow)));
add_16_narrow(sel_op("vqmovn.s", "sqxtn"), satnarrow_i(f32_1));
add_16_narrow(sel_op("vqmovn.s", "sqxtn"), satnarrow_i(f64_1));
add_16_narrow(sel_op("vqmovun.s", "sqxtun"), satnarrow_u(f32_1));
add_16_narrow(sel_op("vqmovun.s", "sqxtun"), satnarrow_u(f64_1));
// VQNEG I - Saturating Negate
const Expr max_i = Int(bits).max();
add_8_16_32(sel_op("vqneg.s", "sqneg"), -max(i_1, -max_i));
// VQRDMULH I - Saturating Rounding Doubling Multiply Returning High Half
// Note: division in Halide always rounds down (not towards
// zero). Otherwise these patterns would be more complicated.
add_16_32(sel_op("vqrdmulh.s", "sqrdmulh"), satcast_i((widen_i(i_1) * widen_i(i_2) + (1 << (bits - 2))) / (widen_i(1) << (bits - 1))));
// VQRSHRN I - Saturating Rounding Shift Right Narrow
// VQRSHRUN I - Saturating Rounding Shift Right Unsigned Narrow
add_16_32_64_narrow(sel_op("vqrshrn.s", "sqrshrn"), satnarrow_i((widen_i(i_1) + 8) / 16));
add_16_32_64_narrow(sel_op("vqrshrun.s", "sqrshrun"), satnarrow_u((widen_i(i_1) + 8) / 16));
add_16_32_narrow(sel_op("vqrshrn.u", "uqrshrn"), narrow_u(min((widen_u(u_1) + 8) / 16, max_u_narrow)));
// VQSHL I - Saturating Shift Left
add_8_16_32(sel_op("vqshl.s", "sqshl"), satcast_i(widen_i(i_1) * 16));
add_8_16_32(sel_op("vqshl.u", "uqshl"), cast_u(min(widen_u(u_1) * 16, max_u)));
// VQSHLU I - Saturating Shift Left Unsigned
if (!has_sve()) {
add_8_16_32(sel_op("vqshlu.s", "sqshlu"), satcast_u(widen_i(i_1) * 16));
}
// VQSHRN I - Saturating Shift Right Narrow
// VQSHRUN I - Saturating Shift Right Unsigned Narrow
add_16_32_64_narrow(sel_op("vqshrn.s", "sqshrn"), satnarrow_i(i_1 / 16));
add_16_32_64_narrow(sel_op("vqshrun.s", "sqshrun"), satnarrow_u(i_1 / 16));
add_16_32_narrow(sel_op("vqshrn.u", "uqshrn"), narrow_u(min(u_1 / 16, max_u_narrow)));
// VQSUB I - Saturating Subtract
add_8_16_32(sel_op("vqsub.s", "sqsub"), satcast_i(widen_i(i_1) - widen_i(i_2)));
// N.B. Saturating subtracts are expressed by widening to a igned* type
add_8_16_32(sel_op("vqsub.u", "uqsub"), satcast_u(widen_i(u_1) - widen_i(u_2)));
// VRADDHN I - Rounding Add and Narrow Returning High Half
add_16_32_64_narrow(sel_op("vraddhn.i", "raddhn"), narrow_i((widen_i(i_1 + i_2) + (Expr(cast_i(1)) << (bits / 2 - 1))) >> (bits / 2)));
add_16_32_narrow(sel_op("vraddhn.i", "raddhn"), narrow_u((widen_u(u_1 + u_2) + (Expr(cast_u(1)) << (bits / 2 - 1))) >> (bits / 2)));
// VREV16 X - Reverse in Halfwords
// VREV32 X - Reverse in Words
// VREV64 X - Reverse in Doublewords
// These reverse within each halfword, word, and doubleword
// respectively. Sometimes llvm generates them, and sometimes
// it generates vtbl instructions.
// VRHADD I - Rounding Halving Add
add_8_16_32(sel_op("vrhadd.s", "srhadd"), cast_i((widen_i(i_1) + widen_i(i_2) + 1) / 2));
add_8_16_32(sel_op("vrhadd.u", "urhadd"), cast_u((widen_u(u_1) + widen_u(u_2) + 1) / 2));
// VRSHL I - Rounding Shift Left
Expr shift = (i_2 % bits) - (bits / 2);
Expr round_s = (cast_i(1) >> min(shift, 0)) / 2;
Expr round_u = (cast_u(1) >> min(shift, 0)) / 2;
add_8_16_32(sel_op("vrshl.s", "srshl", "srshlr"), cast_i((widen_i(i_1) + round_s) << shift));
add_8_16_32(sel_op("vrshl.u", "urshl", "urshlr"), cast_u((widen_u(u_1) + round_u) << shift));
round_s = (cast_i(1) << max(shift, 0)) / 2;
round_u = (cast_u(1) << max(shift, 0)) / 2;
add_8_16_32(sel_op("vrshl.s", "srshl", "srshlr"), cast_i((widen_i(i_1) + round_s) >> shift));
add_8_16_32(sel_op("vrshl.u", "urshl", "urshlr"), cast_u((widen_u(u_1) + round_u) >> shift));
// VRSHR I - Rounding Shift Right
add_8_16_32(sel_op("vrshr.s", "srshr", "srshl"), cast_i((widen_i(i_1) + 1) >> 1));
add_8_16_32(sel_op("vrshr.u", "urshr", "urshl"), cast_u((widen_u(u_1) + 1) >> 1));
// VRSHRN I - Rounding Shift Right Narrow
// LLVM14 converts RSHRN/RSHRN2 to RADDHN/RADDHN2 when the shift amount is half the width of the vector element
// See https://reviews.llvm.org/D116166
add_16_32_narrow(sel_op("vrshrn.i", "raddhn"), narrow_i((widen_i(i_1) + (cast_i(1) << (bits / 2 - 1))) >> (bits / 2)));
add_16_32_narrow(sel_op("vrshrn.i", "raddhn"), narrow_u((widen_u(u_1) + (cast_u(1) << (bits / 2 - 1))) >> (bits / 2)));
add_16_32_64_narrow(sel_op("vrshrn.i", "rshrn"), narrow_i((widen_i(i_1) + (1 << (bits / 4))) >> (bits / 4 + 1)));
add_16_32_narrow(sel_op("vrshrn.i", "rshrn"), narrow_u((widen_u(u_1) + (1 << (bits / 4))) >> (bits / 4 + 1)));
// VRSRA I - Rounding Shift Right and Accumulate
if (!has_sve()) {
// Relying on LLVM to detect accumulation
add_8_16_32(sel_op("vrsra.s", "srsra"), i_2 + cast_i((widen_i(i_1) + 1) >> 1));
add_8_16_32(sel_op("vrsra.u", "ursra"), i_2 + cast_u((widen_u(u_1) + 1) >> 1));
}
// VRSUBHN I - Rounding Subtract and Narrow Returning High Half
add_16_32_64_narrow(sel_op("vrsubhn.i", "rsubhn"), narrow_i((widen_i(i_1 - i_2) + (Expr(cast_i(1)) << (bits / 2 - 1))) >> (bits / 2)));
add_16_32_narrow(sel_op("vrsubhn.i", "rsubhn"), narrow_u((widen_u(u_1 - u_2) + (Expr(cast_u(1)) << (bits / 2 - 1))) >> (bits / 2)));
// VSHL I - Shift Left
add_all_vec(sel_op("vshl.i", "shl", "lsl"), i_1 * 16);
add_all_vec(sel_op("vshl.i", "shl", "lsl"), u_1 * 16);
if (!has_sve()) { // No equivalent instruction in SVE.
add_all_vec(sel_op("vshl.s", "sshl"), i_1 << shift);
add_all_vec(sel_op("vshl.s", "sshl"), i_1 >> shift);
add_all_vec(sel_op("vshl.u", "ushl"), u_1 << shift);
add_all_vec(sel_op("vshl.u", "ushl"), u_1 >> shift);
}
// VSHLL I - Shift Left Long
add_8_16_32_widen(sel_op("vshll.s", "sshll"), widen_i(i_1) * 16);
add_8_16_32_widen(sel_op("vshll.u", "ushll"), widen_u(u_1) * 16);
// VSHR I - Shift Right
add_all_vec(sel_op("vshr.s", "sshr", "asr"), i_1 / 16);
add_all_vec(sel_op("vshr.u", "ushr", "lsr"), u_1 / 16);
// VSHRN I - Shift Right Narrow
add_16_32_64_narrow(sel_op("vshrn.i", "shrn"), narrow_i(i_1 >> (bits / 2)));
add_16_32_64_narrow(sel_op("vshrn.i", "shrn"), narrow_u(u_1 >> (bits / 2)));
add_16_32_64_narrow(sel_op("vshrn.i", "shrn"), narrow_i(i_1 / 16));
add_16_32_64_narrow(sel_op("vshrn.i", "shrn"), narrow_u(u_1 / 16));
// VSLI X - Shift Left and Insert
// I guess this could be used for (x*256) | (y & 255)? We don't do bitwise ops on integers, so skip it.
// VSRA I - Shift Right and Accumulate
if (!has_sve()) {
// Relying on LLVM to detect accumulation
add_all_vec(sel_op("vsra.s", "ssra"), i_2 + i_1 / 16);
add_all_vec(sel_op("vsra.u", "usra"), u_2 + u_1 / 16);
}
// VSRI X - Shift Right and Insert
// See VSLI
// VSUB I, F F, D Subtract
add_all_vec(sel_op("vsub.i", "sub"), i_1 - i_2);
add_all_vec(sel_op("vsub.i", "sub"), u_1 - u_2);
// VSUBHN I - Subtract and Narrow
add_16_32_64_narrow(sel_op("vsubhn.i", "subhn"), narrow_i((i_1 - i_2) >> (bits / 2)));
add_16_32_64_narrow(sel_op("vsubhn.i", "subhn"), narrow_u((u_1 - u_2) >> (bits / 2)));
// VSUBL I - Subtract Long
add_8_16_32_widen(sel_op("vsubl.s", "ssubl"), widen_i(i_1) - widen_i(i_2));
add_8_16_32_widen(sel_op("vsubl.u", "usubl"), widen_u(u_1) - widen_u(u_2));
add_8_16_32_widen(sel_op("vsubl.s", "ssubl"), widen_i(i_1) - widen_i(in_i(0)));
add_8_16_32_widen(sel_op("vsubl.u", "usubl"), widen_u(u_1) - widen_u(in_u(0)));
// VSUBW I - Subtract Wide
add_8_16_32_widen(sel_op("vsubw.s", "ssubw"), i_wide_1 - i_1);
add_8_16_32_widen(sel_op("vsubw.u", "usubw"), u_wide_1 - u_1);
}
}
}
void check_arm_float() {
vector<tuple<int, CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy>> test_params{
{16, in_f16, in_u16, in_i16, f16},
{32, in_f32, in_u32, in_i32, f32},
{64, in_f64, in_u64, in_i64, f64},
};
for (const auto &[bits, in_f, in_u, in_i, cast_f] : test_params) {
Expr f_1 = in_f(x), f_2 = in_f(x + 16), f_3 = in_f(x + 32);
Expr u_1 = in_u(x);
Expr i_1 = in_i(x);
// Arithmetic which could throw FP exception could return NaN, which results in output mismatch.
// To avoid that, we need a positive value within certain range
Func in_f_clamped;
in_f_clamped(x) = clamp(in_f(x), cast_f(1e-3f), cast_f(1.0f));
in_f_clamped.compute_root(); // To prevent LLVM optimization which results in a different instruction
Expr f_1_clamped = in_f_clamped(x);
Expr f_2_clamped = in_f_clamped(x + 16);
if (bits == 16 && !is_float16_supported()) {
continue;
}
vector total_bits_params = {256}; // {64, 128, 192, 256};
if (bits != 64) {
// Add scalar case to verify float16 native operation
total_bits_params.push_back(bits);
}
for (auto total_bits : total_bits_params) {
const int vf = total_bits / bits;
const bool is_vector = vf > 1;
const int instr_lanes = Instruction::get_instr_lanes(bits, vf, target);
const int force_vectorized_lanes = Instruction::get_force_vectorized_instr_lanes(bits, vf, target);
AddTestFunctor add(*this, bits, instr_lanes, vf);
AddTestFunctor add_arm32_f32(*this, bits, vf, is_arm32() && bits == 32);
AddTestFunctor add_arm64(*this, bits, instr_lanes, vf, !is_arm32());
add({{sel_op("vabs.f", "fabs"), bits, force_vectorized_lanes}}, vf, abs(f_1));
add(sel_op("vadd.f", "fadd"), f_1 + f_2);
add(sel_op("vsub.f", "fsub"), f_1 - f_2);
add(sel_op("vmul.f", "fmul"), f_1 * f_2);
add("fdiv", sel_op("vdiv.f", "fdiv", "(fdiv|fdivr)"), f_1 / f_2_clamped);
auto fneg_lanes = has_sve() ? force_vectorized_lanes : instr_lanes;
add({{sel_op("vneg.f", "fneg"), bits, fneg_lanes}}, vf, -f_1);
add({{sel_op("vsqrt.f", "fsqrt"), bits, force_vectorized_lanes}}, vf, sqrt(f_1_clamped));
add_arm32_f32(is_vector ? "vceq.f" : "vcmp.f", select(f_1 == f_2, cast_f(1.0f), cast_f(2.0f)));
add_arm32_f32(is_vector ? "vcgt.f" : "vcmp.f", select(f_1 > f_2, cast_f(1.0f), cast_f(2.0f)));
add_arm64(is_vector ? "fcmeq" : "fcmp", select(f_1 == f_2, cast_f(1.0f), cast_f(2.0f)));
add_arm64(is_vector ? "fcmgt" : "fcmp", select(f_1 > f_2, cast_f(1.0f), cast_f(2.0f)));
add_arm32_f32("vcvt.f32.u", cast_f(u_1));
add_arm32_f32("vcvt.f32.s", cast_f(i_1));
add_arm32_f32("vcvt.u32.f", cast(UInt(bits), f_1));
add_arm32_f32("vcvt.s32.f", cast(Int(bits), f_1));
// The max of Float(16) is less than that of UInt(16), which generates "nan" in emulator
Expr float_max = Float(bits).max();
add_arm64("ucvtf", cast_f(min(float_max, u_1)));
add_arm64("scvtf", cast_f(i_1));
add_arm64({{"fcvtzu", bits, force_vectorized_lanes}}, vf, cast(UInt(bits), f_1));
add_arm64({{"fcvtzs", bits, force_vectorized_lanes}}, vf, cast(Int(bits), f_1));
add_arm64({{"frintn", bits, force_vectorized_lanes}}, vf, round(f_1));
add_arm64({{"frintm", bits, force_vectorized_lanes}}, vf, floor(f_1));
add_arm64({{"frintp", bits, force_vectorized_lanes}}, vf, ceil(f_1));
add_arm64({{"frintz", bits, force_vectorized_lanes}}, vf, trunc(f_1));
add_arm32_f32({{"vmax.f", bits, force_vectorized_lanes}}, vf, max(f_1, f_2));
add_arm32_f32({{"vmin.f", bits, force_vectorized_lanes}}, vf, min(f_1, f_2));
add_arm64({{"fmax", bits, force_vectorized_lanes}}, vf, max(f_1, f_2));
add_arm64({{"fmin", bits, force_vectorized_lanes}}, vf, min(f_1, f_2));
if (bits != 64 && total_bits != 192) {
// Halide relies on LLVM optimization for this pattern, and in some case it doesn't work
add_arm64("fmla", is_vector ? (has_sve() ? "(fmla|fmad)" : "fmla") : "fmadd", f_1 + f_2 * f_3);
add_arm64("fmls", is_vector ? (has_sve() ? "(fmls|fmsb)" : "fmls") : "fmsub", f_1 - f_2 * f_3);
}
if (bits != 64) {
add_arm64(vector<string>{"frecpe", "frecps"}, fast_inverse(f_1_clamped));
add_arm64(vector<string>{"frsqrte", "frsqrts"}, fast_inverse_sqrt(f_1_clamped));
}
if (bits == 16) {
// Some of the math ops (exp,log,pow) for fp16 are converted into "xxx_fp32" call
// and then lowered to Internal::halide_xxx() function.
// In case the target has FP16 feature, native type conversion between fp16 and fp32 should be generated
// instead of emulated equivalent code with other types.
if (is_vector && !has_sve()) {
add_arm64("exp", {{"fcvtl", 16, 4}, {"fcvtn", 16, 4}}, vf, exp(f_1_clamped));
add_arm64("log", {{"fcvtl", 16, 4}, {"fcvtn", 16, 4}}, vf, log(f_1_clamped));
add_arm64("pow", {{"fcvtl", 16, 4}, {"fcvtn", 16, 4}}, vf, pow(f_1_clamped, f_2_clamped));
} else {
add_arm64("exp", "fcvt", exp(f_1_clamped));
add_arm64("log", "fcvt", log(f_1_clamped));
add_arm64("pow", "fcvt", pow(f_1_clamped, f_2_clamped));
}
}
// No corresponding instructions exists for is_nan, is_inf, is_finite.
// The instructions expected to be generated depends on CodeGen_LLVM::visit(const Call *op)
add_arm64("nan", is_vector ? sel_op("", "fcmge", "fcmuo") : "fcmp", is_nan(f_1));
if (Halide::Internal::get_llvm_version() >= 200) {
add_arm64("inf", is_vector ? sel_op("", "fcmge", "fcmeq") : "", is_inf(f_1));
add_arm64("finite", is_vector ? sel_op("", "fcmge", "fcmeq") : "", is_inf(f_1));
} else {
add_arm64("inf", {{"fabs", bits, force_vectorized_lanes}}, vf, is_inf(f_1));
add_arm64("finite", {{"fabs", bits, force_vectorized_lanes}}, vf, is_inf(f_1));
}
}
if (bits == 16) {
// Actually, the following ops are not vectorized because SIMD instruction is unavailable.
// The purpose of the test is just to confirm no error.
// In case the target has FP16 feature, native type conversion between fp16 and fp32 should be generated
// instead of emulated equivalent code with other types.
AddTestFunctor add_f16(*this, 16, 1);
add_f16("sinf", {{"bl", "sinf"}, {"fcvt", 16, 1}}, 1, sin(f_1_clamped));
add_f16("asinf", {{"bl", "asinf"}, {"fcvt", 16, 1}}, 1, asin(f_1_clamped));
add_f16("cosf", {{"bl", "cosf"}, {"fcvt", 16, 1}}, 1, cos(f_1_clamped));
add_f16("acosf", {{"bl", "acosf"}, {"fcvt", 16, 1}}, 1, acos(f_1_clamped));
add_f16("tanf", {{"bl", "tanf"}, {"fcvt", 16, 1}}, 1, tan(f_1_clamped));
add_f16("atanf", {{"bl", "atanf"}, {"fcvt", 16, 1}}, 1, atan(f_1_clamped));
add_f16("atan2f", {{"bl", "atan2f"}, {"fcvt", 16, 1}}, 1, atan2(f_1_clamped, f_2_clamped));
add_f16("sinhf", {{"bl", "sinhf"}, {"fcvt", 16, 1}}, 1, sinh(f_1_clamped));
add_f16("asinhf", {{"bl", "asinhf"}, {"fcvt", 16, 1}}, 1, asinh(f_1_clamped));
add_f16("coshf", {{"bl", "coshf"}, {"fcvt", 16, 1}}, 1, cosh(f_1_clamped));
add_f16("acoshf", {{"bl", "acoshf"}, {"fcvt", 16, 1}}, 1, acosh(max(f_1, cast_f(1.0f))));
add_f16("tanhf", {{"bl", "tanhf"}, {"fcvt", 16, 1}}, 1, tanh(f_1_clamped));
add_f16("atanhf", {{"bl", "atanhf"}, {"fcvt", 16, 1}}, 1, atanh(clamp(f_1, cast_f(-0.5f), cast_f(0.5f))));
}
}
}
void check_arm_load_store() {
vector<tuple<Type, CastFuncTy>> test_params = {
{Int(8), in_i8}, {Int(16), in_i16}, {Int(32), in_i32}, {Int(64), in_i64}, {UInt(8), in_u8}, {UInt(16), in_u16}, {UInt(32), in_u32}, {UInt(64), in_u64}, {Float(16), in_f16}, {Float(32), in_f32}, {Float(64), in_f64}};
for (const auto &[elt, in_im] : test_params) {
const int bits = elt.bits();
if ((elt == Float(16) && !is_float16_supported()) ||
(is_arm32() && bits == 64)) {
continue;
}
// LD/ST - Load/Store
for (int width = 64; width <= 64 * 4; width *= 2) {
const int total_lanes = width / bits;
const int instr_lanes = min(total_lanes, 128 / bits);
if (instr_lanes < 2) continue; // bail out scalar op
// In case of arm32, instruction selection looks inconsistent due to optimization by LLVM
AddTestFunctor add(*this, bits, total_lanes, target.bits == 64);
// NOTE: if the expr is too simple, LLVM might generate "bl memcpy"
Expr load_store_1 = in_im(x) * 3;
if (has_sve()) {
// This pattern has changed with LLVM 21, see https://github.com/halide/Halide/issues/8584 for more
// details.
if (Halide::Internal::get_llvm_version() <= 200) {
// in native width, ld1b/st1b is used regardless of data type
const bool allow_byte_ls = (width == target.vector_bits);
add({get_sve_ls_instr("ld1", bits, bits, "", allow_byte_ls ? "b" : "")}, total_lanes, load_store_1);
add({get_sve_ls_instr("st1", bits, bits, "", allow_byte_ls ? "b" : "")}, total_lanes, load_store_1);
}
} else {
// vector register is not used for simple load/store
string reg_prefix = (width <= 64) ? "d" : "q";
add({{"st[rp]", reg_prefix + R"(\d\d?)"}}, total_lanes, load_store_1);
add({{"ld[rp]", reg_prefix + R"(\d\d?)"}}, total_lanes, load_store_1);
}
}
// LD2/ST2 - Load/Store two-element structures
int base_vec_bits = has_sve() ? target.vector_bits : 128;
for (int width = base_vec_bits; width <= base_vec_bits * 4; width *= 2) {
const int total_lanes = width / bits;
const int vector_lanes = total_lanes / 2;
const int instr_lanes = min(vector_lanes, base_vec_bits / bits);
if (instr_lanes < 2) continue; // bail out scalar op
AddTestFunctor add_ldn(*this, bits, vector_lanes);
AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);
Func tmp1, tmp2;
tmp1(x) = cast(elt, x);
tmp1.compute_root();
tmp2(x, y) = select(x % 2 == 0, tmp1(x / 2), tmp1(x / 2 + 16));
tmp2.compute_root().vectorize(x, total_lanes);
Expr load_2 = in_im(x * 2) + in_im(x * 2 + 1);
Expr store_2 = tmp2(0, 0) + tmp2(0, 127);
if (has_sve()) {
// TODO(inssue needed): Added strided load support.
#if 0
add_ldn({get_sve_ls_instr("ld2", bits)}, vector_lanes, load_2);
#endif
add_stn({get_sve_ls_instr("st2", bits)}, total_lanes, store_2);
} else {
add_ldn(sel_op("vld2.", "ld2"), load_2);
add_stn(sel_op("vst2.", "st2"), store_2);
}
}
// Also check when the two expressions interleaved have a common
// subexpression, which results in a vector var being lifted out.
for (int width = base_vec_bits; width <= base_vec_bits * 4; width *= 2) {
const int total_lanes = width / bits;
const int vector_lanes = total_lanes / 2;
const int instr_lanes = Instruction::get_instr_lanes(bits, vector_lanes, target);
if (instr_lanes < 2) continue; // bail out scalar op
AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);
Func tmp1, tmp2;
tmp1(x) = cast(elt, x);
tmp1.compute_root();
Expr e = (tmp1(x / 2) * 2 + 7) / 4;
tmp2(x, y) = select(x % 2 == 0, e * 3, e + 17);
tmp2.compute_root().vectorize(x, total_lanes);
Expr store_2 = tmp2(0, 0) + tmp2(0, 127);
if (has_sve()) {
add_stn({get_sve_ls_instr("st2", bits)}, total_lanes, store_2);
} else {
add_stn(sel_op("vst2.", "st2"), store_2);
}
}
// LD3/ST3 - Store three-element structures
for (int width = 192; width <= 192 * 4; width *= 2) {
const int total_lanes = width / bits;
const int vector_lanes = total_lanes / 3;
const int instr_lanes = Instruction::get_instr_lanes(bits, vector_lanes, target);
if (instr_lanes < 2) continue; // bail out scalar op
AddTestFunctor add_ldn(*this, bits, vector_lanes);
AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);
Func tmp1, tmp2;
tmp1(x) = cast(elt, x);
tmp1.compute_root();
tmp2(x, y) = select(x % 3 == 0, tmp1(x / 3),
x % 3 == 1, tmp1(x / 3 + 16),
tmp1(x / 3 + 32));
tmp2.compute_root().vectorize(x, total_lanes);
Expr load_3 = in_im(x * 3) + in_im(x * 3 + 1) + in_im(x * 3 + 2);
Expr store_3 = tmp2(0, 0) + tmp2(0, 127);
if (has_sve()) {
// TODO(issue needed): Added strided load support.
#if 0
add_ldn({get_sve_ls_instr("ld3", bits)}, vector_lanes, load_3);
add_stn({get_sve_ls_instr("st3", bits)}, total_lanes, store_3);
#endif
} else {
add_ldn(sel_op("vld3.", "ld3"), load_3);
add_stn(sel_op("vst3.", "st3"), store_3);
}
}
// LD4/ST4 - Store four-element structures
for (int width = 256; width <= 256 * 4; width *= 2) {
const int total_lanes = width / bits;
const int vector_lanes = total_lanes / 4;
const int instr_lanes = Instruction::get_instr_lanes(bits, vector_lanes, target);
if (instr_lanes < 2) continue; // bail out scalar op
AddTestFunctor add_ldn(*this, bits, vector_lanes);
AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);
Func tmp1, tmp2;
tmp1(x) = cast(elt, x);
tmp1.compute_root();
tmp2(x, y) = select(x % 4 == 0, tmp1(x / 4),
x % 4 == 1, tmp1(x / 4 + 16),
x % 4 == 2, tmp1(x / 4 + 32),
tmp1(x / 4 + 48));
tmp2.compute_root().vectorize(x, total_lanes);
Expr load_4 = in_im(x * 4) + in_im(x * 4 + 1) + in_im(x * 4 + 2) + in_im(x * 4 + 3);
Expr store_4 = tmp2(0, 0) + tmp2(0, 127);
if (has_sve()) {
// TODO(issue needed): Added strided load support.
#if 0
add_ldn({get_sve_ls_instr("ld4", bits)}, vector_lanes, load_4);
add_stn({get_sve_ls_instr("st4", bits)}, total_lanes, store_4);
#endif
} else {
add_ldn(sel_op("vld4.", "ld4"), load_4);
add_stn(sel_op("vst4.", "st4"), store_4);
}
}
// SVE Gather/Scatter
if (has_sve()) {
for (int width = 64; width <= 64 * 4; width *= 2) {
const int total_lanes = width / bits;
const int instr_lanes = min(total_lanes, 128 / bits);
if (instr_lanes < 2) continue; // bail out scalar op
AddTestFunctor add(*this, bits, total_lanes);
Expr index = clamp(cast<int>(in_im(x)), 0, W - 1);
Func tmp;
tmp(x, y) = cast(elt, y);
tmp(x, index) = cast(elt, 1);
tmp.compute_root().update().vectorize(x, total_lanes);
Expr gather = in_im(index);
Expr scatter = tmp(0, 0) + tmp(0, 127);
const int index_bits = std::max(32, bits);
add({get_sve_ls_instr("ld1", bits, index_bits, "uxtw")}, total_lanes, gather);
add({get_sve_ls_instr("st1", bits, index_bits, "uxtw")}, total_lanes, scatter);
}
}
}
}
void check_arm_pairwise() {
// A summation reduction that starts at something
// non-trivial, to avoid llvm simplifying accumulating
// widening summations into just widening summations.
auto sum_ = [&](Expr e) {
Func f;
f(x) = cast(e.type(), 123);
f(x) += e;
return f(x);
};
// Tests for integer type
{
vector<tuple<int, CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy>> test_params{
{8, in_i8, in_u8, i16, i32, u16, u32},
{16, in_i16, in_u16, i32, i64, u32, u64},
{32, in_i32, in_u32, i64, i64, u64, u64},
{64, in_i64, in_u64, i64, i64, u64, u64},
};
// clang-format on
for (const auto &[bits, in_i, in_u, widen_i, widenx4_i, widen_u, widenx4_u] : test_params) {
for (auto &total_bits : {64, 128}) {
const int vf = total_bits / bits;
const int instr_lanes = Instruction::get_force_vectorized_instr_lanes(bits, vf, target);
AddTestFunctor add(*this, bits, instr_lanes, vf, !(is_arm32() && bits == 64)); // 64 bit is unavailable in neon 32 bit
AddTestFunctor add_8_16_32(*this, bits, instr_lanes, vf, bits != 64);
const int widen_lanes = Instruction::get_instr_lanes(bits, vf * 2, target);
AddTestFunctor add_widen(*this, bits, widen_lanes, vf, bits != 64);
if (!has_sve()) {
// VPADD I, F - Pairwise Add
// VPMAX I, F - Pairwise Maximum
// VPMIN I, F - Pairwise Minimum
for (int f : {2, 4}) {
RDom r(0, f);
add(sel_op("vpadd.i", "addp"), sum_(in_i(f * x + r)));
add(sel_op("vpadd.i", "addp"), sum_(in_u(f * x + r)));
add_8_16_32(sel_op("vpmax.s", "smaxp"), maximum(in_i(f * x + r)));
add_8_16_32(sel_op("vpmax.u", "umaxp"), maximum(in_u(f * x + r)));
add_8_16_32(sel_op("vpmin.s", "sminp"), minimum(in_i(f * x + r)));
add_8_16_32(sel_op("vpmin.u", "uminp"), minimum(in_u(f * x + r)));
}
}
// VPADAL I - Pairwise Add and Accumulate Long
// VPADDL I - Pairwise Add Long
{
int f = 2;
RDom r(0, f);
// If we're reducing by a factor of two, we can
// use the forms with an accumulator
add_widen(sel_op("vpadal.s", "sadalp"), sum_(widen_i(in_i(f * x + r))));
add_widen(sel_op("vpadal.u", "uadalp"), sum_(widen_i(in_u(f * x + r))));
add_widen(sel_op("vpadal.u", "uadalp"), sum_(widen_u(in_u(f * x + r))));
}
{
int f = 4;
RDom r(0, f);
// If we're reducing by more than that, that's not
// possible.
// In case of SVE, addlp is unavailable, so adalp is used with accumulator=0 instead.
add_widen(sel_op("vpaddl.s", "saddlp", "sadalp"), sum_(widen_i(in_i(f * x + r))));
add_widen(sel_op("vpaddl.u", "uaddlp", "uadalp"), sum_(widen_i(in_u(f * x + r))));
add_widen(sel_op("vpaddl.u", "uaddlp", "uadalp"), sum_(widen_u(in_u(f * x + r))));
}
const bool is_arm_dot_prod_available = (!is_arm32() && target.has_feature(Target::ARMDotProd) && bits == 8) ||
(has_sve() && (bits == 8 || bits == 16));
if ((bits == 8 || bits == 16) && !is_arm_dot_prod_available) { // udot/sdot is applied if available
int f = 4;
RDom r(0, f);
// If we're widening the type by a factor of four
// as well as reducing by a factor of four, we
// expect vpaddl followed by vpadal
// Note that when going from u8 to i32 like this,
// the vpaddl is unsigned and the vpadal is a
// signed, because the intermediate type is u16
const int widenx4_lanes = Instruction::get_instr_lanes(bits * 2, vf, target);
string op_addl, op_adal;
op_addl = sel_op("vpaddl.s", "saddlp");
op_adal = sel_op("vpadal.s", "sadalp");
add({{op_addl, bits, widen_lanes}, {op_adal, bits * 2, widenx4_lanes}}, vf, sum_(widenx4_i(in_i(f * x + r))));
op_addl = sel_op("vpaddl.u", "uaddlp");
op_adal = sel_op("vpadal.u", "uadalp");
add({{op_addl, bits, widen_lanes}, {op_adal, bits * 2, widenx4_lanes}}, vf, sum_(widenx4_i(in_u(f * x + r))));
add({{op_addl, bits, widen_lanes}, {op_adal, bits * 2, widenx4_lanes}}, vf, sum_(widenx4_u(in_u(f * x + r))));
}
// UDOT/SDOT
if (is_arm_dot_prod_available) {
const int factor_32bit = vf / 4;
for (int f : {4, 8}) {
// checks vector register for narrow src data type (i.e. 8 or 16 bit)
const int lanes_src = Instruction::get_instr_lanes(bits, f * factor_32bit, target);
AddTestFunctor add_dot(*this, bits, lanes_src, factor_32bit);
RDom r(0, f);
add_dot("udot", sum(widenx4_u(in_u(f * x + r)) * in_u(f * x + r + 32)));
add_dot("sdot", sum(widenx4_i(in_i(f * x + r)) * in_i(f * x + r + 32)));
if (f == 4) {
// This doesn't generate for higher reduction factors because the
// intermediate is 16-bit instead of 32-bit. It seems like it would
// be slower to fix this (because the intermediate sum would be
// 32-bit instead of 16-bit).
add_dot("udot", sum(widenx4_u(in_u(f * x + r))));
add_dot("sdot", sum(widenx4_i(in_i(f * x + r))));
}
}
}
}
}
}
// Tests for Float type
{
// clang-format off
vector<tuple<int, CastFuncTy>> test_params{
{16, in_f16},
{32, in_f32},
{64, in_f64},
};
// clang-format on
if (!has_sve()) {
for (const auto &[bits, in_f] : test_params) {
for (auto &total_bits : {64, 128}) {
const int vf = total_bits / bits;
if (vf < 2) continue;
AddTestFunctor add(*this, bits, vf);
AddTestFunctor add_16_32(*this, bits, vf, bits != 64);
if (bits == 16 && !is_float16_supported()) {
continue;
}
for (int f : {2, 4}) {
RDom r(0, f);
add(sel_op("vadd.f", "faddp"), sum_(in_f(f * x + r)));
add_16_32(sel_op("vmax.f", "fmaxp"), maximum(in_f(f * x + r)));
add_16_32(sel_op("vmin.f", "fminp"), minimum(in_f(f * x + r)));
}
}
}
}
}
}
struct ArmTask {
vector<string> instrs;
};
struct Instruction {
string opcode;
optional<string> operand;
optional<int> bits;
optional<int> pattern_lanes;
static inline const int ANY_LANES = -1;
// matching pattern for opcode/operand is directly set
Instruction(const string &opcode, const string &operand)
: opcode(opcode), operand(operand), bits(nullopt), pattern_lanes(nullopt) {
}
// matching pattern for opcode/operand is generated from bits/lanes
Instruction(const string &opcode, int bits, int lanes)
: opcode(opcode), operand(nullopt), bits(bits), pattern_lanes(lanes) {
}
string generate_pattern(const Target &target) const {
bool is_arm32 = target.bits == 32;
bool has_sve = target.has_feature(Target::SVE2);
string opcode_pattern;
string operand_pattern;
if (bits && pattern_lanes) {
if (is_arm32) {
opcode_pattern = get_opcode_neon32();
operand_pattern = get_reg_neon32();
} else if (!has_sve) {
opcode_pattern = opcode;
operand_pattern = get_reg_neon64();
} else {
opcode_pattern = opcode;
operand_pattern = get_reg_sve();
}
} else {
opcode_pattern = opcode;
operand_pattern = operand.value_or("");
}
// e.g "add v15.h " -> "\s*add\s.*\bv\d\d?\.h\b.*"
return opcode_pattern + R"(\s.*\b)" + operand_pattern + R"(\b.*)";
}
// TODO Fix this for SVE2
static int natural_lanes(int bits) {
return 128 / bits;
}
static int get_instr_lanes(int bits, int vec_factor, const Target &target) {
return min(natural_lanes(bits), vec_factor);
}
static int get_force_vectorized_instr_lanes(int bits, int vec_factor, const Target &target) {
// For some cases, where scalar operation is forced to vectorize
if (target.has_feature(Target::SVE2)) {
if (vec_factor == 1) {
return 1;
} else {
return natural_lanes(bits);
}
} else {
int min_lanes = std::max(2, natural_lanes(bits) / 2); // 64 bit wide VL
return max(min_lanes, get_instr_lanes(bits, vec_factor, target));
}
}
string get_opcode_neon32() const {
return opcode + to_string(bits.value());
}
const char *get_bits_designator() const {
static const map<int, const char *> designators{
// NOTE: vector or float only
{8, "b"},
{16, "h"},
{32, "s"},
{64, "d"},
};
auto iter = designators.find(bits.value());
assert(iter != designators.end());
return iter->second;
}
string get_reg_sve() const {
if (pattern_lanes == ANY_LANES) {
return R"((z\d\d?\.[bhsd])|(s\d\d?))";
} else {
const char *bits_designator = get_bits_designator();
// TODO(need issue): This should only match the scalar register, and likely a NEON instruction opcode.
// Generating a full SVE vector instruction for a scalar operation is inefficient. However this is
// happening and fixing it involves changing intrinsic selection. Likely to use NEON intrinsics where
// applicable. For now, accept both a scalar operation and a vector one.
std::string scalar_reg_pattern = (pattern_lanes > 1) ? "" : std::string("|(") + bits_designator + R"(\d\d?))"; // e.g. "h15"
return std::string(R"(((z\d\d?\.)") + bits_designator + ")|(" +
R"(v\d\d?\.)" + to_string(pattern_lanes.value()) + bits_designator + ")" + scalar_reg_pattern + ")";
}
}
string get_reg_neon32() const {
return "";
}
string get_reg_neon64() const {
const char *bits_designator = get_bits_designator();
if (pattern_lanes == 1) {
return std::string(bits_designator) + R"(\d\d?)"; // e.g. "h15"
} else if (pattern_lanes == ANY_LANES) {
return R"(v\d\d?\.[bhsd])";
} else {
return R"(v\d\d?\.)" + to_string(pattern_lanes.value()) + bits_designator; // e.g. "v15.4h"
}
}
};
Instruction get_sve_ls_instr(const string &base_opcode, int opcode_bits, int operand_bits, const string &additional = "", const string &optional_type = "") {
static const map<int, string> opcode_suffix_map = {{8, "b"}, {16, "h"}, {32, "w"}, {64, "d"}};
static const map<int, string> operand_suffix_map = {{8, "b"}, {16, "h"}, {32, "s"}, {64, "d"}};
string opcode_size_specifier;
string operand_size_specifier;
if (!optional_type.empty()) {
opcode_size_specifier = "[";
operand_size_specifier = "[";
}
opcode_size_specifier += opcode_suffix_map.at(opcode_bits);
operand_size_specifier += operand_suffix_map.at(operand_bits);
if (!optional_type.empty()) {
opcode_size_specifier += optional_type;
opcode_size_specifier += "]";
operand_size_specifier += optional_type;
operand_size_specifier += "]";
}
const string opcode = base_opcode + opcode_size_specifier;
string operand = R"(z\d\d?\.)" + operand_size_specifier;
if (!additional.empty()) {
operand += ", " + additional;
}
return Instruction(opcode, operand);
}
Instruction get_sve_ls_instr(const string &base_opcode, int bits) {
return get_sve_ls_instr(base_opcode, bits, bits, "");
}
// Helper functor to add test case
class AddTestFunctor {
public:
AddTestFunctor(SimdOpCheckArmSve &p,
int default_bits,
int default_instr_lanes,
int default_vec_factor,
bool is_enabled = true /* false to skip testing */)
: parent(p), default_bits(default_bits), default_instr_lanes(default_instr_lanes),
default_vec_factor(default_vec_factor), is_enabled(is_enabled) {};
AddTestFunctor(SimdOpCheckArmSve &p,
int default_bits,
// default_instr_lanes is inferred from bits and vec_factor
int default_vec_factor,
bool is_enabled = true /* false to skip testing */)
: parent(p), default_bits(default_bits),
default_instr_lanes(Instruction::get_instr_lanes(default_bits, default_vec_factor, p.target)),
default_vec_factor(default_vec_factor), is_enabled(is_enabled) {};
// Constructs single Instruction with default parameters
void operator()(const string &opcode, Expr e) {
// Use opcode for name
(*this)(opcode, opcode, e);
}
// Constructs single Instruction with default parameters except for custom name
void operator()(const string &op_name, const string &opcode, Expr e) {
create_and_register(op_name, {Instruction{opcode, default_bits, default_instr_lanes}}, default_vec_factor, e);
}
// Constructs multiple Instruction with default parameters
void operator()(const vector<string> &opcodes, Expr e) {
assert(!opcodes.empty());
(*this)(opcodes[0], opcodes, e);
}
// Constructs multiple Instruction with default parameters except for custom name
void operator()(const string &op_name, const vector<string> &opcodes, Expr e) {
vector<Instruction> instrs;
for (const auto &opcode : opcodes) {
instrs.emplace_back(opcode, default_bits, default_instr_lanes);
}
create_and_register(op_name, instrs, default_vec_factor, e);
}
// Set single or multiple Instructions of custom parameters
void operator()(const vector<Instruction> &instructions, int vec_factor, Expr e) {
// Use the 1st opcode for name
assert(!instructions.empty());
string op_name = instructions[0].opcode;
(*this)(op_name, instructions, vec_factor, e);
}
// Set single or multiple Instructions of custom parameters, with custom name
void operator()(const string &op_name, const vector<Instruction> &instructions, int vec_factor, Expr e) {
create_and_register(op_name, instructions, vec_factor, e);
}
private:
void create_and_register(const string &op_name, const vector<Instruction> &instructions, int vec_factor, Expr e) {
if (!is_enabled) return;
// Generate regular expression for the instruction we check
vector<string> instr_patterns;
transform(instructions.begin(), instructions.end(), back_inserter(instr_patterns),
[t = parent.target](const Instruction &instr) { return instr.generate_pattern(t); });
std::stringstream type_name_stream;
type_name_stream << e.type();
std::string decorated_op_name = op_name + "_" + type_name_stream.str() + "_x" + std::to_string(vec_factor);
auto unique_name = "op_" + decorated_op_name + "_" + std::to_string(parent.tasks.size());
// Bail out after generating the unique_name, so that names are
// unique across different processes and don't depend on filter
// settings.
if (!parent.wildcard_match(parent.filter, decorated_op_name)) return;
// Create a deep copy of the expr and all Funcs referenced by it, so
// that no IR is shared between tests. This is required by the base
// class, and is why we can parallelize.
{
using namespace Halide::Internal;
class FindOutputs : public IRVisitor {
using IRVisitor::visit;
void visit(const Call *op) override {
if (op->func.defined()) {
outputs.insert(op->func);
}
IRVisitor::visit(op);
}
public:
std::set<FunctionPtr> outputs;
} finder;
e.accept(&finder);
std::vector<Function> outputs(finder.outputs.begin(), finder.outputs.end());
auto env = deep_copy(outputs, build_environment(outputs)).second;
class DeepCopy : public IRMutator {
std::map<FunctionPtr, FunctionPtr> copied;
using IRMutator::visit;
Expr visit(const Call *op) override {
if (op->func.defined()) {
auto it = env.find(op->name);
if (it != env.end()) {
return Func(it->second)(mutate(op->args));
}
}
return IRMutator::visit(op);
}
const std::map<std::string, Function> &env;
public:
DeepCopy(const std::map<std::string, Function> &env)
: env(env) {
}
} copier(env);
e = copier.mutate(e);
}
// Create Task and register
parent.tasks.emplace_back(Task{decorated_op_name, unique_name, vec_factor, e});
parent.arm_tasks.emplace(unique_name, ArmTask{std::move(instr_patterns)});
}
SimdOpCheckArmSve &parent;
int default_bits;
int default_instr_lanes;
int default_vec_factor;
bool is_enabled;
};
void compile_and_check(Func error, const string &op, const string &name, int vector_width, const std::vector<Argument> &arg_types, ostringstream &error_msg) override {
// This is necessary as LLVM validation errors, crashes, etc. don't tell which op crashed.
cout << "Starting op " << op << "\n";
string fn_name = "test_" + name;
string file_name = output_directory + fn_name;
auto ext = Internal::get_output_info(target);
std::map<OutputFileType, std::string> outputs = {
{OutputFileType::llvm_assembly, file_name + ext.at(OutputFileType::llvm_assembly).extension},
{OutputFileType::c_header, file_name + ext.at(OutputFileType::c_header).extension},
{OutputFileType::object, file_name + ext.at(OutputFileType::object).extension},
{OutputFileType::assembly, file_name + ".s"},
};
error.compile_to(outputs, arg_types, fn_name, target);
std::ifstream asm_file;
asm_file.open(file_name + ".s");
auto arm_task = arm_tasks.find(name);
assert(arm_task != arm_tasks.end());
std::ostringstream msg;
msg << op << " did not generate for target=" << target.to_string()
<< " vector_width=" << vector_width << ". Instead we got:\n";
string line;
vector<string> matched_lines;
vector<string> &patterns = arm_task->second.instrs;
while (getline(asm_file, line) && !patterns.empty()) {
msg << line << "\n";
auto pattern = patterns.begin();
while (pattern != patterns.end()) {
smatch match;
if (regex_search(line, match, regex(*pattern))) {
pattern = patterns.erase(pattern);
matched_lines.emplace_back(match[0]);
} else {
++pattern;
}
}
}
if (!patterns.empty()) {
error_msg << "Failed: " << msg.str() << "\n";
error_msg << "The following instruction patterns were not found:\n";
for (auto &p : patterns) {
error_msg << p << "\n";
}
} else if (debug_mode == "1") {
for (auto &l : matched_lines) {
error_msg << " " << setw(20) << name << ", vf=" << setw(2) << vector_width << ", ";
error_msg << l << endl;
}
}
}
inline const string &sel_op(const string &neon32, const string &neon64) {
return is_arm32() ? neon32 : neon64;
}
inline const string &sel_op(const string &neon32, const string &neon64, const string &sve) {
return is_arm32() ? neon32 :
target.has_feature(Target::SVE) || target.has_feature(Target::SVE2) ? sve :
neon64;
}
inline bool is_arm32() const {
return target.bits == 32;
};
inline bool has_neon() const {
return !target.has_feature(Target::NoNEON);
};
inline bool has_sve() const {
return target.has_feature(Target::SVE2);
};
bool is_float16_supported() const {
return (target.bits == 64) && target.has_feature(Target::ARMFp16);
}
bool can_run_the_code;
string debug_mode;
std::unordered_map<string, ArmTask> arm_tasks;
const Var x{"x"}, y{"y"};
};
} // namespace
int main(int argc, char **argv) {
if (Halide::Internal::get_llvm_version() < 190) {
std::cout << "[SKIP] simd_op_check_sve2 requires LLVM 19 or later.\n";
return 0;
}
return SimdOpCheckTest::main<SimdOpCheckArmSve>(
argc, argv,
{
// IMPORTANT:
// When adding new targets here, make sure to also update
// can_run_code in simd_op_check.h to include any new features used.
Target("arm-64-linux-sve2-no_neon-vector_bits_128"),
Target("arm-64-linux-sve2-no_neon-vector_bits_256"),
});
}
|