1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671
|
/*
* Copyright (c) 2020, Alliance for Open Media. All rights reserved.
*
* This source code is subject to the terms of the BSD 2 Clause License and
* the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
* was not distributed with this source code in the LICENSE file, you can
* obtain it at www.aomedia.org/license/software. If the Alliance for Open
* Media Patent License 1.0 was not distributed with this source code in the
* PATENTS file, you can obtain it at www.aomedia.org/license/patent.
*/
#include "av1/common/pred_common.h"
#include "av1/encoder/compound_type.h"
#include "av1/encoder/encoder_alloc.h"
#include "av1/encoder/model_rd.h"
#include "av1/encoder/motion_search_facade.h"
#include "av1/encoder/rdopt_utils.h"
#include "av1/encoder/reconinter_enc.h"
#include "av1/encoder/tx_search.h"
typedef int64_t (*pick_interinter_mask_type)(
const AV1_COMP *const cpi, MACROBLOCK *x, const BLOCK_SIZE bsize,
const uint8_t *const p0, const uint8_t *const p1,
const int16_t *const residual1, const int16_t *const diff10,
uint64_t *best_sse);
// Checks if characteristics of search match
static inline int is_comp_rd_match(const AV1_COMP *const cpi,
const MACROBLOCK *const x,
const COMP_RD_STATS *st,
const MB_MODE_INFO *const mi,
int32_t *comp_rate, int64_t *comp_dist,
int32_t *comp_model_rate,
int64_t *comp_model_dist, int *comp_rs2) {
// TODO(ranjit): Ensure that compound type search use regular filter always
// and check if following check can be removed
// Check if interp filter matches with previous case
if (st->filter.as_int != mi->interp_filters.as_int) return 0;
const MACROBLOCKD *const xd = &x->e_mbd;
// Match MV and reference indices
for (int i = 0; i < 2; ++i) {
if ((st->ref_frames[i] != mi->ref_frame[i]) ||
(st->mv[i].as_int != mi->mv[i].as_int)) {
return 0;
}
const WarpedMotionParams *const wm = &xd->global_motion[mi->ref_frame[i]];
if (is_global_mv_block(mi, wm->wmtype) != st->is_global[i]) return 0;
}
int reuse_data[COMPOUND_TYPES] = { 1, 1, 0, 0 };
// For compound wedge, reuse data if newmv search is disabled when NEWMV is
// present or if NEWMV is not present in either of the directions
if ((!have_newmv_in_inter_mode(mi->mode) &&
!have_newmv_in_inter_mode(st->mode)) ||
(cpi->sf.inter_sf.disable_interinter_wedge_newmv_search))
reuse_data[COMPOUND_WEDGE] = 1;
// For compound diffwtd, reuse data if fast search is enabled (no newmv search
// when NEWMV is present) or if NEWMV is not present in either of the
// directions
if (cpi->sf.inter_sf.enable_fast_compound_mode_search ||
(!have_newmv_in_inter_mode(mi->mode) &&
!have_newmv_in_inter_mode(st->mode)))
reuse_data[COMPOUND_DIFFWTD] = 1;
// Store the stats for the different compound types
for (int comp_type = COMPOUND_AVERAGE; comp_type < COMPOUND_TYPES;
comp_type++) {
if (reuse_data[comp_type]) {
comp_rate[comp_type] = st->rate[comp_type];
comp_dist[comp_type] = st->dist[comp_type];
comp_model_rate[comp_type] = st->model_rate[comp_type];
comp_model_dist[comp_type] = st->model_dist[comp_type];
comp_rs2[comp_type] = st->comp_rs2[comp_type];
}
}
return 1;
}
// Checks if similar compound type search case is accounted earlier
// If found, returns relevant rd data
static inline int find_comp_rd_in_stats(const AV1_COMP *const cpi,
const MACROBLOCK *x,
const MB_MODE_INFO *const mbmi,
int32_t *comp_rate, int64_t *comp_dist,
int32_t *comp_model_rate,
int64_t *comp_model_dist, int *comp_rs2,
int *match_index) {
for (int j = 0; j < x->comp_rd_stats_idx; ++j) {
if (is_comp_rd_match(cpi, x, &x->comp_rd_stats[j], mbmi, comp_rate,
comp_dist, comp_model_rate, comp_model_dist,
comp_rs2)) {
*match_index = j;
return 1;
}
}
return 0; // no match result found
}
static inline bool enable_wedge_search(
MACROBLOCK *const x, const unsigned int disable_wedge_var_thresh) {
// Enable wedge search if source variance and edge strength are above
// the thresholds.
return x->source_variance > disable_wedge_var_thresh;
}
static inline bool enable_wedge_interinter_search(MACROBLOCK *const x,
const AV1_COMP *const cpi) {
return enable_wedge_search(
x, cpi->sf.inter_sf.disable_interinter_wedge_var_thresh) &&
cpi->oxcf.comp_type_cfg.enable_interinter_wedge;
}
static inline bool enable_wedge_interintra_search(MACROBLOCK *const x,
const AV1_COMP *const cpi) {
return enable_wedge_search(
x, cpi->sf.inter_sf.disable_interintra_wedge_var_thresh) &&
cpi->oxcf.comp_type_cfg.enable_interintra_wedge;
}
static int8_t estimate_wedge_sign(const AV1_COMP *cpi, const MACROBLOCK *x,
const BLOCK_SIZE bsize, const uint8_t *pred0,
int stride0, const uint8_t *pred1,
int stride1) {
static const BLOCK_SIZE split_qtr[BLOCK_SIZES_ALL] = {
// 4X4
BLOCK_INVALID,
// 4X8, 8X4, 8X8
BLOCK_INVALID, BLOCK_INVALID, BLOCK_4X4,
// 8X16, 16X8, 16X16
BLOCK_4X8, BLOCK_8X4, BLOCK_8X8,
// 16X32, 32X16, 32X32
BLOCK_8X16, BLOCK_16X8, BLOCK_16X16,
// 32X64, 64X32, 64X64
BLOCK_16X32, BLOCK_32X16, BLOCK_32X32,
// 64x128, 128x64, 128x128
BLOCK_32X64, BLOCK_64X32, BLOCK_64X64,
// 4X16, 16X4, 8X32
BLOCK_INVALID, BLOCK_INVALID, BLOCK_4X16,
// 32X8, 16X64, 64X16
BLOCK_16X4, BLOCK_8X32, BLOCK_32X8
};
const struct macroblock_plane *const p = &x->plane[0];
const uint8_t *src = p->src.buf;
int src_stride = p->src.stride;
const int bw = block_size_wide[bsize];
const int bh = block_size_high[bsize];
const int bw_by2 = bw >> 1;
const int bh_by2 = bh >> 1;
uint32_t esq[2][2];
int64_t tl, br;
const BLOCK_SIZE f_index = split_qtr[bsize];
assert(f_index != BLOCK_INVALID);
if (is_cur_buf_hbd(&x->e_mbd)) {
pred0 = CONVERT_TO_BYTEPTR(pred0);
pred1 = CONVERT_TO_BYTEPTR(pred1);
}
// Residual variance computation over relevant quandrants in order to
// find TL + BR, TL = sum(1st,2nd,3rd) quadrants of (pred0 - pred1),
// BR = sum(2nd,3rd,4th) quadrants of (pred1 - pred0)
// The 2nd and 3rd quadrants cancel out in TL + BR
// Hence TL + BR = 1st quadrant of (pred0-pred1) + 4th of (pred1-pred0)
// TODO(nithya): Sign estimation assumes 45 degrees (1st and 4th quadrants)
// for all codebooks; experiment with other quadrant combinations for
// 0, 90 and 135 degrees also.
cpi->ppi->fn_ptr[f_index].vf(src, src_stride, pred0, stride0, &esq[0][0]);
cpi->ppi->fn_ptr[f_index].vf(src + bh_by2 * src_stride + bw_by2, src_stride,
pred0 + bh_by2 * stride0 + bw_by2, stride0,
&esq[0][1]);
cpi->ppi->fn_ptr[f_index].vf(src, src_stride, pred1, stride1, &esq[1][0]);
cpi->ppi->fn_ptr[f_index].vf(src + bh_by2 * src_stride + bw_by2, src_stride,
pred1 + bh_by2 * stride1 + bw_by2, stride0,
&esq[1][1]);
tl = ((int64_t)esq[0][0]) - ((int64_t)esq[1][0]);
br = ((int64_t)esq[1][1]) - ((int64_t)esq[0][1]);
return (tl + br > 0);
}
// Choose the best wedge index and sign
static int64_t pick_wedge(const AV1_COMP *const cpi, const MACROBLOCK *const x,
const BLOCK_SIZE bsize, const uint8_t *const p0,
const int16_t *const residual1,
const int16_t *const diff10,
int8_t *const best_wedge_sign,
int8_t *const best_wedge_index, uint64_t *best_sse) {
const MACROBLOCKD *const xd = &x->e_mbd;
const struct buf_2d *const src = &x->plane[0].src;
const int bw = block_size_wide[bsize];
const int bh = block_size_high[bsize];
const int N = bw * bh;
assert(N >= 64);
int rate;
int64_t dist;
int64_t rd, best_rd = INT64_MAX;
int8_t wedge_index;
int8_t wedge_sign;
const int8_t wedge_types = get_wedge_types_lookup(bsize);
const uint8_t *mask;
uint64_t sse;
const int hbd = is_cur_buf_hbd(xd);
const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
DECLARE_ALIGNED(32, int16_t, residual0[MAX_SB_SQUARE]); // src - pred0
#if CONFIG_AV1_HIGHBITDEPTH
if (hbd) {
aom_highbd_subtract_block(bh, bw, residual0, bw, src->buf, src->stride,
CONVERT_TO_BYTEPTR(p0), bw);
} else {
aom_subtract_block(bh, bw, residual0, bw, src->buf, src->stride, p0, bw);
}
#else
(void)hbd;
aom_subtract_block(bh, bw, residual0, bw, src->buf, src->stride, p0, bw);
#endif
int64_t sign_limit = ((int64_t)aom_sum_squares_i16(residual0, N) -
(int64_t)aom_sum_squares_i16(residual1, N)) *
(1 << WEDGE_WEIGHT_BITS) / 2;
int16_t *ds = residual0;
av1_wedge_compute_delta_squares(ds, residual0, residual1, N);
for (wedge_index = 0; wedge_index < wedge_types; ++wedge_index) {
mask = av1_get_contiguous_soft_mask(wedge_index, 0, bsize);
wedge_sign = av1_wedge_sign_from_residuals(ds, mask, N, sign_limit);
mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, bsize);
sse = av1_wedge_sse_from_residuals(residual1, diff10, mask, N);
sse = ROUND_POWER_OF_TWO(sse, bd_round);
model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
&rate, &dist);
rate += x->mode_costs.wedge_idx_cost[bsize][wedge_index];
rd = RDCOST(x->rdmult, rate, dist);
if (rd < best_rd) {
*best_wedge_index = wedge_index;
*best_wedge_sign = wedge_sign;
best_rd = rd;
*best_sse = sse;
}
}
return best_rd -
RDCOST(x->rdmult,
x->mode_costs.wedge_idx_cost[bsize][*best_wedge_index], 0);
}
// Choose the best wedge index the specified sign
static int64_t pick_wedge_fixed_sign(
const AV1_COMP *const cpi, const MACROBLOCK *const x,
const BLOCK_SIZE bsize, const int16_t *const residual1,
const int16_t *const diff10, const int8_t wedge_sign,
int8_t *const best_wedge_index, uint64_t *best_sse) {
const MACROBLOCKD *const xd = &x->e_mbd;
const int bw = block_size_wide[bsize];
const int bh = block_size_high[bsize];
const int N = bw * bh;
assert(N >= 64);
int rate;
int64_t dist;
int64_t rd, best_rd = INT64_MAX;
int8_t wedge_index;
const int8_t wedge_types = get_wedge_types_lookup(bsize);
const uint8_t *mask;
uint64_t sse;
const int hbd = is_cur_buf_hbd(xd);
const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
for (wedge_index = 0; wedge_index < wedge_types; ++wedge_index) {
mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, bsize);
sse = av1_wedge_sse_from_residuals(residual1, diff10, mask, N);
sse = ROUND_POWER_OF_TWO(sse, bd_round);
model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
&rate, &dist);
rate += x->mode_costs.wedge_idx_cost[bsize][wedge_index];
rd = RDCOST(x->rdmult, rate, dist);
if (rd < best_rd) {
*best_wedge_index = wedge_index;
best_rd = rd;
*best_sse = sse;
}
}
return best_rd -
RDCOST(x->rdmult,
x->mode_costs.wedge_idx_cost[bsize][*best_wedge_index], 0);
}
static int64_t pick_interinter_wedge(
const AV1_COMP *const cpi, MACROBLOCK *const x, const BLOCK_SIZE bsize,
const uint8_t *const p0, const uint8_t *const p1,
const int16_t *const residual1, const int16_t *const diff10,
uint64_t *best_sse) {
MACROBLOCKD *const xd = &x->e_mbd;
MB_MODE_INFO *const mbmi = xd->mi[0];
const int bw = block_size_wide[bsize];
int64_t rd;
int8_t wedge_index = -1;
int8_t wedge_sign = 0;
assert(is_interinter_compound_used(COMPOUND_WEDGE, bsize));
assert(cpi->common.seq_params->enable_masked_compound);
if (cpi->sf.inter_sf.fast_wedge_sign_estimate) {
wedge_sign = estimate_wedge_sign(cpi, x, bsize, p0, bw, p1, bw);
rd = pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, wedge_sign,
&wedge_index, best_sse);
} else {
rd = pick_wedge(cpi, x, bsize, p0, residual1, diff10, &wedge_sign,
&wedge_index, best_sse);
}
mbmi->interinter_comp.wedge_sign = wedge_sign;
mbmi->interinter_comp.wedge_index = wedge_index;
return rd;
}
static int64_t pick_interinter_seg(const AV1_COMP *const cpi,
MACROBLOCK *const x, const BLOCK_SIZE bsize,
const uint8_t *const p0,
const uint8_t *const p1,
const int16_t *const residual1,
const int16_t *const diff10,
uint64_t *best_sse) {
MACROBLOCKD *const xd = &x->e_mbd;
MB_MODE_INFO *const mbmi = xd->mi[0];
const int bw = block_size_wide[bsize];
const int bh = block_size_high[bsize];
const int N = 1 << num_pels_log2_lookup[bsize];
int rate;
int64_t dist;
DIFFWTD_MASK_TYPE cur_mask_type;
int64_t best_rd = INT64_MAX;
DIFFWTD_MASK_TYPE best_mask_type = 0;
const int hbd = is_cur_buf_hbd(xd);
const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
DECLARE_ALIGNED(16, uint8_t, seg_mask[2 * MAX_SB_SQUARE]);
uint8_t *tmp_mask[2] = { xd->seg_mask, seg_mask };
// try each mask type and its inverse
for (cur_mask_type = 0; cur_mask_type < DIFFWTD_MASK_TYPES; cur_mask_type++) {
// build mask and inverse
#if CONFIG_AV1_HIGHBITDEPTH
if (hbd)
av1_build_compound_diffwtd_mask_highbd(
tmp_mask[cur_mask_type], cur_mask_type, CONVERT_TO_BYTEPTR(p0), bw,
CONVERT_TO_BYTEPTR(p1), bw, bh, bw, xd->bd);
else
av1_build_compound_diffwtd_mask(tmp_mask[cur_mask_type], cur_mask_type,
p0, bw, p1, bw, bh, bw);
#else
(void)hbd;
av1_build_compound_diffwtd_mask(tmp_mask[cur_mask_type], cur_mask_type, p0,
bw, p1, bw, bh, bw);
#endif // CONFIG_AV1_HIGHBITDEPTH
// compute rd for mask
uint64_t sse = av1_wedge_sse_from_residuals(residual1, diff10,
tmp_mask[cur_mask_type], N);
sse = ROUND_POWER_OF_TWO(sse, bd_round);
model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
&rate, &dist);
const int64_t rd0 = RDCOST(x->rdmult, rate, dist);
if (rd0 < best_rd) {
best_mask_type = cur_mask_type;
best_rd = rd0;
*best_sse = sse;
}
}
mbmi->interinter_comp.mask_type = best_mask_type;
if (best_mask_type == DIFFWTD_38_INV) {
memcpy(xd->seg_mask, seg_mask, N * 2);
}
return best_rd;
}
static int64_t pick_interintra_wedge(const AV1_COMP *const cpi,
const MACROBLOCK *const x,
const BLOCK_SIZE bsize,
const uint8_t *const p0,
const uint8_t *const p1) {
const MACROBLOCKD *const xd = &x->e_mbd;
MB_MODE_INFO *const mbmi = xd->mi[0];
assert(av1_is_wedge_used(bsize));
assert(cpi->common.seq_params->enable_interintra_compound);
const struct buf_2d *const src = &x->plane[0].src;
const int bw = block_size_wide[bsize];
const int bh = block_size_high[bsize];
DECLARE_ALIGNED(32, int16_t, residual1[MAX_SB_SQUARE]); // src - pred1
DECLARE_ALIGNED(32, int16_t, diff10[MAX_SB_SQUARE]); // pred1 - pred0
#if CONFIG_AV1_HIGHBITDEPTH
if (is_cur_buf_hbd(xd)) {
aom_highbd_subtract_block(bh, bw, residual1, bw, src->buf, src->stride,
CONVERT_TO_BYTEPTR(p1), bw);
aom_highbd_subtract_block(bh, bw, diff10, bw, CONVERT_TO_BYTEPTR(p1), bw,
CONVERT_TO_BYTEPTR(p0), bw);
} else {
aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, p1, bw);
aom_subtract_block(bh, bw, diff10, bw, p1, bw, p0, bw);
}
#else
aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, p1, bw);
aom_subtract_block(bh, bw, diff10, bw, p1, bw, p0, bw);
#endif
int8_t wedge_index = -1;
uint64_t sse;
int64_t rd = pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, 0,
&wedge_index, &sse);
mbmi->interintra_wedge_index = wedge_index;
return rd;
}
static inline void get_inter_predictors_masked_compound(
MACROBLOCK *x, const BLOCK_SIZE bsize, uint8_t **preds0, uint8_t **preds1,
int16_t *residual1, int16_t *diff10, int *strides) {
MACROBLOCKD *xd = &x->e_mbd;
const int bw = block_size_wide[bsize];
const int bh = block_size_high[bsize];
// get inter predictors to use for masked compound modes
av1_build_inter_predictors_for_planes_single_buf(xd, bsize, 0, 0, 0, preds0,
strides);
av1_build_inter_predictors_for_planes_single_buf(xd, bsize, 0, 0, 1, preds1,
strides);
const struct buf_2d *const src = &x->plane[0].src;
#if CONFIG_AV1_HIGHBITDEPTH
if (is_cur_buf_hbd(xd)) {
aom_highbd_subtract_block(bh, bw, residual1, bw, src->buf, src->stride,
CONVERT_TO_BYTEPTR(*preds1), bw);
aom_highbd_subtract_block(bh, bw, diff10, bw, CONVERT_TO_BYTEPTR(*preds1),
bw, CONVERT_TO_BYTEPTR(*preds0), bw);
} else {
aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, *preds1,
bw);
aom_subtract_block(bh, bw, diff10, bw, *preds1, bw, *preds0, bw);
}
#else
aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, *preds1, bw);
aom_subtract_block(bh, bw, diff10, bw, *preds1, bw, *preds0, bw);
#endif
}
// Computes the rd cost for the given interintra mode and updates the best
static inline void compute_best_interintra_mode(
const AV1_COMP *const cpi, MB_MODE_INFO *mbmi, MACROBLOCKD *xd,
MACROBLOCK *const x, const int *const interintra_mode_cost,
const BUFFER_SET *orig_dst, uint8_t *intrapred, const uint8_t *tmp_buf,
INTERINTRA_MODE *best_interintra_mode, int64_t *best_interintra_rd,
INTERINTRA_MODE interintra_mode, BLOCK_SIZE bsize) {
const AV1_COMMON *const cm = &cpi->common;
int rate;
uint8_t skip_txfm_sb;
int64_t dist, skip_sse_sb;
const int bw = block_size_wide[bsize];
mbmi->interintra_mode = interintra_mode;
int rmode = interintra_mode_cost[interintra_mode];
av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
intrapred, bw);
av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
model_rd_sb_fn[MODELRD_TYPE_INTERINTRA](cpi, bsize, x, xd, 0, 0, &rate, &dist,
&skip_txfm_sb, &skip_sse_sb, NULL,
NULL, NULL);
int64_t rd = RDCOST(x->rdmult, rate + rmode, dist);
if (rd < *best_interintra_rd) {
*best_interintra_rd = rd;
*best_interintra_mode = mbmi->interintra_mode;
}
}
static int64_t estimate_yrd_for_sb(const AV1_COMP *const cpi, BLOCK_SIZE bs,
MACROBLOCK *x, int64_t ref_best_rd,
RD_STATS *rd_stats) {
MACROBLOCKD *const xd = &x->e_mbd;
if (ref_best_rd < 0) return INT64_MAX;
av1_subtract_plane(x, bs, 0);
const int64_t rd = av1_estimate_txfm_yrd(cpi, x, rd_stats, ref_best_rd, bs,
max_txsize_rect_lookup[bs]);
if (rd != INT64_MAX) {
const int skip_ctx = av1_get_skip_txfm_context(xd);
if (rd_stats->skip_txfm) {
const int s1 = x->mode_costs.skip_txfm_cost[skip_ctx][1];
rd_stats->rate = s1;
} else {
const int s0 = x->mode_costs.skip_txfm_cost[skip_ctx][0];
rd_stats->rate += s0;
}
}
return rd;
}
// Computes the rd_threshold for smooth interintra rd search.
static inline int64_t compute_rd_thresh(MACROBLOCK *const x,
int total_mode_rate,
int64_t ref_best_rd) {
const int64_t rd_thresh = get_rd_thresh_from_best_rd(
ref_best_rd, (1 << INTER_INTRA_RD_THRESH_SHIFT),
INTER_INTRA_RD_THRESH_SCALE);
const int64_t mode_rd = RDCOST(x->rdmult, total_mode_rate, 0);
return (rd_thresh - mode_rd);
}
// Computes the best wedge interintra mode
static inline int64_t compute_best_wedge_interintra(
const AV1_COMP *const cpi, MB_MODE_INFO *mbmi, MACROBLOCKD *xd,
MACROBLOCK *const x, const int *const interintra_mode_cost,
const BUFFER_SET *orig_dst, uint8_t *intrapred_, uint8_t *tmp_buf_,
int *best_mode, int *best_wedge_index, BLOCK_SIZE bsize) {
const AV1_COMMON *const cm = &cpi->common;
const int bw = block_size_wide[bsize];
int64_t best_interintra_rd_wedge = INT64_MAX;
int64_t best_total_rd = INT64_MAX;
uint8_t *intrapred = get_buf_by_bd(xd, intrapred_);
for (INTERINTRA_MODE mode = 0; mode < INTERINTRA_MODES; ++mode) {
mbmi->interintra_mode = mode;
av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
intrapred, bw);
int64_t rd = pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
const int rate_overhead =
interintra_mode_cost[mode] +
x->mode_costs.wedge_idx_cost[bsize][mbmi->interintra_wedge_index];
const int64_t total_rd = rd + RDCOST(x->rdmult, rate_overhead, 0);
if (total_rd < best_total_rd) {
best_total_rd = total_rd;
best_interintra_rd_wedge = rd;
*best_mode = mbmi->interintra_mode;
*best_wedge_index = mbmi->interintra_wedge_index;
}
}
return best_interintra_rd_wedge;
}
static int handle_smooth_inter_intra_mode(
const AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize,
MB_MODE_INFO *mbmi, int64_t ref_best_rd, int *rate_mv,
INTERINTRA_MODE *best_interintra_mode, int64_t *best_rd,
int *best_mode_rate, const BUFFER_SET *orig_dst, uint8_t *tmp_buf,
uint8_t *intrapred, HandleInterModeArgs *args) {
MACROBLOCKD *xd = &x->e_mbd;
const ModeCosts *mode_costs = &x->mode_costs;
const int *const interintra_mode_cost =
mode_costs->interintra_mode_cost[size_group_lookup[bsize]];
const AV1_COMMON *const cm = &cpi->common;
const int bw = block_size_wide[bsize];
mbmi->use_wedge_interintra = 0;
if (cpi->sf.inter_sf.reuse_inter_intra_mode == 0 ||
*best_interintra_mode == INTERINTRA_MODES) {
int64_t best_interintra_rd = INT64_MAX;
for (INTERINTRA_MODE cur_mode = 0; cur_mode < INTERINTRA_MODES;
++cur_mode) {
if ((!cpi->oxcf.intra_mode_cfg.enable_smooth_intra ||
cpi->sf.intra_sf.disable_smooth_intra) &&
cur_mode == II_SMOOTH_PRED)
continue;
compute_best_interintra_mode(
cpi, mbmi, xd, x, interintra_mode_cost, orig_dst, intrapred, tmp_buf,
best_interintra_mode, &best_interintra_rd, cur_mode, bsize);
}
args->inter_intra_mode[mbmi->ref_frame[0]] = *best_interintra_mode;
}
assert(IMPLIES(!cpi->oxcf.comp_type_cfg.enable_smooth_interintra,
*best_interintra_mode != II_SMOOTH_PRED));
// Recompute prediction if required
bool interintra_mode_reuse = cpi->sf.inter_sf.reuse_inter_intra_mode ||
*best_interintra_mode != INTERINTRA_MODES;
if (interintra_mode_reuse || *best_interintra_mode != INTERINTRA_MODES - 1) {
mbmi->interintra_mode = *best_interintra_mode;
av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
intrapred, bw);
av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
}
// Compute rd cost for best smooth_interintra
RD_STATS rd_stats;
const int is_wedge_used = av1_is_wedge_used(bsize);
const int rmode =
interintra_mode_cost[*best_interintra_mode] +
(is_wedge_used ? mode_costs->wedge_interintra_cost[bsize][0] : 0);
const int total_mode_rate = rmode + *rate_mv;
const int64_t rd_thresh = compute_rd_thresh(x, total_mode_rate, ref_best_rd);
int64_t rd = estimate_yrd_for_sb(cpi, bsize, x, rd_thresh, &rd_stats);
if (rd != INT64_MAX) {
rd = RDCOST(x->rdmult, total_mode_rate + rd_stats.rate, rd_stats.dist);
} else {
return IGNORE_MODE;
}
*best_rd = rd;
*best_mode_rate = rmode;
// Return early if best rd not good enough
if (ref_best_rd < INT64_MAX &&
(*best_rd >> INTER_INTRA_RD_THRESH_SHIFT) * INTER_INTRA_RD_THRESH_SCALE >
ref_best_rd) {
return IGNORE_MODE;
}
return 0;
}
static int handle_wedge_inter_intra_mode(
const AV1_COMP *const cpi, MACROBLOCK *const x, BLOCK_SIZE bsize,
MB_MODE_INFO *mbmi, int *rate_mv, INTERINTRA_MODE *best_interintra_mode,
int64_t *best_rd, const BUFFER_SET *orig_dst, uint8_t *tmp_buf_,
uint8_t *tmp_buf, uint8_t *intrapred_, uint8_t *intrapred,
HandleInterModeArgs *args, int *tmp_rate_mv, int *rate_overhead,
int_mv *tmp_mv, int64_t best_rd_no_wedge) {
MACROBLOCKD *xd = &x->e_mbd;
const ModeCosts *mode_costs = &x->mode_costs;
const int *const interintra_mode_cost =
mode_costs->interintra_mode_cost[size_group_lookup[bsize]];
const AV1_COMMON *const cm = &cpi->common;
const int bw = block_size_wide[bsize];
const int try_smooth_interintra =
cpi->oxcf.comp_type_cfg.enable_smooth_interintra;
mbmi->use_wedge_interintra = 1;
if (!cpi->sf.inter_sf.fast_interintra_wedge_search) {
// Exhaustive search of all wedge and mode combinations.
int best_mode = 0;
int best_wedge_index = 0;
*best_rd = compute_best_wedge_interintra(
cpi, mbmi, xd, x, interintra_mode_cost, orig_dst, intrapred_, tmp_buf_,
&best_mode, &best_wedge_index, bsize);
mbmi->interintra_mode = best_mode;
mbmi->interintra_wedge_index = best_wedge_index;
if (best_mode != INTERINTRA_MODES - 1) {
av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
intrapred, bw);
}
} else if (!try_smooth_interintra) {
if (*best_interintra_mode == INTERINTRA_MODES) {
mbmi->interintra_mode = INTERINTRA_MODES - 1;
*best_interintra_mode = INTERINTRA_MODES - 1;
av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
intrapred, bw);
// Pick wedge mask based on INTERINTRA_MODES - 1
*best_rd = pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
// Find the best interintra mode for the chosen wedge mask
for (INTERINTRA_MODE cur_mode = 0; cur_mode < INTERINTRA_MODES;
++cur_mode) {
compute_best_interintra_mode(
cpi, mbmi, xd, x, interintra_mode_cost, orig_dst, intrapred,
tmp_buf, best_interintra_mode, best_rd, cur_mode, bsize);
}
args->inter_intra_mode[mbmi->ref_frame[0]] = *best_interintra_mode;
mbmi->interintra_mode = *best_interintra_mode;
// Recompute prediction if required
if (*best_interintra_mode != INTERINTRA_MODES - 1) {
av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
intrapred, bw);
}
} else {
// Pick wedge mask for the best interintra mode (reused)
mbmi->interintra_mode = *best_interintra_mode;
av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
intrapred, bw);
*best_rd = pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
}
} else {
// Pick wedge mask for the best interintra mode from smooth_interintra
*best_rd = pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
}
*rate_overhead =
interintra_mode_cost[mbmi->interintra_mode] +
mode_costs->wedge_idx_cost[bsize][mbmi->interintra_wedge_index] +
mode_costs->wedge_interintra_cost[bsize][1];
*best_rd += RDCOST(x->rdmult, *rate_overhead + *rate_mv, 0);
int64_t rd = INT64_MAX;
const int_mv mv0 = mbmi->mv[0];
// Refine motion vector for NEWMV case.
if (have_newmv_in_inter_mode(mbmi->mode)) {
int rate_sum;
uint8_t skip_txfm_sb;
int64_t dist_sum, skip_sse_sb;
// get negative of mask
const uint8_t *mask =
av1_get_contiguous_soft_mask(mbmi->interintra_wedge_index, 1, bsize);
av1_compound_single_motion_search(cpi, x, bsize, &tmp_mv->as_mv, intrapred,
mask, bw, tmp_rate_mv, 0);
if (mbmi->mv[0].as_int != tmp_mv->as_int) {
mbmi->mv[0].as_int = tmp_mv->as_int;
// Set ref_frame[1] to NONE_FRAME temporarily so that the intra
// predictor is not calculated again in av1_enc_build_inter_predictor().
mbmi->ref_frame[1] = NONE_FRAME;
const int mi_row = xd->mi_row;
const int mi_col = xd->mi_col;
av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
AOM_PLANE_Y, AOM_PLANE_Y);
mbmi->ref_frame[1] = INTRA_FRAME;
av1_combine_interintra(xd, bsize, 0, xd->plane[AOM_PLANE_Y].dst.buf,
xd->plane[AOM_PLANE_Y].dst.stride, intrapred, bw);
model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum, &skip_txfm_sb,
&skip_sse_sb, NULL, NULL, NULL);
rd =
RDCOST(x->rdmult, *tmp_rate_mv + *rate_overhead + rate_sum, dist_sum);
}
}
if (rd >= *best_rd) {
tmp_mv->as_int = mv0.as_int;
*tmp_rate_mv = *rate_mv;
av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
}
// Evaluate closer to true rd
RD_STATS rd_stats;
const int64_t mode_rd = RDCOST(x->rdmult, *rate_overhead + *tmp_rate_mv, 0);
const int64_t tmp_rd_thresh = best_rd_no_wedge - mode_rd;
rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &rd_stats);
if (rd != INT64_MAX) {
rd = RDCOST(x->rdmult, *rate_overhead + *tmp_rate_mv + rd_stats.rate,
rd_stats.dist);
} else {
if (*best_rd == INT64_MAX) return IGNORE_MODE;
}
*best_rd = rd;
return 0;
}
int av1_handle_inter_intra_mode(const AV1_COMP *const cpi, MACROBLOCK *const x,
BLOCK_SIZE bsize, MB_MODE_INFO *mbmi,
HandleInterModeArgs *args, int64_t ref_best_rd,
int *rate_mv, int *tmp_rate2,
const BUFFER_SET *orig_dst) {
const int try_smooth_interintra =
cpi->oxcf.comp_type_cfg.enable_smooth_interintra;
const int is_wedge_used = av1_is_wedge_used(bsize);
const int try_wedge_interintra =
is_wedge_used && enable_wedge_interintra_search(x, cpi);
const AV1_COMMON *const cm = &cpi->common;
MACROBLOCKD *xd = &x->e_mbd;
const int bw = block_size_wide[bsize];
DECLARE_ALIGNED(16, uint8_t, tmp_buf_[2 * MAX_INTERINTRA_SB_SQUARE]);
DECLARE_ALIGNED(16, uint8_t, intrapred_[2 * MAX_INTERINTRA_SB_SQUARE]);
uint8_t *tmp_buf = get_buf_by_bd(xd, tmp_buf_);
uint8_t *intrapred = get_buf_by_bd(xd, intrapred_);
const int mi_row = xd->mi_row;
const int mi_col = xd->mi_col;
// Single reference inter prediction
mbmi->ref_frame[1] = NONE_FRAME;
xd->plane[0].dst.buf = tmp_buf;
xd->plane[0].dst.stride = bw;
av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, NULL, bsize,
AOM_PLANE_Y, AOM_PLANE_Y);
const int num_planes = av1_num_planes(cm);
// Restore the buffers for intra prediction
restore_dst_buf(xd, *orig_dst, num_planes);
mbmi->ref_frame[1] = INTRA_FRAME;
INTERINTRA_MODE best_interintra_mode =
args->inter_intra_mode[mbmi->ref_frame[0]];
// Compute smooth_interintra
int64_t best_interintra_rd_nowedge = INT64_MAX;
int best_mode_rate = INT_MAX;
if (try_smooth_interintra) {
int ret = handle_smooth_inter_intra_mode(
cpi, x, bsize, mbmi, ref_best_rd, rate_mv, &best_interintra_mode,
&best_interintra_rd_nowedge, &best_mode_rate, orig_dst, tmp_buf,
intrapred, args);
if (ret == IGNORE_MODE) {
return IGNORE_MODE;
}
}
// Compute wedge interintra
int64_t best_interintra_rd_wedge = INT64_MAX;
const int_mv mv0 = mbmi->mv[0];
int_mv tmp_mv = mv0;
int tmp_rate_mv = 0;
int rate_overhead = 0;
if (try_wedge_interintra) {
int ret = handle_wedge_inter_intra_mode(
cpi, x, bsize, mbmi, rate_mv, &best_interintra_mode,
&best_interintra_rd_wedge, orig_dst, tmp_buf_, tmp_buf, intrapred_,
intrapred, args, &tmp_rate_mv, &rate_overhead, &tmp_mv,
best_interintra_rd_nowedge);
if (ret == IGNORE_MODE) {
return IGNORE_MODE;
}
}
if (best_interintra_rd_nowedge == INT64_MAX &&
best_interintra_rd_wedge == INT64_MAX) {
return IGNORE_MODE;
}
if (best_interintra_rd_wedge < best_interintra_rd_nowedge) {
mbmi->mv[0].as_int = tmp_mv.as_int;
*tmp_rate2 += tmp_rate_mv - *rate_mv;
*rate_mv = tmp_rate_mv;
best_mode_rate = rate_overhead;
} else if (try_smooth_interintra && try_wedge_interintra) {
// If smooth was best, but we over-wrote the values when evaluating the
// wedge mode, we need to recompute the smooth values.
mbmi->use_wedge_interintra = 0;
mbmi->interintra_mode = best_interintra_mode;
mbmi->mv[0].as_int = mv0.as_int;
av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
AOM_PLANE_Y, AOM_PLANE_Y);
}
*tmp_rate2 += best_mode_rate;
if (num_planes > 1) {
av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
AOM_PLANE_U, num_planes - 1);
}
return 0;
}
// Computes the valid compound_types to be evaluated
static inline int compute_valid_comp_types(MACROBLOCK *x,
const AV1_COMP *const cpi,
BLOCK_SIZE bsize,
int masked_compound_used,
int mode_search_mask,
COMPOUND_TYPE *valid_comp_types) {
const AV1_COMMON *cm = &cpi->common;
int valid_type_count = 0;
int comp_type, valid_check;
int8_t enable_masked_type[MASKED_COMPOUND_TYPES] = { 0, 0 };
const int try_average_comp = (mode_search_mask & (1 << COMPOUND_AVERAGE));
const int try_distwtd_comp =
((mode_search_mask & (1 << COMPOUND_DISTWTD)) &&
cm->seq_params->order_hint_info.enable_dist_wtd_comp == 1 &&
cpi->sf.inter_sf.use_dist_wtd_comp_flag != DIST_WTD_COMP_DISABLED);
// Check if COMPOUND_AVERAGE and COMPOUND_DISTWTD are valid cases
for (comp_type = COMPOUND_AVERAGE; comp_type <= COMPOUND_DISTWTD;
comp_type++) {
valid_check =
(comp_type == COMPOUND_AVERAGE) ? try_average_comp : try_distwtd_comp;
if (valid_check && is_interinter_compound_used(comp_type, bsize))
valid_comp_types[valid_type_count++] = comp_type;
}
// Check if COMPOUND_WEDGE and COMPOUND_DIFFWTD are valid cases
if (masked_compound_used) {
// enable_masked_type[0] corresponds to COMPOUND_WEDGE
// enable_masked_type[1] corresponds to COMPOUND_DIFFWTD
enable_masked_type[0] = enable_wedge_interinter_search(x, cpi);
enable_masked_type[1] = cpi->oxcf.comp_type_cfg.enable_diff_wtd_comp;
for (comp_type = COMPOUND_WEDGE; comp_type <= COMPOUND_DIFFWTD;
comp_type++) {
if ((mode_search_mask & (1 << comp_type)) &&
is_interinter_compound_used(comp_type, bsize) &&
enable_masked_type[comp_type - COMPOUND_WEDGE])
valid_comp_types[valid_type_count++] = comp_type;
}
}
return valid_type_count;
}
// Calculates the cost for compound type mask
static inline void calc_masked_type_cost(
const ModeCosts *mode_costs, BLOCK_SIZE bsize, int comp_group_idx_ctx,
int comp_index_ctx, int masked_compound_used, int *masked_type_cost) {
av1_zero_array(masked_type_cost, COMPOUND_TYPES);
// Account for group index cost when wedge and/or diffwtd prediction are
// enabled
if (masked_compound_used) {
// Compound group index of average and distwtd is 0
// Compound group index of wedge and diffwtd is 1
masked_type_cost[COMPOUND_AVERAGE] +=
mode_costs->comp_group_idx_cost[comp_group_idx_ctx][0];
masked_type_cost[COMPOUND_DISTWTD] += masked_type_cost[COMPOUND_AVERAGE];
masked_type_cost[COMPOUND_WEDGE] +=
mode_costs->comp_group_idx_cost[comp_group_idx_ctx][1];
masked_type_cost[COMPOUND_DIFFWTD] += masked_type_cost[COMPOUND_WEDGE];
}
// Compute the cost to signal compound index/type
masked_type_cost[COMPOUND_AVERAGE] +=
mode_costs->comp_idx_cost[comp_index_ctx][1];
masked_type_cost[COMPOUND_DISTWTD] +=
mode_costs->comp_idx_cost[comp_index_ctx][0];
masked_type_cost[COMPOUND_WEDGE] += mode_costs->compound_type_cost[bsize][0];
masked_type_cost[COMPOUND_DIFFWTD] +=
mode_costs->compound_type_cost[bsize][1];
}
// Updates mbmi structure with the relevant compound type info
static inline void update_mbmi_for_compound_type(MB_MODE_INFO *mbmi,
COMPOUND_TYPE cur_type) {
mbmi->interinter_comp.type = cur_type;
mbmi->comp_group_idx = (cur_type >= COMPOUND_WEDGE);
mbmi->compound_idx = (cur_type != COMPOUND_DISTWTD);
}
// When match is found, populate the compound type data
// and calculate the rd cost using the stored stats and
// update the mbmi appropriately.
static inline int populate_reuse_comp_type_data(
const MACROBLOCK *x, MB_MODE_INFO *mbmi,
BEST_COMP_TYPE_STATS *best_type_stats, int_mv *cur_mv, int32_t *comp_rate,
int64_t *comp_dist, int *comp_rs2, int *rate_mv, int64_t *rd,
int match_index) {
const int winner_comp_type =
x->comp_rd_stats[match_index].interinter_comp.type;
if (comp_rate[winner_comp_type] == INT_MAX)
return best_type_stats->best_compmode_interinter_cost;
update_mbmi_for_compound_type(mbmi, winner_comp_type);
mbmi->interinter_comp = x->comp_rd_stats[match_index].interinter_comp;
*rd = RDCOST(
x->rdmult,
comp_rs2[winner_comp_type] + *rate_mv + comp_rate[winner_comp_type],
comp_dist[winner_comp_type]);
mbmi->mv[0].as_int = cur_mv[0].as_int;
mbmi->mv[1].as_int = cur_mv[1].as_int;
return comp_rs2[winner_comp_type];
}
// Updates rd cost and relevant compound type data for the best compound type
static inline void update_best_info(const MB_MODE_INFO *const mbmi, int64_t *rd,
BEST_COMP_TYPE_STATS *best_type_stats,
int64_t best_rd_cur,
int64_t comp_model_rd_cur, int rs2) {
*rd = best_rd_cur;
best_type_stats->comp_best_model_rd = comp_model_rd_cur;
best_type_stats->best_compound_data = mbmi->interinter_comp;
best_type_stats->best_compmode_interinter_cost = rs2;
}
// Updates best_mv for masked compound types
static inline void update_mask_best_mv(const MB_MODE_INFO *const mbmi,
int_mv *best_mv, int *best_tmp_rate_mv,
int tmp_rate_mv) {
*best_tmp_rate_mv = tmp_rate_mv;
best_mv[0].as_int = mbmi->mv[0].as_int;
best_mv[1].as_int = mbmi->mv[1].as_int;
}
static inline void save_comp_rd_search_stat(
MACROBLOCK *x, const MB_MODE_INFO *const mbmi, const int32_t *comp_rate,
const int64_t *comp_dist, const int32_t *comp_model_rate,
const int64_t *comp_model_dist, const int_mv *cur_mv, const int *comp_rs2) {
const int offset = x->comp_rd_stats_idx;
if (offset < MAX_COMP_RD_STATS) {
COMP_RD_STATS *const rd_stats = x->comp_rd_stats + offset;
memcpy(rd_stats->rate, comp_rate, sizeof(rd_stats->rate));
memcpy(rd_stats->dist, comp_dist, sizeof(rd_stats->dist));
memcpy(rd_stats->model_rate, comp_model_rate, sizeof(rd_stats->model_rate));
memcpy(rd_stats->model_dist, comp_model_dist, sizeof(rd_stats->model_dist));
memcpy(rd_stats->comp_rs2, comp_rs2, sizeof(rd_stats->comp_rs2));
memcpy(rd_stats->mv, cur_mv, sizeof(rd_stats->mv));
memcpy(rd_stats->ref_frames, mbmi->ref_frame, sizeof(rd_stats->ref_frames));
rd_stats->mode = mbmi->mode;
rd_stats->filter = mbmi->interp_filters;
rd_stats->ref_mv_idx = mbmi->ref_mv_idx;
const MACROBLOCKD *const xd = &x->e_mbd;
for (int i = 0; i < 2; ++i) {
const WarpedMotionParams *const wm =
&xd->global_motion[mbmi->ref_frame[i]];
rd_stats->is_global[i] = is_global_mv_block(mbmi, wm->wmtype);
}
rd_stats->interinter_comp = mbmi->interinter_comp;
++x->comp_rd_stats_idx;
}
}
static inline int get_interinter_compound_mask_rate(
const ModeCosts *const mode_costs, const MB_MODE_INFO *const mbmi) {
const COMPOUND_TYPE compound_type = mbmi->interinter_comp.type;
// This function will be called only for COMPOUND_WEDGE and COMPOUND_DIFFWTD
if (compound_type == COMPOUND_WEDGE) {
return av1_is_wedge_used(mbmi->bsize)
? av1_cost_literal(1) +
mode_costs
->wedge_idx_cost[mbmi->bsize]
[mbmi->interinter_comp.wedge_index]
: 0;
} else {
assert(compound_type == COMPOUND_DIFFWTD);
return av1_cost_literal(1);
}
}
// Takes a backup of rate, distortion and model_rd for future reuse
static inline void backup_stats(COMPOUND_TYPE cur_type, int32_t *comp_rate,
int64_t *comp_dist, int32_t *comp_model_rate,
int64_t *comp_model_dist, int rate_sum,
int64_t dist_sum, RD_STATS *rd_stats,
int *comp_rs2, int rs2) {
comp_rate[cur_type] = rd_stats->rate;
comp_dist[cur_type] = rd_stats->dist;
comp_model_rate[cur_type] = rate_sum;
comp_model_dist[cur_type] = dist_sum;
comp_rs2[cur_type] = rs2;
}
static inline int save_mask_search_results(const PREDICTION_MODE this_mode,
const int reuse_level) {
if (reuse_level || (this_mode == NEW_NEWMV))
return 1;
else
return 0;
}
static inline int prune_mode_by_skip_rd(const AV1_COMP *const cpi,
MACROBLOCK *x, MACROBLOCKD *xd,
const BLOCK_SIZE bsize,
int64_t ref_skip_rd, int mode_rate) {
int eval_txfm = 1;
const int txfm_rd_gate_level =
get_txfm_rd_gate_level(cpi->common.seq_params->enable_masked_compound,
cpi->sf.inter_sf.txfm_rd_gate_level, bsize,
TX_SEARCH_COMP_TYPE_MODE, /*eval_motion_mode=*/0);
// Check if the mode is good enough based on skip rd
if (txfm_rd_gate_level) {
int64_t sse_y = compute_sse_plane(x, xd, PLANE_TYPE_Y, bsize);
int64_t skip_rd = RDCOST(x->rdmult, mode_rate, (sse_y << 4));
eval_txfm =
check_txfm_eval(x, bsize, ref_skip_rd, skip_rd, txfm_rd_gate_level, 1);
}
return eval_txfm;
}
static int64_t masked_compound_type_rd(
const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv,
const BLOCK_SIZE bsize, const PREDICTION_MODE this_mode, int *rs2,
int rate_mv, const BUFFER_SET *ctx, int *out_rate_mv, uint8_t **preds0,
uint8_t **preds1, int16_t *residual1, int16_t *diff10, int *strides,
int mode_rate, int64_t rd_thresh, int *calc_pred_masked_compound,
int32_t *comp_rate, int64_t *comp_dist, int32_t *comp_model_rate,
int64_t *comp_model_dist, const int64_t comp_best_model_rd,
int64_t *const comp_model_rd_cur, int *comp_rs2, int64_t ref_skip_rd) {
const AV1_COMMON *const cm = &cpi->common;
MACROBLOCKD *xd = &x->e_mbd;
MB_MODE_INFO *const mbmi = xd->mi[0];
int64_t best_rd_cur = INT64_MAX;
int64_t rd = INT64_MAX;
const COMPOUND_TYPE compound_type = mbmi->interinter_comp.type;
// This function will be called only for COMPOUND_WEDGE and COMPOUND_DIFFWTD
assert(compound_type == COMPOUND_WEDGE || compound_type == COMPOUND_DIFFWTD);
int rate_sum;
uint8_t tmp_skip_txfm_sb;
int64_t dist_sum, tmp_skip_sse_sb;
pick_interinter_mask_type pick_interinter_mask[2] = { pick_interinter_wedge,
pick_interinter_seg };
// TODO(any): Save pred and mask calculation as well into records. However
// this may increase memory requirements as compound segment mask needs to be
// stored in each record.
if (*calc_pred_masked_compound) {
get_inter_predictors_masked_compound(x, bsize, preds0, preds1, residual1,
diff10, strides);
*calc_pred_masked_compound = 0;
}
if (compound_type == COMPOUND_WEDGE) {
unsigned int sse;
if (is_cur_buf_hbd(xd))
(void)cpi->ppi->fn_ptr[bsize].vf(CONVERT_TO_BYTEPTR(*preds0), *strides,
CONVERT_TO_BYTEPTR(*preds1), *strides,
&sse);
else
(void)cpi->ppi->fn_ptr[bsize].vf(*preds0, *strides, *preds1, *strides,
&sse);
const unsigned int mse =
ROUND_POWER_OF_TWO(sse, num_pels_log2_lookup[bsize]);
// If two predictors are very similar, skip wedge compound mode search
if (mse < 8 || (!have_newmv_in_inter_mode(this_mode) && mse < 64)) {
*comp_model_rd_cur = INT64_MAX;
return INT64_MAX;
}
}
// Function pointer to pick the appropriate mask
// compound_type == COMPOUND_WEDGE, calls pick_interinter_wedge()
// compound_type == COMPOUND_DIFFWTD, calls pick_interinter_seg()
uint64_t cur_sse = UINT64_MAX;
best_rd_cur = pick_interinter_mask[compound_type - COMPOUND_WEDGE](
cpi, x, bsize, *preds0, *preds1, residual1, diff10, &cur_sse);
*rs2 += get_interinter_compound_mask_rate(&x->mode_costs, mbmi);
best_rd_cur += RDCOST(x->rdmult, *rs2 + rate_mv, 0);
assert(cur_sse != UINT64_MAX);
int64_t skip_rd_cur = RDCOST(x->rdmult, *rs2 + rate_mv, (cur_sse << 4));
// Although the true rate_mv might be different after motion search, but it
// is unlikely to be the best mode considering the transform rd cost and other
// mode overhead cost
int64_t mode_rd = RDCOST(x->rdmult, *rs2 + mode_rate, 0);
if (mode_rd > rd_thresh) {
*comp_model_rd_cur = INT64_MAX;
return INT64_MAX;
}
// Check if the mode is good enough based on skip rd
// TODO(nithya): Handle wedge_newmv_search if extending for lower speed
// setting
const int txfm_rd_gate_level =
get_txfm_rd_gate_level(cm->seq_params->enable_masked_compound,
cpi->sf.inter_sf.txfm_rd_gate_level, bsize,
TX_SEARCH_COMP_TYPE_MODE, /*eval_motion_mode=*/0);
if (txfm_rd_gate_level) {
int eval_txfm = check_txfm_eval(x, bsize, ref_skip_rd, skip_rd_cur,
txfm_rd_gate_level, 1);
if (!eval_txfm) {
*comp_model_rd_cur = INT64_MAX;
return INT64_MAX;
}
}
// Compute cost if matching record not found, else, reuse data
if (comp_rate[compound_type] == INT_MAX) {
// Check whether new MV search for wedge is to be done
int wedge_newmv_search =
have_newmv_in_inter_mode(this_mode) &&
(compound_type == COMPOUND_WEDGE) &&
(!cpi->sf.inter_sf.disable_interinter_wedge_newmv_search);
// Search for new MV if needed and build predictor
if (wedge_newmv_search) {
*out_rate_mv = av1_interinter_compound_motion_search(cpi, x, cur_mv,
bsize, this_mode);
const int mi_row = xd->mi_row;
const int mi_col = xd->mi_col;
av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, ctx, bsize,
AOM_PLANE_Y, AOM_PLANE_Y);
} else {
*out_rate_mv = rate_mv;
av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, strides,
preds1, strides);
}
// Get the RD cost from model RD
model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum, &tmp_skip_txfm_sb,
&tmp_skip_sse_sb, NULL, NULL, NULL);
rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + rate_sum, dist_sum);
*comp_model_rd_cur = rd;
// Override with best if current is worse than best for new MV
if (wedge_newmv_search) {
if (rd >= best_rd_cur) {
mbmi->mv[0].as_int = cur_mv[0].as_int;
mbmi->mv[1].as_int = cur_mv[1].as_int;
*out_rate_mv = rate_mv;
av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0,
strides, preds1, strides);
*comp_model_rd_cur = best_rd_cur;
}
}
if (cpi->sf.inter_sf.prune_comp_type_by_model_rd &&
(*comp_model_rd_cur > comp_best_model_rd) &&
comp_best_model_rd != INT64_MAX) {
*comp_model_rd_cur = INT64_MAX;
return INT64_MAX;
}
// Compute RD cost for the current type
RD_STATS rd_stats;
const int64_t tmp_mode_rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv, 0);
const int64_t tmp_rd_thresh = rd_thresh - tmp_mode_rd;
rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh, &rd_stats);
if (rd != INT64_MAX) {
rd =
RDCOST(x->rdmult, *rs2 + *out_rate_mv + rd_stats.rate, rd_stats.dist);
// Backup rate and distortion for future reuse
backup_stats(compound_type, comp_rate, comp_dist, comp_model_rate,
comp_model_dist, rate_sum, dist_sum, &rd_stats, comp_rs2,
*rs2);
}
} else {
// Reuse data as matching record is found
assert(comp_dist[compound_type] != INT64_MAX);
// When disable_interinter_wedge_newmv_search is set, motion refinement is
// disabled. Hence rate and distortion can be reused in this case as well
assert(IMPLIES((have_newmv_in_inter_mode(this_mode) &&
(compound_type == COMPOUND_WEDGE)),
cpi->sf.inter_sf.disable_interinter_wedge_newmv_search));
assert(mbmi->mv[0].as_int == cur_mv[0].as_int);
assert(mbmi->mv[1].as_int == cur_mv[1].as_int);
*out_rate_mv = rate_mv;
// Calculate RD cost based on stored stats
rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + comp_rate[compound_type],
comp_dist[compound_type]);
// Recalculate model rdcost with the updated rate
*comp_model_rd_cur =
RDCOST(x->rdmult, *rs2 + *out_rate_mv + comp_model_rate[compound_type],
comp_model_dist[compound_type]);
}
return rd;
}
// scaling values to be used for gating wedge/compound segment based on best
// approximate rd
static const int comp_type_rd_threshold_mul[3] = { 1, 11, 12 };
static const int comp_type_rd_threshold_div[3] = { 3, 16, 16 };
int av1_compound_type_rd(const AV1_COMP *const cpi, MACROBLOCK *x,
HandleInterModeArgs *args, BLOCK_SIZE bsize,
int_mv *cur_mv, int mode_search_mask,
int masked_compound_used, const BUFFER_SET *orig_dst,
const BUFFER_SET *tmp_dst,
const CompoundTypeRdBuffers *buffers, int *rate_mv,
int64_t *rd, RD_STATS *rd_stats, int64_t ref_best_rd,
int64_t ref_skip_rd, int *is_luma_interp_done,
int64_t rd_thresh) {
const AV1_COMMON *cm = &cpi->common;
MACROBLOCKD *xd = &x->e_mbd;
MB_MODE_INFO *mbmi = xd->mi[0];
const PREDICTION_MODE this_mode = mbmi->mode;
int ref_frame = av1_ref_frame_type(mbmi->ref_frame);
const int bw = block_size_wide[bsize];
int rs2;
int_mv best_mv[2];
int best_tmp_rate_mv = *rate_mv;
BEST_COMP_TYPE_STATS best_type_stats;
// Initializing BEST_COMP_TYPE_STATS
best_type_stats.best_compound_data.type = COMPOUND_AVERAGE;
best_type_stats.best_compmode_interinter_cost = 0;
best_type_stats.comp_best_model_rd = INT64_MAX;
uint8_t *preds0[1] = { buffers->pred0 };
uint8_t *preds1[1] = { buffers->pred1 };
int strides[1] = { bw };
int tmp_rate_mv;
COMPOUND_TYPE cur_type;
// Local array to store the mask cost for different compound types
int masked_type_cost[COMPOUND_TYPES];
int calc_pred_masked_compound = 1;
int64_t comp_dist[COMPOUND_TYPES] = { INT64_MAX, INT64_MAX, INT64_MAX,
INT64_MAX };
int32_t comp_rate[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX, INT_MAX };
int comp_rs2[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX, INT_MAX };
int32_t comp_model_rate[COMPOUND_TYPES] = { INT_MAX, INT_MAX, INT_MAX,
INT_MAX };
int64_t comp_model_dist[COMPOUND_TYPES] = { INT64_MAX, INT64_MAX, INT64_MAX,
INT64_MAX };
int match_index = 0;
const int match_found =
find_comp_rd_in_stats(cpi, x, mbmi, comp_rate, comp_dist, comp_model_rate,
comp_model_dist, comp_rs2, &match_index);
best_mv[0].as_int = cur_mv[0].as_int;
best_mv[1].as_int = cur_mv[1].as_int;
*rd = INT64_MAX;
// Local array to store the valid compound types to be evaluated in the core
// loop
COMPOUND_TYPE valid_comp_types[COMPOUND_TYPES] = {
COMPOUND_AVERAGE, COMPOUND_DISTWTD, COMPOUND_WEDGE, COMPOUND_DIFFWTD
};
int valid_type_count = 0;
// compute_valid_comp_types() returns the number of valid compound types to be
// evaluated and populates the same in the local array valid_comp_types[].
// It also sets the flag 'try_average_and_distwtd_comp'
valid_type_count = compute_valid_comp_types(
x, cpi, bsize, masked_compound_used, mode_search_mask, valid_comp_types);
// The following context indices are independent of compound type
const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
const int comp_index_ctx = get_comp_index_context(cm, xd);
// Populates masked_type_cost local array for the 4 compound types
calc_masked_type_cost(&x->mode_costs, bsize, comp_group_idx_ctx,
comp_index_ctx, masked_compound_used, masked_type_cost);
int64_t comp_model_rd_cur = INT64_MAX;
int64_t best_rd_cur = ref_best_rd;
const int mi_row = xd->mi_row;
const int mi_col = xd->mi_col;
// If the match is found, calculate the rd cost using the
// stored stats and update the mbmi appropriately.
if (match_found && cpi->sf.inter_sf.reuse_compound_type_decision) {
return populate_reuse_comp_type_data(x, mbmi, &best_type_stats, cur_mv,
comp_rate, comp_dist, comp_rs2,
rate_mv, rd, match_index);
}
// If COMPOUND_AVERAGE is not valid, use the spare buffer
if (valid_comp_types[0] != COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
// Loop over valid compound types
for (int i = 0; i < valid_type_count; i++) {
cur_type = valid_comp_types[i];
if (args->cmp_mode[ref_frame] == COMPOUND_AVERAGE) {
if (cur_type == COMPOUND_WEDGE) continue;
}
comp_model_rd_cur = INT64_MAX;
tmp_rate_mv = *rate_mv;
best_rd_cur = INT64_MAX;
ref_best_rd = AOMMIN(ref_best_rd, *rd);
update_mbmi_for_compound_type(mbmi, cur_type);
rs2 = masked_type_cost[cur_type];
int64_t mode_rd = RDCOST(x->rdmult, rs2 + rd_stats->rate, 0);
if (mode_rd >= ref_best_rd) continue;
// Derive the flags to indicate enabling/disabling of MV refinement process.
const int enable_fast_compound_mode_search =
cpi->sf.inter_sf.enable_fast_compound_mode_search;
const bool skip_mv_refinement_for_avg_distwtd =
enable_fast_compound_mode_search == 3 ||
(enable_fast_compound_mode_search == 2 && (this_mode != NEW_NEWMV));
const bool skip_mv_refinement_for_diffwtd =
(!enable_fast_compound_mode_search && cur_type == COMPOUND_DIFFWTD);
// Case COMPOUND_AVERAGE and COMPOUND_DISTWTD
if (cur_type < COMPOUND_WEDGE) {
if (skip_mv_refinement_for_avg_distwtd) {
int rate_sum;
uint8_t tmp_skip_txfm_sb;
int64_t dist_sum, tmp_skip_sse_sb;
// Reuse data if matching record is found
if (comp_rate[cur_type] == INT_MAX) {
av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
AOM_PLANE_Y, AOM_PLANE_Y);
if (cur_type == COMPOUND_AVERAGE) *is_luma_interp_done = 1;
// Compute RD cost for the current type
RD_STATS est_rd_stats;
const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh) - mode_rd;
int64_t est_rd = INT64_MAX;
int eval_txfm = prune_mode_by_skip_rd(cpi, x, xd, bsize, ref_skip_rd,
rs2 + *rate_mv);
// Evaluate further if skip rd is low enough
if (eval_txfm) {
est_rd = estimate_yrd_for_sb(cpi, bsize, x, tmp_rd_thresh,
&est_rd_stats);
}
if (est_rd != INT64_MAX) {
best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + est_rd_stats.rate,
est_rd_stats.dist);
model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum,
&tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
comp_model_rd_cur =
RDCOST(x->rdmult, rs2 + *rate_mv + rate_sum, dist_sum);
// Backup rate and distortion for future reuse
backup_stats(cur_type, comp_rate, comp_dist, comp_model_rate,
comp_model_dist, rate_sum, dist_sum, &est_rd_stats,
comp_rs2, rs2);
}
} else {
// Calculate RD cost based on stored stats
assert(comp_dist[cur_type] != INT64_MAX);
best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + comp_rate[cur_type],
comp_dist[cur_type]);
// Recalculate model rdcost with the updated rate
comp_model_rd_cur =
RDCOST(x->rdmult, rs2 + *rate_mv + comp_model_rate[cur_type],
comp_model_dist[cur_type]);
}
} else {
tmp_rate_mv = *rate_mv;
if (have_newmv_in_inter_mode(this_mode)) {
InterPredParams inter_pred_params;
av1_dist_wtd_comp_weight_assign(
&cpi->common, mbmi, &inter_pred_params.conv_params.fwd_offset,
&inter_pred_params.conv_params.bck_offset,
&inter_pred_params.conv_params.use_dist_wtd_comp_avg, 1);
int mask_value = inter_pred_params.conv_params.fwd_offset * 4;
memset(xd->seg_mask, mask_value,
sizeof(xd->seg_mask[0]) * 2 * MAX_SB_SQUARE);
tmp_rate_mv = av1_interinter_compound_motion_search(cpi, x, cur_mv,
bsize, this_mode);
}
av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
AOM_PLANE_Y, AOM_PLANE_Y);
if (cur_type == COMPOUND_AVERAGE) *is_luma_interp_done = 1;
int eval_txfm = prune_mode_by_skip_rd(cpi, x, xd, bsize, ref_skip_rd,
rs2 + *rate_mv);
if (eval_txfm) {
RD_STATS est_rd_stats;
estimate_yrd_for_sb(cpi, bsize, x, INT64_MAX, &est_rd_stats);
best_rd_cur = RDCOST(x->rdmult, rs2 + tmp_rate_mv + est_rd_stats.rate,
est_rd_stats.dist);
}
}
// use spare buffer for following compound type try
if (cur_type == COMPOUND_AVERAGE) restore_dst_buf(xd, *tmp_dst, 1);
} else if (cur_type == COMPOUND_WEDGE) {
int best_mask_index = 0;
int best_wedge_sign = 0;
int_mv tmp_mv[2] = { mbmi->mv[0], mbmi->mv[1] };
int best_rs2 = 0;
int best_rate_mv = *rate_mv;
int wedge_mask_size = get_wedge_types_lookup(bsize);
int need_mask_search = args->wedge_index == -1;
int wedge_newmv_search =
have_newmv_in_inter_mode(this_mode) &&
!cpi->sf.inter_sf.disable_interinter_wedge_newmv_search;
if (need_mask_search && !wedge_newmv_search) {
// short cut repeated single reference block build
av1_build_inter_predictors_for_planes_single_buf(xd, bsize, 0, 0, 0,
preds0, strides);
av1_build_inter_predictors_for_planes_single_buf(xd, bsize, 0, 0, 1,
preds1, strides);
}
for (int wedge_mask = 0; wedge_mask < wedge_mask_size && need_mask_search;
++wedge_mask) {
for (int wedge_sign = 0; wedge_sign < 2; ++wedge_sign) {
tmp_rate_mv = *rate_mv;
mbmi->interinter_comp.wedge_index = wedge_mask;
mbmi->interinter_comp.wedge_sign = wedge_sign;
rs2 = masked_type_cost[cur_type];
rs2 += get_interinter_compound_mask_rate(&x->mode_costs, mbmi);
mode_rd = RDCOST(x->rdmult, rs2 + rd_stats->rate, 0);
if (mode_rd >= ref_best_rd / 2) continue;
if (wedge_newmv_search) {
tmp_rate_mv = av1_interinter_compound_motion_search(
cpi, x, cur_mv, bsize, this_mode);
av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst,
bsize, AOM_PLANE_Y, AOM_PLANE_Y);
} else {
av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0,
strides, preds1, strides);
}
RD_STATS est_rd_stats;
int64_t this_rd_cur = INT64_MAX;
int eval_txfm = prune_mode_by_skip_rd(cpi, x, xd, bsize, ref_skip_rd,
rs2 + *rate_mv);
if (eval_txfm) {
this_rd_cur = estimate_yrd_for_sb(
cpi, bsize, x, AOMMIN(best_rd_cur, ref_best_rd), &est_rd_stats);
}
if (this_rd_cur < INT64_MAX) {
this_rd_cur =
RDCOST(x->rdmult, rs2 + tmp_rate_mv + est_rd_stats.rate,
est_rd_stats.dist);
}
if (this_rd_cur < best_rd_cur) {
best_mask_index = wedge_mask;
best_wedge_sign = wedge_sign;
best_rd_cur = this_rd_cur;
tmp_mv[0] = mbmi->mv[0];
tmp_mv[1] = mbmi->mv[1];
best_rate_mv = tmp_rate_mv;
best_rs2 = rs2;
}
}
// Consider the asymmetric partitions for oblique angle only if the
// corresponding symmetric partition is the best so far.
// Note: For horizontal and vertical types, both symmetric and
// asymmetric partitions are always considered.
if (cpi->sf.inter_sf.enable_fast_wedge_mask_search) {
// The first 4 entries in wedge_codebook_16_heqw/hltw/hgtw[16]
// correspond to symmetric partitions of the 4 oblique angles, the
// next 4 entries correspond to the vertical/horizontal
// symmetric/asymmetric partitions and the last 8 entries correspond
// to the asymmetric partitions of oblique types.
const int idx_before_asym_oblique = 7;
const int last_oblique_sym_idx = 3;
if (wedge_mask == idx_before_asym_oblique) {
if (best_mask_index > last_oblique_sym_idx) {
break;
} else {
// Asymmetric (Index-1) map for the corresponding oblique masks.
// WEDGE_OBLIQUE27: sym - 0, asym - 8, 9
// WEDGE_OBLIQUE63: sym - 1, asym - 12, 13
// WEDGE_OBLIQUE117: sym - 2, asym - 14, 15
// WEDGE_OBLIQUE153: sym - 3, asym - 10, 11
const int asym_mask_idx[4] = { 7, 11, 13, 9 };
wedge_mask = asym_mask_idx[best_mask_index];
wedge_mask_size = wedge_mask + 3;
}
}
}
}
if (need_mask_search) {
if (save_mask_search_results(
this_mode, cpi->sf.inter_sf.reuse_mask_search_results)) {
args->wedge_index = best_mask_index;
args->wedge_sign = best_wedge_sign;
}
} else {
mbmi->interinter_comp.wedge_index = args->wedge_index;
mbmi->interinter_comp.wedge_sign = args->wedge_sign;
rs2 = masked_type_cost[cur_type];
rs2 += get_interinter_compound_mask_rate(&x->mode_costs, mbmi);
if (wedge_newmv_search) {
tmp_rate_mv = av1_interinter_compound_motion_search(cpi, x, cur_mv,
bsize, this_mode);
}
best_mask_index = args->wedge_index;
best_wedge_sign = args->wedge_sign;
tmp_mv[0] = mbmi->mv[0];
tmp_mv[1] = mbmi->mv[1];
best_rate_mv = tmp_rate_mv;
best_rs2 = masked_type_cost[cur_type];
best_rs2 += get_interinter_compound_mask_rate(&x->mode_costs, mbmi);
av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
AOM_PLANE_Y, AOM_PLANE_Y);
int eval_txfm = prune_mode_by_skip_rd(cpi, x, xd, bsize, ref_skip_rd,
best_rs2 + *rate_mv);
if (eval_txfm) {
RD_STATS est_rd_stats;
estimate_yrd_for_sb(cpi, bsize, x, INT64_MAX, &est_rd_stats);
best_rd_cur =
RDCOST(x->rdmult, best_rs2 + tmp_rate_mv + est_rd_stats.rate,
est_rd_stats.dist);
}
}
mbmi->interinter_comp.wedge_index = best_mask_index;
mbmi->interinter_comp.wedge_sign = best_wedge_sign;
mbmi->mv[0] = tmp_mv[0];
mbmi->mv[1] = tmp_mv[1];
tmp_rate_mv = best_rate_mv;
rs2 = best_rs2;
} else if (skip_mv_refinement_for_diffwtd) {
int_mv tmp_mv[2];
int best_mask_index = 0;
rs2 += get_interinter_compound_mask_rate(&x->mode_costs, mbmi);
int need_mask_search = args->diffwtd_index == -1;
for (int mask_index = 0; mask_index < 2 && need_mask_search;
++mask_index) {
tmp_rate_mv = *rate_mv;
mbmi->interinter_comp.mask_type = mask_index;
if (have_newmv_in_inter_mode(this_mode)) {
// hard coded number for diff wtd
int mask_value = mask_index == 0 ? 38 : 26;
memset(xd->seg_mask, mask_value,
sizeof(xd->seg_mask[0]) * 2 * MAX_SB_SQUARE);
tmp_rate_mv = av1_interinter_compound_motion_search(cpi, x, cur_mv,
bsize, this_mode);
}
av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
AOM_PLANE_Y, AOM_PLANE_Y);
RD_STATS est_rd_stats;
int64_t this_rd_cur = INT64_MAX;
int eval_txfm = prune_mode_by_skip_rd(cpi, x, xd, bsize, ref_skip_rd,
rs2 + *rate_mv);
if (eval_txfm) {
this_rd_cur =
estimate_yrd_for_sb(cpi, bsize, x, ref_best_rd, &est_rd_stats);
}
if (this_rd_cur < INT64_MAX) {
this_rd_cur = RDCOST(x->rdmult, rs2 + tmp_rate_mv + est_rd_stats.rate,
est_rd_stats.dist);
}
if (this_rd_cur < best_rd_cur) {
best_rd_cur = this_rd_cur;
best_mask_index = mbmi->interinter_comp.mask_type;
tmp_mv[0] = mbmi->mv[0];
tmp_mv[1] = mbmi->mv[1];
}
}
if (need_mask_search) {
if (save_mask_search_results(this_mode, 0))
args->diffwtd_index = best_mask_index;
} else {
mbmi->interinter_comp.mask_type = args->diffwtd_index;
rs2 = masked_type_cost[cur_type];
rs2 += get_interinter_compound_mask_rate(&x->mode_costs, mbmi);
int mask_value = mbmi->interinter_comp.mask_type == 0 ? 38 : 26;
memset(xd->seg_mask, mask_value,
sizeof(xd->seg_mask[0]) * 2 * MAX_SB_SQUARE);
if (have_newmv_in_inter_mode(this_mode)) {
tmp_rate_mv = av1_interinter_compound_motion_search(cpi, x, cur_mv,
bsize, this_mode);
}
best_mask_index = mbmi->interinter_comp.mask_type;
tmp_mv[0] = mbmi->mv[0];
tmp_mv[1] = mbmi->mv[1];
av1_enc_build_inter_predictor(cm, xd, mi_row, mi_col, orig_dst, bsize,
AOM_PLANE_Y, AOM_PLANE_Y);
RD_STATS est_rd_stats;
int64_t this_rd_cur = INT64_MAX;
int eval_txfm = prune_mode_by_skip_rd(cpi, x, xd, bsize, ref_skip_rd,
rs2 + *rate_mv);
if (eval_txfm) {
this_rd_cur =
estimate_yrd_for_sb(cpi, bsize, x, ref_best_rd, &est_rd_stats);
}
if (this_rd_cur < INT64_MAX) {
best_rd_cur = RDCOST(x->rdmult, rs2 + tmp_rate_mv + est_rd_stats.rate,
est_rd_stats.dist);
}
}
mbmi->interinter_comp.mask_type = best_mask_index;
mbmi->mv[0] = tmp_mv[0];
mbmi->mv[1] = tmp_mv[1];
} else {
// Handle masked compound types
bool eval_masked_comp_type = true;
if (*rd != INT64_MAX) {
// Factors to control gating of compound type selection based on best
// approximate rd so far
const int max_comp_type_rd_threshold_mul =
comp_type_rd_threshold_mul[cpi->sf.inter_sf
.prune_comp_type_by_comp_avg];
const int max_comp_type_rd_threshold_div =
comp_type_rd_threshold_div[cpi->sf.inter_sf
.prune_comp_type_by_comp_avg];
// Evaluate COMPOUND_WEDGE / COMPOUND_DIFFWTD if approximated cost is
// within threshold
const int64_t approx_rd = ((*rd / max_comp_type_rd_threshold_div) *
max_comp_type_rd_threshold_mul);
if (approx_rd >= ref_best_rd) eval_masked_comp_type = false;
}
if (eval_masked_comp_type) {
const int64_t tmp_rd_thresh = AOMMIN(*rd, rd_thresh);
best_rd_cur = masked_compound_type_rd(
cpi, x, cur_mv, bsize, this_mode, &rs2, *rate_mv, orig_dst,
&tmp_rate_mv, preds0, preds1, buffers->residual1, buffers->diff10,
strides, rd_stats->rate, tmp_rd_thresh, &calc_pred_masked_compound,
comp_rate, comp_dist, comp_model_rate, comp_model_dist,
best_type_stats.comp_best_model_rd, &comp_model_rd_cur, comp_rs2,
ref_skip_rd);
}
}
// Update stats for best compound type
if (best_rd_cur < *rd) {
update_best_info(mbmi, rd, &best_type_stats, best_rd_cur,
comp_model_rd_cur, rs2);
if (have_newmv_in_inter_mode(this_mode))
update_mask_best_mv(mbmi, best_mv, &best_tmp_rate_mv, tmp_rate_mv);
}
// reset to original mvs for next iteration
mbmi->mv[0].as_int = cur_mv[0].as_int;
mbmi->mv[1].as_int = cur_mv[1].as_int;
}
mbmi->comp_group_idx =
(best_type_stats.best_compound_data.type < COMPOUND_WEDGE) ? 0 : 1;
mbmi->compound_idx =
!(best_type_stats.best_compound_data.type == COMPOUND_DISTWTD);
mbmi->interinter_comp = best_type_stats.best_compound_data;
if (have_newmv_in_inter_mode(this_mode)) {
mbmi->mv[0].as_int = best_mv[0].as_int;
mbmi->mv[1].as_int = best_mv[1].as_int;
rd_stats->rate += best_tmp_rate_mv - *rate_mv;
*rate_mv = best_tmp_rate_mv;
}
if (this_mode == NEW_NEWMV)
args->cmp_mode[ref_frame] = mbmi->interinter_comp.type;
restore_dst_buf(xd, *orig_dst, 1);
if (!match_found)
save_comp_rd_search_stat(x, mbmi, comp_rate, comp_dist, comp_model_rate,
comp_model_dist, cur_mv, comp_rs2);
return best_type_stats.best_compmode_interinter_cost;
}
|