1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 2440 2441 2442 2443 2444 2445 2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462 2463 2464 2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508 2509 2510 2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532 2533 2534 2535 2536 2537 2538 2539 2540 2541 2542 2543 2544 2545 2546 2547 2548 2549 2550 2551 2552 2553 2554 2555 2556 2557 2558 2559 2560 2561 2562 2563 2564 2565 2566 2567 2568 2569 2570 2571 2572 2573 2574 2575 2576 2577 2578 2579 2580 2581 2582 2583 2584 2585 2586 2587 2588 2589 2590 2591 2592 2593 2594 2595 2596 2597 2598 2599 2600 2601 2602 2603 2604 2605 2606 2607 2608 2609 2610 2611 2612 2613 2614 2615 2616 2617 2618 2619 2620 2621 2622 2623 2624 2625 2626 2627 2628 2629 2630 2631 2632 2633 2634 2635 2636 2637 2638 2639 2640 2641 2642 2643 2644 2645 2646 2647 2648 2649 2650 2651 2652 2653 2654 2655 2656 2657 2658 2659 2660 2661 2662 2663 2664 2665 2666 2667 2668 2669 2670 2671 2672 2673 2674 2675 2676 2677 2678 2679 2680 2681 2682 2683 2684 2685 2686 2687 2688 2689 2690 2691 2692 2693 2694 2695 2696 2697 2698 2699 2700 2701 2702 2703 2704 2705 2706 2707 2708 2709 2710 2711 2712 2713 2714 2715 2716 2717 2718 2719 2720 2721 2722 2723 2724 2725 2726 2727 2728 2729 2730 2731 2732 2733 2734 2735 2736 2737 2738 2739 2740 2741 2742 2743 2744 2745 2746 2747 2748 2749 2750 2751 2752 2753 2754 2755 2756 2757 2758 2759 2760 2761 2762 2763 2764 2765 2766 2767 2768 2769 2770 2771 2772 2773 2774 2775 2776 2777 2778 2779 2780 2781 2782 2783 2784 2785 2786 2787 2788 2789 2790 2791 2792 2793 2794 2795 2796 2797 2798 2799 2800 2801 2802 2803 2804 2805 2806 2807 2808 2809 2810 2811 2812 2813 2814 2815 2816 2817 2818 2819 2820 2821 2822 2823 2824 2825 2826 2827 2828 2829 2830 2831 2832 2833 2834 2835 2836 2837 2838 2839 2840 2841 2842 2843 2844 2845 2846 2847 2848 2849 2850 2851 2852 2853 2854 2855 2856 2857 2858 2859 2860 2861 2862 2863 2864 2865 2866 2867 2868 2869 2870 2871 2872 2873 2874 2875 2876 2877 2878 2879 2880 2881 2882 2883 2884 2885 2886 2887 2888 2889 2890 2891 2892 2893 2894 2895 2896 2897 2898 2899 2900 2901 2902 2903 2904 2905 2906 2907 2908 2909 2910 2911 2912 2913 2914 2915 2916 2917 2918 2919 2920 2921 2922 2923 2924 2925 2926 2927 2928 2929 2930 2931 2932 2933 2934 2935 2936 2937 2938 2939 2940 2941 2942 2943 2944 2945 2946 2947 2948 2949 2950 2951 2952 2953 2954 2955 2956 2957 2958 2959 2960 2961 2962 2963 2964 2965 2966 2967 2968 2969 2970 2971 2972 2973 2974 2975 2976 2977 2978 2979 2980 2981 2982 2983 2984 2985 2986 2987 2988 2989 2990 2991 2992 2993 2994 2995 2996 2997 2998 2999 3000 3001 3002 3003 3004 3005 3006 3007 3008 3009 3010 3011 3012 3013 3014 3015 3016 3017 3018 3019 3020 3021 3022 3023 3024 3025 3026 3027 3028 3029 3030 3031 3032 3033 3034 3035 3036 3037 3038 3039 3040 3041 3042 3043 3044 3045 3046 3047 3048 3049 3050 3051 3052 3053 3054 3055 3056 3057 3058 3059 3060 3061 3062 3063 3064 3065 3066 3067 3068 3069 3070 3071 3072 3073 3074 3075 3076 3077 3078 3079 3080 3081 3082 3083 3084 3085 3086 3087 3088 3089 3090 3091 3092 3093 3094 3095 3096 3097 3098 3099 3100 3101 3102 3103 3104 3105 3106 3107 3108 3109 3110 3111 3112 3113 3114 3115 3116 3117 3118 3119 3120 3121 3122 3123 3124 3125 3126 3127 3128 3129 3130 3131 3132 3133 3134 3135 3136 3137 3138 3139 3140 3141 3142 3143 3144 3145 3146 3147 3148 3149 3150 3151 3152 3153 3154 3155 3156 3157 3158 3159 3160 3161 3162 3163 3164 3165 3166 3167 3168 3169 3170 3171 3172 3173 3174 3175 3176 3177 3178 3179 3180 3181 3182 3183 3184 3185 3186 3187 3188 3189 3190 3191 3192 3193 3194 3195 3196 3197 3198 3199 3200 3201 3202 3203 3204 3205 3206 3207 3208 3209 3210 3211 3212 3213 3214 3215 3216 3217 3218 3219 3220 3221 3222 3223 3224 3225 3226 3227 3228 3229 3230 3231 3232 3233 3234 3235 3236 3237 3238 3239 3240 3241 3242 3243 3244 3245 3246 3247 3248 3249 3250 3251 3252 3253 3254 3255 3256 3257 3258 3259 3260 3261 3262 3263 3264 3265 3266 3267 3268 3269 3270 3271 3272 3273 3274 3275 3276 3277 3278 3279 3280 3281 3282 3283 3284 3285 3286 3287 3288 3289 3290 3291 3292 3293 3294 3295 3296 3297 3298 3299 3300 3301 3302 3303 3304 3305 3306 3307 3308 3309 3310 3311 3312 3313 3314 3315 3316 3317 3318 3319 3320 3321 3322 3323 3324 3325 3326 3327 3328 3329 3330 3331 3332 3333 3334 3335 3336 3337 3338 3339 3340 3341 3342 3343 3344 3345 3346 3347 3348 3349 3350 3351 3352 3353 3354 3355 3356 3357 3358 3359 3360 3361 3362 3363 3364 3365 3366 3367 3368 3369 3370 3371 3372 3373 3374 3375 3376 3377 3378 3379 3380 3381 3382 3383 3384 3385 3386 3387 3388 3389 3390 3391 3392 3393 3394 3395 3396 3397 3398 3399 3400 3401 3402 3403 3404 3405 3406 3407 3408 3409 3410 3411 3412 3413 3414 3415 3416 3417 3418 3419 3420 3421 3422 3423 3424 3425 3426 3427 3428 3429 3430 3431 3432 3433 3434 3435 3436 3437 3438 3439 3440 3441 3442 3443 3444 3445 3446 3447 3448 3449 3450 3451 3452 3453 3454 3455 3456 3457 3458 3459 3460 3461 3462 3463 3464 3465 3466 3467 3468 3469 3470 3471 3472 3473 3474 3475 3476 3477 3478 3479 3480 3481 3482 3483 3484 3485 3486 3487 3488 3489 3490 3491 3492 3493 3494 3495 3496 3497 3498 3499 3500 3501 3502 3503 3504 3505 3506 3507 3508 3509 3510 3511 3512 3513 3514 3515 3516 3517 3518 3519 3520 3521 3522 3523 3524 3525 3526 3527 3528 3529 3530 3531 3532 3533 3534 3535 3536 3537 3538 3539 3540 3541 3542 3543 3544 3545 3546 3547 3548 3549 3550 3551 3552 3553 3554 3555 3556 3557 3558 3559 3560 3561 3562 3563 3564 3565 3566 3567 3568 3569 3570 3571 3572 3573 3574 3575 3576 3577 3578 3579 3580 3581 3582 3583 3584 3585 3586 3587 3588 3589 3590 3591 3592 3593 3594 3595 3596 3597 3598 3599 3600 3601 3602 3603 3604 3605 3606 3607 3608 3609 3610 3611 3612 3613 3614 3615 3616 3617 3618 3619 3620 3621 3622 3623 3624 3625 3626 3627 3628 3629 3630 3631 3632 3633 3634 3635 3636 3637 3638 3639 3640 3641 3642 3643 3644 3645 3646 3647 3648 3649 3650 3651 3652 3653 3654 3655 3656 3657 3658 3659 3660 3661 3662 3663 3664 3665 3666 3667 3668 3669 3670 3671 3672 3673 3674 3675 3676 3677 3678 3679 3680 3681 3682 3683 3684 3685 3686 3687 3688 3689 3690 3691 3692 3693 3694 3695 3696 3697 3698 3699 3700 3701 3702 3703 3704 3705 3706 3707 3708 3709 3710 3711 3712 3713 3714 3715 3716 3717 3718 3719 3720 3721 3722 3723 3724 3725 3726 3727 3728 3729 3730 3731 3732 3733 3734 3735 3736 3737 3738 3739 3740 3741 3742 3743 3744 3745 3746 3747 3748 3749 3750 3751 3752 3753 3754 3755 3756 3757 3758 3759 3760 3761 3762 3763 3764 3765 3766 3767 3768 3769 3770 3771 3772 3773 3774 3775 3776 3777 3778 3779 3780 3781 3782 3783 3784 3785 3786 3787 3788 3789 3790 3791 3792 3793 3794 3795 3796 3797 3798 3799 3800 3801 3802 3803 3804 3805 3806 3807 3808 3809 3810 3811 3812 3813 3814 3815 3816 3817 3818 3819 3820 3821 3822 3823 3824 3825 3826 3827 3828 3829 3830 3831 3832 3833 3834 3835 3836 3837 3838 3839 3840 3841 3842 3843 3844 3845 3846 3847 3848 3849 3850 3851 3852 3853 3854 3855 3856 3857 3858 3859 3860 3861 3862 3863 3864 3865 3866 3867 3868 3869 3870 3871 3872 3873 3874 3875 3876 3877 3878 3879 3880 3881 3882 3883 3884 3885 3886 3887 3888 3889 3890 3891 3892 3893 3894 3895 3896 3897 3898 3899 3900 3901 3902 3903 3904 3905 3906 3907 3908 3909 3910 3911 3912 3913 3914 3915 3916 3917 3918 3919 3920 3921 3922 3923 3924 3925 3926 3927 3928 3929 3930 3931 3932 3933 3934 3935 3936 3937 3938 3939 3940 3941 3942 3943 3944 3945 3946 3947 3948 3949 3950 3951 3952 3953 3954 3955 3956 3957 3958 3959 3960 3961 3962 3963 3964 3965 3966 3967 3968 3969 3970 3971 3972 3973 3974 3975 3976 3977 3978 3979 3980 3981 3982 3983 3984 3985 3986 3987 3988 3989 3990 3991 3992 3993 3994 3995 3996 3997 3998 3999 4000 4001 4002 4003 4004 4005 4006 4007 4008 4009 4010 4011 4012 4013 4014 4015 4016 4017 4018 4019 4020 4021 4022 4023 4024 4025 4026 4027 4028 4029 4030 4031 4032 4033 4034 4035 4036 4037 4038 4039 4040 4041 4042 4043 4044 4045 4046 4047 4048 4049 4050 4051 4052 4053 4054 4055 4056
|
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/core/impl/GPUTrace.h>
#include <c10/cuda/CUDAAllocatorConfig.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/util/CallOnce.h>
#include <c10/util/Gauge.h>
#include <c10/util/ScopeExit.h>
#include <c10/util/UniqueVoidPtr.h>
#include <c10/util/env.h>
#include <c10/util/error.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/hash.h>
#include <c10/util/llvmMathExtras.h>
#include <c10/util/static_tracepoint.h>
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#include <sys/syscall.h>
#include <sys/types.h>
#include <unistd.h>
#endif
#include <c10/util/Exception.h>
#include <cuda_runtime_api.h>
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <deque>
#include <memory>
#include <mutex>
#include <regex>
#include <set>
#include <utility>
#include <vector>
TORCH_SDT_DEFINE_SEMAPHORE(malloc)
TORCH_SDT_DEFINE_SEMAPHORE(free)
namespace c10 {
C10_DEFINE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback);
namespace cuda::CUDACachingAllocator {
using namespace c10::CachingDeviceAllocator;
// Included here as this is externally used in CUDAAllocatorConfig
const size_t kLargeBuffer =
20971520; // "large" allocations may be packed in 20 MiB blocks
namespace Native {
//
// Yet another caching allocator for CUDA device allocations.
//
// - Allocations are associated with a stream. Once freed, blocks can be
// re-allocated on the same stream, but not on any other stream.
// - The allocator attempts to find the smallest cached block that will fit the
// requested size. If the block is larger than the requested size, it may be
// split. If no block is found, the allocator will delegate to cudaMalloc.
// - If the cudaMalloc fails, the allocator will attempt to free one cached
// block of sufficient size that is not split and retry the allocation.
// If this also fails, the allocator will attempt to free all cached blocks
// that are not split and retry the allocation.
// - Large (>1MB) and small allocations are stored in separate pools.
// Small requests are packed into 2MB buffers. Large requests will use the
// smallest available free block or allocate a new block using cudaMalloc.
// - To reduce fragmentation, requests between 1MB and 10MB will allocate and
// split a 20MB block, if no free block of sufficient size is available.
// - To further reduce fragmentation, blocks >= max_split_size are not allowed
// to be split. These oversize cached blocks will still satisfy requests
// within 1MB of the oversize cached block size.
//
// With this allocator, allocations and frees should logically be considered
// "usages" of the memory segment associated with streams, just like kernel
// launches. The programmer must insert the proper synchronization if memory
// segments are used from multiple streams.
//
// The library provides a recordStream() function to help insert the correct
// synchronization when allocations are used on multiple streams. This will
// ensure that the block is not reused before each recorded stream completes
// work.
//
/**
* Note [Interaction with CUDA graph capture]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* Graph capture performs a dry run of a region of execution, freezing all CUDA
* work (and virtual addresses used during that work) into a "graph." The graph
* may be "replayed" like a single giant kernel, with greatly reduced CPU
* overhead as well as modestly improved GPU performance.
*
* Because capture bakes in memory addresses, the memory used during capture
* must be available for the graph to use during replay. DeviceCachingAllocator
* assigns and frees memory eagerly and dynamically, so if we're not careful
* about managing graphs' memory, at replay time those memory addresses could be
* used by other tensors.
*
* To guarantee a graph's baked in addresses are safe to reuse in replay,
* DeviceAllocator satisfies allocations from a graph-private memory pool during
* capture, and doesn't begin cudaFreeing those addresses until the graph is
* destroyed.
*
* Within the private pool, allocations are freed and reassigned as usual during
* capture. Memory regions will be used in a consistent order during replay. So
* a private pool doesn't use memory more wastefully than the default pools
* during capture, but it does reserve its high-water mark of used memory away
* from the default pools as long as the capture(s) it served survive
* (regardless whether those captures are idle or replaying).
*
* CUDAGraph's requests for private pools are mediated by
* DeviceAllocator::notifyCaptureBegin,
* notifyCaptureAboutToEnd,
* notifyCaptureEnded,
* notifyCaptureDestroy.
*/
constexpr size_t kMinBlockSize =
512; // all sizes are rounded to at least 512 bytes
constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB
constexpr size_t kSmallBuffer =
2097152; // "small" allocations are packed in 2 MiB blocks
constexpr size_t kMinLargeAlloc =
10485760; // allocations between 1 and 10 MiB may use kLargeBuffer
constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB
static char SHAREABLE_HANDLE_VERSION = 1;
enum ShareableHandleType : char {
SHAREABLE_CUDA_MALLOC = 'c',
SHAREABLE_CUDA_EXPANDABLE_SEGMENT = 'e'
};
namespace {
using stream_set = ska::flat_hash_set<cuda::CUDAStream>;
void decrease_stat_array(
StatArray& stat_array,
size_t amount,
const StatTypes& stat_types) {
for_each_selected_stat_type(
stat_types, [&stat_array, amount](size_t stat_type) {
stat_array[stat_type].decrease(amount);
});
}
struct Block;
struct PrivatePool;
typedef bool (*Comparison)(const Block*, const Block*);
static bool BlockComparatorSize(const Block* a, const Block* b);
static bool BlockComparatorAddress(const Block* a, const Block* b);
struct BlockPool {
BlockPool(bool small, PrivatePool* private_pool = nullptr)
: blocks(BlockComparatorSize),
unmapped(BlockComparatorAddress),
is_small(small),
owner_PrivatePool(private_pool) {}
// Do not insert a Block to blocks directly; use insert_into_blocks(),
// instead.
std::set<Block*, Comparison> blocks;
std::set<Block*, Comparison> unmapped;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const bool is_small;
PrivatePool* owner_PrivatePool;
int64_t get_free_blocks_call_count{0};
// Add a Block into blocks set with updating gc counter.
std::pair<std::set<Block*, Comparison>::iterator, bool> insert_into_blocks(
Block* block);
};
struct ExpandableSegment;
struct Block {
c10::DeviceIndex device; // gpu
cudaStream_t stream; // allocation stream
stream_set stream_uses; // streams on which the block was used
size_t size; // block size in bytes
size_t requested_size; // memory originally requested
BlockPool* pool{nullptr}; // owning memory pool
void* ptr{nullptr}; // memory address
bool allocated{false}; // in-use flag
bool mapped{true}; // is the virtual address range this Block references
// backed by physical pages. Always true when
// expandable_segment_ is null. When false
// This Block will be aligned to the segment size
// of its expandable_segment_.
Block* prev{nullptr}; // prev block if split from a larger allocation
Block* next{nullptr}; // next block if split from a larger allocation
int event_count{0}; // number of outstanding CUDA events
int64_t gc_count_base{0}; // get_free_blocks_call_count when Block is inserted
std::shared_ptr<GatheredContext> context_when_allocated;
// only set for the first block in the segment (when prev == null)
// this records the frame information when cudaMalloc was called
// whereas context_when_allocated records the last time we handed this
// memory out from our cache.
std::shared_ptr<GatheredContext> context_when_segment_allocated;
ExpandableSegment* expandable_segment_{nullptr};
Block(
c10::DeviceIndex device,
cudaStream_t stream,
size_t size,
BlockPool* pool,
void* ptr)
: device(device),
stream(stream),
stream_uses(),
size(size),
requested_size(0),
pool(pool),
ptr(ptr) {}
// constructor for search key
Block(c10::DeviceIndex device, cudaStream_t stream, size_t size)
: device(device),
stream(stream),
stream_uses(),
size(size),
requested_size(0) {}
size_t gc_count() {
TORCH_INTERNAL_ASSERT(pool);
return static_cast<int>(pool->get_free_blocks_call_count - gc_count_base);
}
bool is_split() const {
return (prev != nullptr) || (next != nullptr);
}
void splice(Block* before, Block* after) {
if (before) {
TORCH_INTERNAL_ASSERT(before->next == after);
before->next = this;
}
prev = before;
if (after) {
TORCH_INTERNAL_ASSERT(after->prev == before);
after->prev = this;
}
next = after;
}
};
std::pair<std::set<Block*, Comparison>::iterator, bool> BlockPool::
insert_into_blocks(Block* block) {
block->gc_count_base = get_free_blocks_call_count;
return blocks.insert(block);
}
struct SegmentRange {
char* ptr;
size_t size;
SegmentRange(void* p, size_t s) : ptr(static_cast<char*>(p)), size(s) {}
};
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
/*
Note [Expandable Segments]
Rationale
For large (>2MB) allocations, the allocator calls cudaMalloc to get allocations
that are the same size as what the user requests. In the future, parts of these
allocations can be reused for other requests if they are free. This works well
when the program makes many requests of exactly the same size or of sizes that
even multiples of that size. Many deep learning models follow this behavior.
However, one common exception is when the batch size changes slightly from one
iteration to the next, e.g. in batched inference. When the program runs
initially with batch size N, it will make allocations appropriate for that size.
If in the future, it runs at size N - 1, the existing allocations will still be
big enough. However, if it runs at size N + 1, then it will have to make new
allocations that are slightly larger. Not all the tensors are the same size.
Some might be (N + 1)*A and others (N + 1)*A*B where A and B are some non-batch
dimensions in the model. Because the allocator reuses existing allocations when
they are big enough, some number of (N + 1)*A allocations will actually fit in
the already existing N*B*A segments, though not perfectly. As the model runs it
will partially fill up all of these segments leaving unusable free slices of
memory at the end of these segments. The allocator at some point will need to
cudaMalloc a new (N + 1)*A*B segment. If there is not enough memory, there is
now no way to recover the slices of memory that are free at the end of existing
segments. With models 50+ layers deep, this pattern might repeat 50+ times
creating many slivers.
Approach
Expandable segments allows the allocator to create a segment initially and then
expand its size later when more memory is needed. Instead of making one segment
per allocation, it tries to make one segment (per stream) that grows as
necessary. Now when the N + 1 case runs, the allocations will tile nicely into
the one large segment until it fills up. Then more memory is requested and
appended to the end of the segment. This process does not create as many slivers
of unusable memory, so it is more likely to succeed at finding this memory.
Implementation
The expandable_segments:True option is used to enable/disable this behavior. We
use cuda's low-level memory APIs, which are similar to mmap, to extend the
memory segments. These APIs separate the allocation of physical memory
(cuMemCreate) from the allocation of virtual address space (cuMemAddressReserve)
and the associate between them cuMemMap/cuMemSetAccess.
When we allocate a new segment, we allocate enough address space to map
basically the entire physical memory of the GPU (there is 256TiB of address
space), but we only map enough physical memory to handle the current amount of
memory needed by the program. As more is requested, we add more physical memory
to the segment. This can work at the granularity of GPU pages which are 2MiB
currently.
If we end up out of memory, we can unmap all the memory in our segment
corresponding to empty physical pages, and return it to CUDA for use at another
address in the segment or in a segment for a different stream.
A current limitation of CUDA's API is that physical memory
(CUmemGenericAllocationHandle) cannot be split up after it is mapped even if the
handle holds multiple GPU pages. The cost to map/unmap memory is proportional to
the number of physical memory chunks that were allocated (mapping 10 separately
allocated 2MiB pages takes 10x time compared to mapping one 20MiB physical
allocation of 10 pages). Changing memory mappings also appears to involve at
least some synchronous actions with the GPU and so should be considered an
expensive operation. To limit overhead, we use 2MiB pages for our small pool and
20MiB pages for our large pool. Initially allocation using expandable_blocks
will be slower than cudaMalloc, though still in the milliseconds range for
mapping the entire memory.
When mapping new memory to expand the segment, we look for the lowest address at
which we can fit a new allocation by adding new pages. Normally this will be at
the end of the block. But if have previously unmapped blocks earlier in the
segment during an OOM, it will first try to fill in those gaps to keep the
segment as a single block. By allocating at the lowest address we encourage
the split up parts of the block to merge into a single block again, reducing
fragmentation potential.
Allocation of blocks in the segment uses the same best-fit heuristics of the
rest of the allocator.
Expandable blocks can be enabled/disabled throughout the run of a program. When
disabled, the allocator will not put new allocations in an expandable block.
Limitations
* Slightly slower initial memory allocation speed.
* IPC of cuda tensors (e.g. for multiprocess dataloaders) is not supported.
However, it is possible to temporarily disable (expandable_segments:False) the
bevhavior for allocator tensors that need to be used cross-process.
* CUDA runtime APIs related to sharing memory across process
(cudaDeviceEnablePeerAccess) do not work for memory allocated with cuMemMap.
Instead these mapping have to be done manually. The allocator now has an
`enablePeerAccess` method to do this.
*/
struct ExpandableSegment {
ExpandableSegment(
c10::DeviceIndex device,
std::optional<cudaStream_t> stream,
size_t address_space_size,
size_t segment_size,
std::vector<c10::DeviceIndex> peers)
: device_(device),
stream_(stream),
// 2MB for small pool, 20MB for large pool
segment_size_(segment_size),
max_handles_(numSegments(address_space_size)),
peers_(std::move(peers)) {
cudaDeviceProp prop{};
C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_));
// we allocate enough address space for 1 1/8 the total memory on the GPU.
// This allows for some cases where we have to unmap pages earlier in the
// segment to put them at the end.
max_handles_ = numSegments(prop.totalGlobalMem + prop.totalGlobalMem / 8);
C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemAddressReserve_(
&ptr_, segment_size_ * max_handles_, 0ULL, 0, 0ULL));
}
ExpandableSegment(const ExpandableSegment&) = delete;
ExpandableSegment(ExpandableSegment&&) = delete;
ExpandableSegment operator=(const ExpandableSegment&) = delete;
ExpandableSegment operator=(ExpandableSegment&&) = delete;
// begin must be aligned to segment_size_.
// returns the actual range mapped, which may be
// greater than requested if size is not aligned to segment_size_.
// return size of 0 indicates OOM
SegmentRange map(SegmentRange range) {
auto begin = segmentLeft(range.ptr);
auto end = segmentRight(range.ptr + range.size);
TORCH_INTERNAL_ASSERT(ptr() + begin * segment_size_ == range.ptr);
if (begin == end) {
return rangeFromHandles(begin, end);
}
while (end > handles_.size()) {
handles_.emplace_back(std::nullopt);
}
for (auto i : c10::irange(begin, end)) {
TORCH_INTERNAL_ASSERT(!handles_.at(i));
CUmemGenericAllocationHandle handle = 0;
CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
#ifndef FBCODE_CAFFE2
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
#endif
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
prop.location.id = static_cast<int>(device_);
auto status =
DriverAPI::get()->cuMemCreate_(&handle, segment_size_, &prop, 0);
if (status == CUDA_ERROR_OUT_OF_MEMORY) {
for (auto j : c10::irange(begin, i)) {
auto h = handles_.at(j).value();
handles_.at(j) = std::nullopt;
C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemRelease_(h.handle));
}
trimHandles();
return rangeFromHandles(begin, begin);
}
C10_CUDA_DRIVER_CHECK(status);
handles_.at(i) = Handle{handle, std::nullopt};
}
mapAndSetAccess(begin, end);
return rangeFromHandles(begin, end);
}
// unmaps all the completely empty segment_size_ segments between
// [begin, begin + size), returns the offset where the range begin,
// and the actual size unmapped (multiple of segment_size_)
SegmentRange unmap(SegmentRange range) {
auto begin = segmentRight(range.ptr);
auto end = segmentLeft(range.ptr + range.size);
if (begin >= end) {
return SegmentRange{range.ptr, 0};
}
unmapHandles(begin, end);
return rangeFromHandles(begin, end);
}
// Setup IPC sharing for range.
// Returns the (larger) range that was actually shared.
// Serializes data to std::ostream that can be passed to the
// other process, and then restored as an exapandable segment
// via ExpandableSegment::fromShared(istream);
SegmentRange share(SegmentRange range, std::ostream& buf) {
auto begin = segmentLeft(range.ptr);
auto end = segmentRight(range.ptr + range.size);
ShareHeader header{getpid(), segment_size_, end - begin};
buf.write((const char*)&header, sizeof(ShareHeader));
for (auto i : c10::irange(begin, end)) {
auto& handle = handles_.at(i).value();
if (!handle.fd) {
int fd = 0;
C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemExportToShareableHandle_(
&fd, handle.handle, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0));
handle.fd = fd;
}
int fd = *handle.fd;
buf.write((const char*)&fd, sizeof(int));
}
return rangeFromHandles(begin, end);
}
static std::unique_ptr<ExpandableSegment> fromShared(
c10::DeviceIndex device,
std::vector<c10::DeviceIndex> peers,
std::istream& buf) {
ShareHeader header{};
buf.read((char*)&header, sizeof(ShareHeader));
auto segment = std::make_unique<ExpandableSegment>(
device,
std::nullopt,
header.num_handles * header.segment_size,
header.segment_size,
std::move(peers));
// older build setups (e.g. multiwheels) do not have this syscall, added 2020
// but the kernel on the system might still support it.
#ifndef SYS_pidfd_open
#define SYS_pidfd_open 434
#endif
#ifndef SYS_pidfd_getfd
#define SYS_pidfd_getfd 438
#endif
auto pidfd = syscall(SYS_pidfd_open, header.pid, 0);
TORCH_CHECK(
pidfd != -1 || errno != ENOSYS,
"The kernel on this machine does not support the pidfd_open syscall needed to use IPC for CUDA tensors when expandable_segments:True is set. "
"Consider using expandable_segments:False via torch.cuda.memory._set_allocator_settings('expandable_segments:False') for this allocation.");
TORCH_CHECK(pidfd != -1, "pidfd_open:", c10::utils::str_error(errno));
for (auto i : c10::irange(header.num_handles)) {
(void)i;
int fd = 0;
buf.read((char*)&fd, sizeof(int));
auto myfd = syscall(SYS_pidfd_getfd, pidfd, fd, 0);
if (myfd == -1) {
auto err = errno;
close((int)pidfd);
for (auto& h : segment->handles_) {
C10_CUDA_DRIVER_CHECK(
DriverAPI::get()->cuMemRelease_(h.value().handle));
h = std::nullopt;
}
TORCH_CHECK(
err != ENOSYS,
"The kernel on this machine does not support the pidfd_getfd syscall needed to use IPC for CUDA tensors when expandable_segments:True is set. "
"Consider using expandable_segments:False via torch.cuda.memory._set_allocator_settings('expandable_segments:False') for this allocation.");
TORCH_CHECK(false, "pidfd_getfd: ", c10::utils::str_error(err));
}
CUmemGenericAllocationHandle handle = 0;
C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemImportFromShareableHandle_(
&handle,
// NOLINTNEXTLINE(performance-no-int-to-ptr)
(void*)(uintptr_t)myfd,
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
close((int)myfd);
segment->handles_.emplace_back(Handle{handle, std::nullopt});
}
close((int)pidfd);
segment->mapAndSetAccess(0, header.num_handles);
return segment;
}
char* ptr() const {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<char*>(ptr_);
}
size_t size() const {
return max_handles_ * segment_size_;
}
void addPeer(c10::DeviceIndex device) {
peers_.push_back(device);
forEachAllocatedRange(
[&](size_t begin, size_t end) { setAccess(device, begin, end); });
}
~ExpandableSegment() {
forEachAllocatedRange(
[&](size_t begin, size_t end) { unmapHandles(begin, end); });
C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemAddressFree_(
ptr_, segment_size_ * max_handles_));
}
private:
void setAccess(c10::DeviceIndex device, size_t begin, size_t end) {
CUmemAccessDesc desc;
desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
desc.location.id = static_cast<int>(device);
desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemSetAccess_(
ptr_ + begin * segment_size_, (end - begin) * segment_size_, &desc, 1));
}
void mapAndSetAccess(size_t begin, size_t end) {
for (auto i : c10::irange(begin, end)) {
C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemMap_(
ptr_ + i * segment_size_,
segment_size_,
0,
handles_.at(i).value().handle,
0ULL));
}
setAccess(device_, begin, end);
for (auto p : peers_) {
setAccess(p, begin, end);
}
}
void unmapHandles(size_t begin, size_t end) {
// note: unlike cudaFree, MemUnmap and MemRelease do
// not appear to synchronize in all cases, so we have to wait for the
// stream to finish before this memory is truly free.
// cannot call c10::cuda::stream_synchronize because
// it might grab the GIL which can lead to a deadlock
// Locking order must be GIL -> Allocator Lock
if (stream_) {
C10_CUDA_CHECK(cudaStreamSynchronize(*stream_));
} else {
cuda::CUDAGuard device_guard(device_);
C10_CUDA_CHECK(cudaDeviceSynchronize());
}
for (auto i : c10::irange(begin, end)) {
Handle h = handles_.at(i).value();
handles_.at(i) = std::nullopt;
C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemUnmap_(
ptr_ + segment_size_ * i, segment_size_));
if (h.fd) {
close(*h.fd);
}
C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemRelease_(h.handle));
}
trimHandles();
}
void trimHandles() {
while (!handles_.empty() && !handles_.back()) {
handles_.pop_back();
}
}
void forEachAllocatedRange(const std::function<void(size_t, size_t)>& fn) {
size_t start = 0;
for (auto i : c10::irange(handles_.size())) {
if (handles_.at(i) && (i == 0 || !handles_.at(i - 1))) {
start = i;
}
if (handles_.at(i) && (i + 1 == handles_.size() || !handles_.at(i + 1))) {
fn(start, i + 1);
}
}
}
size_t numSegments(size_t size) {
return (size + segment_size_ - 1) / segment_size_;
}
size_t segmentLeft(char* p) {
auto size = p - ptr();
return size / segment_size_;
}
size_t segmentRight(char* p) {
auto size = p - ptr();
return numSegments(size);
}
SegmentRange rangeFromHandles(size_t begin, size_t end) {
return SegmentRange(
ptr() + segment_size_ * begin, segment_size_ * (end - begin));
}
c10::DeviceIndex device_;
std::optional<cudaStream_t> stream_;
CUdeviceptr ptr_{};
size_t segment_size_;
size_t max_handles_;
struct Handle {
CUmemGenericAllocationHandle handle;
std::optional<int> fd;
};
struct ShareHeader {
pid_t pid;
size_t segment_size;
size_t num_handles;
};
std::vector<std::optional<Handle>> handles_;
// devices on which this memory should be mapped in addition
// to the device where the physical memory lives (device_).
std::vector<c10::DeviceIndex> peers_;
};
#else
struct ExpandableSegment {
ExpandableSegment(
c10::DeviceIndex device,
std::optional<cudaStream_t> stream,
size_t address_space_size,
size_t segment_size,
std::vector<c10::DeviceIndex> peers) {
TORCH_INTERNAL_ASSERT(false, "expandable segment not supported");
}
SegmentRange map(SegmentRange range) {
return SegmentRange(nullptr, 0);
}
SegmentRange unmap(SegmentRange range) {
return SegmentRange(nullptr, 0);
}
SegmentRange share(SegmentRange range, std::ostream& ss) {
return SegmentRange(nullptr, 0);
}
static std::unique_ptr<ExpandableSegment> fromShared(
c10::DeviceIndex device,
std::vector<c10::DeviceIndex> peers,
std::istream& buf) {
return {};
}
char* ptr() const {
return nullptr;
}
size_t size() const {
return 0;
}
void addPeer(c10::DeviceIndex device) {}
};
#endif
// BlockState, BlockPoolState, and PrivatePoolState contain the information
// needed to reconstruct a private pool to a previous state. See note
// [Checkpointing PrivatePoolState]
struct BlockState {
c10::DeviceIndex device = 0;
cudaStream_t stream = nullptr;
stream_set stream_uses = {};
size_t size = 0;
void* ptr = nullptr;
bool allocated = false;
int64_t gc_count_base = 0;
// maintain invariant that event_count == 0 ;
// history will be left alone in checkpoint
BlockState(Block* block);
};
struct SegmentState {
std::vector<BlockState> blocks;
bool is_small = false;
SegmentState(Block* head);
};
struct PrivatePoolState : AllocatorState {
// omitting use_count, and cudaMalloc_count as they remain the same
MempoolId_t owner_id = {0, 0};
std::vector<SegmentState> segments;
PrivatePoolState(
MempoolId_t pool_id,
const std::vector<Block*>& private_pool_head_blocks);
};
struct RestoreResult {
std::vector<void*> allocations_freed;
std::vector<Block*> allocations_created;
};
static bool BlockComparatorSize(const Block* a, const Block* b) {
if (a->stream != b->stream) {
return (uintptr_t)a->stream < (uintptr_t)b->stream;
}
if (a->size != b->size) {
return a->size < b->size;
}
return (uintptr_t)a->ptr < (uintptr_t)b->ptr;
}
static bool BlockComparatorAddress(const Block* a, const Block* b) {
if (a->stream != b->stream) {
return (uintptr_t)a->stream < (uintptr_t)b->stream;
}
return (uintptr_t)a->ptr < (uintptr_t)b->ptr;
}
struct AllocParams {
AllocParams(
c10::DeviceIndex device,
size_t size,
cudaStream_t stream,
BlockPool* pool,
size_t alloc_size,
DeviceStats& stats)
: search_key(device, stream, size), pool(pool), alloc_size(alloc_size) {}
c10::DeviceIndex device() const {
return search_key.device;
}
cudaStream_t stream() const {
return search_key.stream;
}
size_t size() const {
return search_key.size;
}
Block search_key;
BlockPool* pool;
size_t alloc_size;
Block* block{nullptr};
StatTypes stat_types = {false};
cudaError_t err{cudaSuccess};
};
// Note: cudaEventCreate when concurrently invoked from multiple threads can be
// very expensive (at least on certain device/driver combinations). Thus, we a)
// serialize event creation at a per-device level, and b) pool the events to
// avoid constantly calling cudaEventCreate/cudaEventDestroy. This results in
// significant improvements in multithreaded workloads with high allocation
// rates.
class EventPool {
public:
using Event = std::unique_ptr<cudaEvent_t, std::function<void(cudaEvent_t*)>>;
// TODO: Explicit device count
EventPool() : pools_(at::cuda::device_count()) {}
Event get(c10::DeviceIndex device) {
TORCH_INTERNAL_ASSERT(0 <= device);
TORCH_INTERNAL_ASSERT(device < static_cast<int>(pools_.size()));
auto& pool = pools_[device];
auto destructor = [&pool](cudaEvent_t* event) {
std::lock_guard<std::mutex> g(pool.mutex_);
pool.event_pool_.push_back(std::unique_ptr<cudaEvent_t>(event));
};
// Try to acquire an event from the per-device pool.
{
std::lock_guard<std::mutex> g(pool.mutex_);
if (!pool.event_pool_.empty()) {
auto* event = pool.event_pool_.back().release();
pool.event_pool_.pop_back();
return Event(event, destructor);
}
}
// otherwise, allocate a new event that will be returned to the pool on
// destruction.
auto new_ptr = std::make_unique<cudaEvent_t>();
C10_CUDA_CHECK(
cudaEventCreateWithFlags(new_ptr.get(), cudaEventDisableTiming));
return Event(new_ptr.release(), destructor);
}
void empty_cache() {
for (auto& pool : pools_) {
std::lock_guard<std::mutex> g(pool.mutex_);
pool.event_pool_.clear();
}
}
private:
struct PerDevicePool {
alignas(64) std::mutex mutex_;
std::vector<std::unique_ptr<cudaEvent_t>> event_pool_;
};
std::vector<PerDevicePool> pools_;
};
// CUDA graphs helper
struct PrivatePool {
PrivatePool()
: large_blocks(/*small=*/false, this),
small_blocks(/*small=*/true, this) {}
PrivatePool(const PrivatePool&) = delete;
PrivatePool(PrivatePool&&) = delete;
PrivatePool& operator=(const PrivatePool&) = delete;
PrivatePool& operator=(PrivatePool&&) = delete;
~PrivatePool() = default;
// Number of live graphs using this pool
int use_count{1};
// Number of unfreed cudaMallocs made for this pool. When use_count and
// cudaMalloc_count drop to zero, we can delete this PrivatePool from
// graph_pools.
int cudaMalloc_count{0};
// Instead of maintaining private BlockPools here, I could stuff all blocks
// (private or no) into the top-level large_blocks and small_blocks, and
// distinguish private blocks by adding a "pool id" check above the stream
// check in BlockComparator. BlockComparator is performance- critical though,
// I'd rather not add more logic to it.
BlockPool large_blocks;
BlockPool small_blocks;
};
BlockState::BlockState(Block* block)
: device(block->device),
stream(block->stream),
stream_uses(block->stream_uses),
size(block->size),
ptr(block->ptr),
allocated(block->allocated),
gc_count_base(block->gc_count_base) {
TORCH_CHECK(
block->event_count == 0,
"Events should have synchronized when checkpointing block");
};
SegmentState::SegmentState(Block* head) {
TORCH_INTERNAL_ASSERT(head->prev == nullptr && head->pool != nullptr);
is_small = head->pool->is_small;
for (Block* curr = head; curr != nullptr; curr = curr->next) {
blocks.emplace_back(curr);
}
}
PrivatePoolState::PrivatePoolState(
MempoolId_t pool_id,
const std::vector<Block*>& private_pool_head_blocks)
: owner_id(std::move(pool_id)) {
for (Block* head : private_pool_head_blocks) {
segments.emplace_back(head);
}
}
struct MempoolIdHash {
std::size_t operator()(const MempoolId_t& mempool_id) const noexcept {
return mempool_id.first != 0 ? mempool_id.first : mempool_id.second;
}
};
cudaError_t allocPrimitive(void** ptr, size_t size, AllocParams& p) {
auto active_pool = MemPoolContext::getActiveMemPool();
if (active_pool && active_pool->allocator() && p.pool->owner_PrivatePool) {
*ptr = active_pool->allocator()->raw_alloc(size);
return *ptr ? cudaSuccess : cudaErrorMemoryAllocation;
} else {
return C10_CUDA_ERROR_HANDLED(cudaMalloc(ptr, size));
}
}
cudaError_t cudaMallocMaybeCapturing(void** ptr, size_t size, AllocParams& p) {
if (at::cuda::currentStreamCaptureStatusMayInitCtx() ==
at::cuda::CaptureStatus::None) {
return allocPrimitive(ptr, size, p);
} else {
// It's ok to capture cudaMallocs, as long as we never cudaFree those
// addresses before replay.
// Capturing cudaMalloc behaves nicely: it gives the graph new VA,
// but is ignored (won't leakily allocate new memory) in replays.
at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeRelaxed};
return allocPrimitive(ptr, size, p);
}
}
template <class T>
class RingBuffer {
public:
RingBuffer() {
// alloc_trace is a pointer because we need to intentionally
// leak this on deallocation it can hold references to Python
// state which will already be destroyed when we are in exit handlers
// NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
alloc_trace = new std::vector<T>();
}
void setMaxEntries(size_t size) {
std::lock_guard<std::mutex> lk(alloc_trace_lock);
alloc_trace_max_entries_ = std::max(size_t(1), size);
}
void insertEntries(const T& entry) {
std::lock_guard<std::mutex> lk(alloc_trace_lock);
if (alloc_trace->size() < alloc_trace_max_entries_) {
alloc_trace->emplace_back(entry);
} else {
(*alloc_trace)[alloc_trace_next++] = entry;
if (alloc_trace_next == alloc_trace_max_entries_) {
alloc_trace_next = 0;
}
}
}
void getEntries(std::vector<T>& result) {
std::lock_guard<std::mutex> lk(alloc_trace_lock);
result.reserve(alloc_trace->size());
result.insert(
result.end(),
alloc_trace->begin() +
static_cast<typename std::vector<T>::difference_type>(
alloc_trace_next),
alloc_trace->end());
result.insert(
result.end(),
alloc_trace->begin(),
alloc_trace->begin() +
static_cast<typename std::vector<T>::difference_type>(
alloc_trace_next));
}
void clear() {
std::lock_guard<std::mutex> lk(alloc_trace_lock);
alloc_trace_next = 0;
alloc_trace->clear();
}
private:
size_t alloc_trace_max_entries_ = 1;
// Both alloc_trace and alloc_trace_next needs to be used
// under alloc_trace_lock.
std::mutex alloc_trace_lock;
size_t alloc_trace_next = 0;
std::vector<T>*
alloc_trace; // pointer because we need to intentionally leak this on
// deallocation it can hold references to Python state which
// will already be destroyed when we are in exit handlers
};
} // anonymous namespace
} // namespace Native
static std::string reportProcessMemoryInfo(c10::DeviceIndex device) {
#ifdef PYTORCH_C10_DRIVER_API_SUPPORTED
void* nvml_handle = DriverAPI::get_nvml_handle();
if (!nvml_handle) {
return "";
}
static c10::once_flag nvml_init;
c10::call_once(nvml_init, [] {
TORCH_INTERNAL_ASSERT(NVML_SUCCESS == DriverAPI::get()->nvmlInit_v2_());
});
cudaDeviceProp prop{};
C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
// NOLINTNEXTLINE(*-c-arrays)
char pci_id[80];
snprintf(
pci_id,
sizeof(pci_id),
NVML_DEVICE_PCI_BUS_ID_FMT,
prop.pciDomainID,
prop.pciBusID,
prop.pciDeviceID);
nvmlDevice_t nvml_device = nullptr;
TORCH_INTERNAL_ASSERT(
NVML_SUCCESS ==
DriverAPI::get()->nvmlDeviceGetHandleByPciBusId_v2_(
pci_id, &nvml_device));
std::vector<nvmlProcessInfo_v1_t> procs(8);
unsigned int size = procs.size();
nvmlReturn_t r{};
while ((r = DriverAPI::get()->nvmlDeviceGetComputeRunningProcesses_(
nvml_device, &size, procs.data())) ==
NVML_ERROR_INSUFFICIENT_SIZE) {
procs.resize(size);
}
unsigned int self_pid = getpid();
std::stringstream ss;
TORCH_INTERNAL_ASSERT(NVML_SUCCESS == r);
ss << "";
for (auto i : c10::irange(size)) {
auto& proc = procs[i];
if (self_pid == proc.pid) {
ss << "Including non-PyTorch memory, this process";
} else {
ss << "Process " << proc.pid;
}
ss << " has " << format_size(proc.usedGpuMemory) << " memory in use. ";
}
return ss.str();
#else
return "";
#endif
}
namespace Native {
class DeviceCachingAllocator {
private:
// lock around all operations
mutable std::recursive_mutex mutex;
// device statistics
DeviceStats stats;
// unallocated cached blocks larger than 1 MB
BlockPool large_blocks;
// unallocated cached blocks 1 MB or smaller
BlockPool small_blocks;
// allocated or in use by a stream. Holds all active allocations,
// whether they came from graph_pools or one of the BlockPools above.
ska::flat_hash_set<Block*> active_blocks;
// captures_underway tracks if we are diverting some
// allocations to a specific pool.
// Most of the time it's empty, in which case malloc can avoid calling
// cudaStreamGetCaptureInfo in the hot path.
std::vector<std::pair<MempoolId_t, std::function<bool(cudaStream_t)>>>
captures_underway;
// See free() for this thing's purpose
std::vector<Block*> needs_events_deferred_until_no_capture;
// outstanding cuda events
ska::flat_hash_map<
cuda::CUDAStream,
std::deque<std::pair<EventPool::Event, Block*>>>
cuda_events;
// record used memory.
size_t total_allocated_memory = 0;
size_t allowed_memory_maximum = 0;
// all live expandable segments
std::vector<ExpandableSegment*> expandable_segments_;
std::vector<c10::DeviceIndex> devices_with_peer_access_;
bool set_fraction = false;
bool record_history = false;
std::atomic<CreateContextFn> context_recorder_;
RecordContext record_context_ = RecordContext::NEVER;
// Ring buffer for memory snapshot TraceEntry's
RingBuffer<TraceEntry> alloc_buffer;
// Members specific to CUDA graphs
// Private pools for CUDA graphs
ska::flat_hash_map<MempoolId_t, std::unique_ptr<PrivatePool>, MempoolIdHash>
graph_pools;
// Pools no longer referenced by any graph. Their BlockPools are eligible for
// free_blocks. Can't be a vector or deque because we might erase entries in
// any order. Could be an std::list, but we don't care much, access and
// insert/erase are rare.
ska::flat_hash_map<MempoolId_t, PrivatePool*, MempoolIdHash>
graph_pools_freeable;
// XXX - maybe we should generalize and have multiple events
std::vector<OutOfMemoryObserver> oom_observers_;
std::vector<AllocatorTraceTracker> trace_trackers_;
// mapping from block to a stream_set, containing streams on which the block
// was used while cudagraph capturing
std::unordered_map<Block*, stream_set> block_to_cudagraph_stream_uses;
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
DeviceCachingAllocator()
: large_blocks(/*small=*/false), small_blocks(/*small=*/true) {
stats.max_split_size =
static_cast<int64_t>(CUDAAllocatorConfig::max_split_size());
context_recorder_.store(nullptr);
}
void recordHistory(
bool enabled,
CreateContextFn context_recorder,
size_t alloc_buffer_max_entries,
RecordContext when) {
std::unique_lock<std::recursive_mutex> lock(mutex);
TORCH_CHECK(when == RecordContext::NEVER || context_recorder);
record_history = enabled;
context_recorder_.store(record_history ? context_recorder : nullptr);
alloc_buffer.setMaxEntries(alloc_buffer_max_entries);
record_context_ = enabled ? when : RecordContext::NEVER;
if (!enabled) {
alloc_buffer.clear();
}
}
bool isHistoryEnabled() {
return record_history;
}
bool checkPoolLiveAllocations(
MempoolId_t mempool_id,
const std::unordered_set<void*>& expected_live_allocations) {
std::unique_lock<std::recursive_mutex> lock(mutex);
PrivatePool* pool = nullptr;
auto pool_it = graph_pools.find(mempool_id);
TORCH_CHECK(pool_it != graph_pools.end(), "Could not find pool of id");
pool = pool_it->second.get();
TORCH_INTERNAL_ASSERT(pool != nullptr);
size_t allocated_pool_blocks = 0;
for (Block* b : active_blocks) {
TORCH_INTERNAL_ASSERT(b != nullptr);
TORCH_INTERNAL_ASSERT(b->pool != nullptr);
if (b->allocated && b->pool->owner_PrivatePool == pool) {
if (!expected_live_allocations.count(b->ptr)) {
return false;
}
allocated_pool_blocks += 1;
}
}
return allocated_pool_blocks == expected_live_allocations.size();
}
void attachOutOfMemoryObserver(OutOfMemoryObserver observer) {
oom_observers_.emplace_back(std::move(observer));
}
void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) {
std::unique_lock<std::recursive_mutex> lock(mutex);
trace_trackers_.emplace_back(std::move(tracker));
}
// Must be called outside of `mutex` or deadlocks are possible with Python
std::shared_ptr<GatheredContext> maybeGatherContext(RecordContext level) {
if (record_context_ < level) {
return nullptr;
}
return context_recorder_.load()();
}
// All public methods (except the above) acquire the allocator mutex.
// Thus, do not call a public method from another public method.
Block* malloc(
c10::DeviceIndex device,
size_t orig_size,
cudaStream_t stream) {
// done outside the lock because we don't know what locks the recorder needs
// to have...
auto context = maybeGatherContext(RecordContext::STATE);
std::unique_lock<std::recursive_mutex> lock(mutex);
if (C10_LIKELY(captures_underway.empty())) {
// Processes end-of-life events for outstanding allocations used on
// multiple streams (checks if their GPU-side uses are complete and
// recycles their memory if so)
//
// Q. Why skip process_events if a capture might be underway?
// A. process_events involves cudaEventQueries, illegal during CUDA graph
// capture.
// Dumb simple solution: defer reclaiming these allocations until after
// capture. Cross-stream memory use is uncommon, so the deferral's
// effect on memory use during capture should be small.
process_events(context);
}
size_t size = round_size(orig_size);
auto& pool = get_pool(size, stream);
const size_t alloc_size = get_allocation_size(size);
AllocParams params(device, size, stream, &pool, alloc_size, stats);
params.stat_types = get_stat_types_for_pool(pool);
// First, try to get a block from the existing pool.
bool block_found =
// Search pool
get_free_block(params)
// Trigger callbacks and retry search
|| (trigger_free_memory_callbacks(params) && get_free_block(params));
// Can't reuse an existing block; try to get a new one.
if (!block_found) {
// Do garbage collection if the flag is set.
if (C10_UNLIKELY(
set_fraction &&
CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) {
garbage_collect_cached_blocks(context);
}
// Attempt allocate
// WARNING: alloc_block may release the allocator lock when calling
// cudaMalloc. So far this function has not modified allocator state, but
// keep in mind that any observed allocator state may change across calls
// to alloc_block since it may release the lock.
block_found = alloc_block(params, false, context, lock)
// Free enough available cached blocks to satisfy alloc and retry
// alloc.
|| (release_available_cached_blocks(params, context) &&
alloc_block(params, false, context, lock))
// Free all non-split cached blocks and retry alloc.
|| (C10_LIKELY(captures_underway.empty()) &&
release_cached_blocks(context) &&
alloc_block(params, true, context, lock));
}
if (!block_found) {
// For any error code other than cudaErrorMemoryAllocation,
// alloc_block should have thrown an exception already.
TORCH_INTERNAL_ASSERT(params.err == cudaErrorMemoryAllocation);
size_t device_free = 0;
size_t device_total = 0;
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
std::string allowed_info;
if (set_fraction) {
allowed_info = format_size(allowed_memory_maximum) + " allowed; ";
}
std::string proc_info = reportProcessMemoryInfo(device);
record_trace(
TraceEntry::OOM,
device_free,
params.size(),
params.stream(),
params.device(),
std::move(context));
stats.num_ooms += 1;
c10::reportOutOfMemoryToProfiler(
static_cast<int64_t>(size),
stats.allocated_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
.current,
stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
.current,
c10::Device(c10::DeviceType::CUDA, device));
auto allocated_bytes =
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)]
.current;
auto reserved_bytes =
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
.current;
auto observers_local = oom_observers_;
size_t allocated_in_private_pools = 0;
auto get_size_block = [](const BlockPool& pool) {
size_t res = 0;
for (const auto& block : pool.blocks) {
res += block->size;
}
return res;
};
for (const auto& p : graph_pools) {
allocated_in_private_pools += get_size_block(p.second->large_blocks);
allocated_in_private_pools += get_size_block(p.second->small_blocks);
}
std::string private_pool_msg;
if (allocated_in_private_pools > 0) {
private_pool_msg = "with " + format_size(allocated_in_private_pools) +
" allocated in private pools (e.g., CUDA Graphs), ";
}
// Make sure we do not have the device lock before calling our
// observers which might need hold the GIL
// It is safe to release at this point because will no longer
// be reading any allocator state.
lock.unlock();
for (const auto& obs : observers_local) {
obs(device,
alloc_size,
set_fraction ? allowed_memory_maximum : device_total,
device_free);
}
// "total capacity": total global memory on GPU
// "allowed": memory is allowed to use, which set by fraction.
// "already allocated": memory allocated by the program using the
// caching allocator
// "free": free memory as reported by the CUDA API
// "cached": memory held by the allocator but not used by the program
//
// The "allocated" amount does not include memory allocated outside
// of the caching allocator, such as memory allocated by other programs
// or memory held by the driver.
//
// The sum of "allocated" + "free" + "cached" may be less than the
// total capacity due to memory held by the driver and usage by other
// programs.
//
// Note that at this point free_cached_blocks has already returned all
// possible "cached" memory to the driver. The only remaining "cached"
// memory is split from a larger block that is partially in-use.
TORCH_CHECK_WITH(
OutOfMemoryError,
false,
"CUDA out of memory. Tried to allocate ",
format_size(alloc_size),
". GPU ",
static_cast<int>(device),
" has a total capacity of ",
format_size(device_total),
" of which ",
format_size(device_free),
" is free. ",
proc_info,
allowed_info,
"Of the allocated memory ",
format_size(allocated_bytes + allocated_in_private_pools),
" is allocated by PyTorch, ",
private_pool_msg,
"and ",
format_size(
reserved_bytes - allocated_bytes - allocated_in_private_pools),
" is reserved by PyTorch but unallocated.",
" If reserved but unallocated memory is large try setting",
" PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid"
" fragmentation. See documentation for Memory Management "
" (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)");
}
bool split_remainder = should_split(params.block, params.size());
return alloc_found_block(
params, orig_size, std::move(context), split_remainder);
}
Block* alloc_found_block(
const AllocParams& params,
size_t orig_size,
std::shared_ptr<GatheredContext> context,
bool split_remainder) {
auto size = params.size();
auto device = params.device();
auto pool = params.pool;
auto stream = params.stream();
TORCH_INTERNAL_ASSERT(
params.err == cudaSuccess && params.block != nullptr &&
params.block->ptr != nullptr);
Block* block = params.block;
Block* remaining = nullptr;
const bool already_split = block->is_split();
if (split_remainder) {
remaining = block;
block = new Block(device, stream, size, pool, block->ptr);
block->expandable_segment_ = remaining->expandable_segment_;
block->prev = remaining->prev;
if (block->prev) {
block->prev->next = block;
}
block->next = remaining;
remaining->prev = block;
remaining->ptr = static_cast<char*>(remaining->ptr) + size;
remaining->size -= size;
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
bool inserted = pool->insert_into_blocks(remaining).second;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted);
if (already_split && !block->expandable_segment_) {
// An already-split inactive block is being shrunk by size bytes.
decrease_stat_array(
stats.inactive_split_bytes, block->size, params.stat_types);
} else if (!block->expandable_segment_) {
// A new split inactive block is being created from a previously unsplit
// block, size remaining->size bytes.
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
stats.inactive_split_bytes[stat_type].increase(remaining->size);
stats.inactive_split[stat_type].increase(1);
});
}
} else if (already_split && !block->expandable_segment_) {
// An already-split block is becoming active
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
stats.inactive_split_bytes[stat_type].decrease(block->size);
stats.inactive_split[stat_type].decrease(1);
});
}
block->allocated = true;
block->requested_size = orig_size;
block->context_when_allocated = std::move(context);
record_trace(
TraceEntry::ALLOC,
int64_t(block->ptr),
orig_size,
block->stream,
block->device,
block->context_when_allocated);
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
bool inserted = active_blocks.insert(block).second;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted);
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
stats.allocation[stat_type].increase(1);
stats.allocated_bytes[stat_type].increase(block->size);
stats.active[stat_type].increase(1);
stats.active_bytes[stat_type].increase(block->size);
stats.requested_bytes[stat_type].increase(block->requested_size);
});
if (block->size >= CUDAAllocatorConfig::max_split_size())
stats.oversize_allocations.increase(1);
auto allocated_bytes_gauge =
STATIC_GAUGE(pytorch.CUDACachingAllocator.allocated_bytes);
allocated_bytes_gauge.record(
stats.allocated_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
.current);
c10::reportMemoryUsageToProfiler(
block->ptr,
static_cast<int64_t>(block->size),
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
c10::Device(c10::DeviceType::CUDA, device));
return block;
}
void free(Block* block) {
std::shared_ptr<GatheredContext> context =
maybeGatherContext(RecordContext::ALL);
std::lock_guard<std::recursive_mutex> lock(mutex);
block->allocated = false;
// following logic might modifying underlaying Block, causing the size
// changed. We store ahead for reporting
auto orig_block_ptr = block->ptr;
auto orig_block_size = block->size;
StatTypes stat_types = get_stat_types_for_pool(*block->pool);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
stats.allocation[stat_type].decrease(1);
stats.allocated_bytes[stat_type].decrease(block->size);
});
auto allocated_bytes_gauge =
STATIC_GAUGE(pytorch.CUDACachingAllocator.allocated_bytes);
allocated_bytes_gauge.record(
stats.allocated_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
.current);
record_trace(
TraceEntry::FREE_REQUESTED,
int64_t(block->ptr),
block->requested_size,
block->stream,
block->device,
context ? context : block->context_when_allocated);
if (block->size >= CUDAAllocatorConfig::max_split_size())
stats.oversize_allocations.decrease(1);
if (!block->stream_uses.empty()) {
if (C10_UNLIKELY(!captures_underway.empty())) {
// It's forbidden to cudaEventQuery an event recorded during CUDA graph
// capture. We conservatively defer recording end-of-life events until
// the next call to process_events() (which won't happen until no
// captures are underway)
needs_events_deferred_until_no_capture.push_back(block);
} else {
insert_events(block);
}
} else {
free_block(block, context);
}
c10::reportMemoryUsageToProfiler(
orig_block_ptr,
-static_cast<int64_t>(orig_block_size),
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
c10::Device(c10::DeviceType::CUDA, block->device));
}
void* getBaseAllocation(Block* block, size_t* outSize) {
std::lock_guard<std::recursive_mutex> lock(mutex);
TORCH_CHECK(
!block->expandable_segment_,
"Tensors allocated with expandable_segments:True cannot be shared between processes. Consider using expandable_segments:False in data loading workers via torch.cuda.memory._set_allocator_settings('expandable_segments:False')");
while (block->prev) {
block = block->prev;
}
void* basePtr = block->ptr;
if (outSize) {
size_t size = 0;
while (block) {
size += block->size;
block = block->next;
}
*outSize = size;
}
return basePtr;
}
ShareableHandle shareIpcHandle(Block* block) {
std::lock_guard<std::recursive_mutex> lock(mutex);
std::ostringstream ss;
ss.put(SHAREABLE_HANDLE_VERSION);
ptrdiff_t offset = 0;
if (!block->expandable_segment_) {
ss.put(SHAREABLE_CUDA_MALLOC);
Block* base_block = block;
while (base_block->prev) {
base_block = base_block->prev;
}
offset = (char*)block->ptr - (char*)base_block->ptr;
cudaIpcMemHandle_t handle;
C10_CUDA_CHECK(cudaIpcGetMemHandle(&handle, base_block->ptr));
ss.write((char*)&handle, CUDA_IPC_HANDLE_SIZE);
} else {
ss.put(SHAREABLE_CUDA_EXPANDABLE_SEGMENT);
auto full_range = block->expandable_segment_->share(
SegmentRange(block->ptr, block->size), ss);
offset = (char*)block->ptr - (char*)full_range.ptr;
}
return ShareableHandle{offset, ss.str()};
}
void recordStream(Block* block, cuda::CUDAStream stream) {
std::lock_guard<std::recursive_mutex> lock(mutex);
if (stream.stream() == block->stream) {
// ignore uses on the allocation stream, since those don't require any
// special synchronization
return;
}
block->stream_uses.insert(stream);
if (C10_UNLIKELY(!captures_underway.empty())) {
block_to_cudagraph_stream_uses[block].insert(stream);
}
}
/** get memory fraction limiting maximum allocated memory **/
double getMemoryFraction() {
if (!set_fraction) {
return 1.0;
}
size_t device_free = 0;
size_t device_total = 0;
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
return static_cast<double>(allowed_memory_maximum) /
static_cast<double>(device_total);
}
/** set memory fraction to limit maximum allocated memory **/
void setMemoryFraction(double fraction) {
size_t device_free = 0;
size_t device_total = 0;
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
allowed_memory_maximum =
static_cast<size_t>(fraction * static_cast<double>(device_total));
set_fraction = true;
}
/** returns cached blocks to the system allocator **/
void emptyCache() {
auto context = maybeGatherContext(RecordContext::ALL);
std::lock_guard<std::recursive_mutex> lock(mutex);
release_cached_blocks(context);
}
/** Retrieves size of largest unused block held by the memory cache **/
void cacheInfo(size_t* largest) {
std::lock_guard<std::recursive_mutex> lock(mutex);
if (*largest ==
0) { // make an initial guess if a zero *largest is passed in
size_t tmp_bytes = 0;
C10_CUDA_CHECK(cudaMemGetInfo(
largest, // Use free memory as an optimistic initial guess of *largest
&tmp_bytes));
}
cache_info_aux(large_blocks, largest);
cache_info_aux(small_blocks, largest);
for (const auto& gp : graph_pools) {
cache_info_aux(gp.second->large_blocks, largest);
cache_info_aux(gp.second->small_blocks, largest);
}
}
/** Returns a copy of the memory allocator stats **/
DeviceStats getStats() {
std::lock_guard<std::recursive_mutex> lock(mutex);
return stats;
}
/** Resets the historical accumulation stats for the device **/
void resetAccumulatedStats() {
std::lock_guard<std::recursive_mutex> lock(mutex);
for (const auto statType :
c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) {
stats.allocation[statType].reset_accumulated();
stats.segment[statType].reset_accumulated();
stats.active[statType].reset_accumulated();
stats.inactive_split[statType].reset_accumulated();
stats.allocated_bytes[statType].reset_accumulated();
stats.reserved_bytes[statType].reset_accumulated();
stats.active_bytes[statType].reset_accumulated();
stats.inactive_split_bytes[statType].reset_accumulated();
stats.requested_bytes[statType].reset_accumulated();
}
stats.num_alloc_retries = 0;
stats.num_ooms = 0;
stats.num_sync_all_streams = 0;
stats.num_device_alloc = 0;
stats.num_device_free = 0;
stats.oversize_allocations.reset_accumulated();
stats.oversize_segments.reset_accumulated();
}
/** Resets the historical peak stats for the device **/
void resetPeakStats() {
std::lock_guard<std::recursive_mutex> lock(mutex);
for (const auto statType :
c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) {
stats.allocation[statType].reset_peak();
stats.segment[statType].reset_peak();
stats.active[statType].reset_peak();
stats.inactive_split[statType].reset_peak();
stats.allocated_bytes[statType].reset_peak();
stats.reserved_bytes[statType].reset_peak();
stats.active_bytes[statType].reset_peak();
stats.inactive_split_bytes[statType].reset_peak();
stats.requested_bytes[statType].reset_peak();
}
stats.oversize_allocations.reset_peak();
stats.oversize_segments.reset_peak();
}
/* Checkpoint the state of a private pool necessary to return it to its
* current state */
std::unique_ptr<PrivatePoolState> getCheckpointState(MempoolId_t id) {
auto context = maybeGatherContext(RecordContext::ALL);
std::lock_guard<std::recursive_mutex> lock(mutex);
insert_events_deferred_until_no_capture(context);
auto pool = graph_pools.find(id);
if (pool != graph_pools.end()) {
auto private_pool_head_blocks =
get_private_pool_head_blocks(pool->second.get());
return std::make_unique<PrivatePoolState>(id, private_pool_head_blocks);
} else if (graph_pools_freeable.count(id)) {
TORCH_CHECK(false, "Not expected to checkpoint freeable graph");
} else {
TORCH_CHECK(false, "Could not find pool of id");
}
}
void freeBlocksAllocatedToPool(PrivatePool* private_pool, RestoreResult& rr) {
auto pool_blocks = get_private_pool_head_blocks(private_pool);
std::vector<Block*> head_blocks;
for (Block* block : pool_blocks) {
if (block->prev == nullptr) {
head_blocks.push_back(block);
}
}
for (Block* block : head_blocks) {
Block* curr = block;
while (curr) {
// When we free a block, its pointer should never change
// only its adjacent blocks, so free, then look at pointer
if (curr->allocated) {
TORCH_CHECK(
curr->event_count == 0,
"Events should have synchronized when setting checkpointed block");
rr.allocations_freed.push_back(curr->ptr);
free(curr);
TORCH_CHECK(!curr->allocated)
}
curr = curr->next;
}
}
for (Block* b : get_private_pool_head_blocks(private_pool)) {
Block* curr = b;
while (curr) {
TORCH_CHECK(!curr->allocated);
curr = curr->next;
}
}
}
// checkpoint the state of an allocation that may have been
// split into multiple blocks
void setSegmentStateToCheckpoint(
Block* block,
SegmentState& segment,
const std::shared_ptr<GatheredContext>& context,
RestoreResult& rr) {
Block* curr_block = block;
Block* last_block = block;
TORCH_INTERNAL_ASSERT(block->pool);
BlockPool& pool = *block->pool;
const auto segment_len = segment.blocks.size();
// allocate all blocks in the segment
for (size_t i = 0; i < segment_len; ++i) {
// The last block in every expandable segment is the remaining amount of
// available unmapped virtual address space. We shouldn't change it but
// instead check it is correctly formed then skip over allocating it.
if (i == segment_len - 1 && curr_block->expandable_segment_) {
TORCH_CHECK(curr_block->next == nullptr);
TORCH_CHECK(!curr_block->mapped);
TORCH_CHECK(curr_block->allocated == false);
continue;
}
auto& block_state = segment.blocks.at(i);
AllocParams params(
block_state.device,
block_state.size,
block_state.stream,
&pool,
block_state.size,
stats);
pool.blocks.erase(curr_block);
params.block = curr_block;
params.stat_types = get_stat_types_for_pool(pool);
// splitting a block depends on `max_split_size`, which may have changed
// between when checkpoint was taken and now, so we make sure to recreate
// the behavior from the checkpoint. Keep splitting as long as there is
// space left in the block because the block is already the size of how it
// appears in the segment, so any leftover space belongs to the next
// block.
bool split = curr_block->size > block_state.size;
// curr_block will become next pointer if it is split, so reassign with
// the returned value
curr_block = alloc_found_block(params, block_state.size, context, split);
TORCH_CHECK(curr_block->ptr == block_state.ptr);
TORCH_CHECK(curr_block->size == block_state.size);
last_block = curr_block;
curr_block = curr_block->next;
TORCH_CHECK((curr_block != nullptr) == ((i + 1) < (segment_len)));
}
while (last_block->prev) {
last_block = last_block->prev;
}
// free blocks that are not allocated in the checkpoint
curr_block = last_block;
for (size_t i = 0; i < segment_len; ++i, curr_block = curr_block->next) {
if (i == segment_len - 1 && curr_block->expandable_segment_) {
TORCH_CHECK(curr_block->next == nullptr);
TORCH_CHECK(!curr_block->mapped);
TORCH_CHECK(curr_block->allocated == false);
continue;
}
auto& block_state = segment.blocks.at(i);
TORCH_INTERNAL_ASSERT(curr_block != nullptr);
if (block_state.allocated) {
rr.allocations_created.push_back(curr_block);
continue;
}
free(curr_block);
TORCH_CHECK(curr_block->ptr == block_state.ptr);
TORCH_CHECK(curr_block->allocated == block_state.allocated);
TORCH_CHECK(curr_block->size == block_state.size);
}
}
/**
* Note [Checkpointing PrivatePoolState]
*
* Refer above to Note [Interaction with CUDA graph capture]. Allocations made
* during graph capture are made from a separate private pool. During graph
* capture allocations behave as usual. During graph replay the allocator
* state does not change even as new tensors are created. The private pool
* will not free its blocks to the main caching allocator until cuda graph use
* is finished to prevent an allocation from eager clobbering the memory from
* a live but unaccounted for tensor that was created during replay.
*
* `make_graphed_callables`, a series of separate callables chained in
* successive cuda graphs, can share a memory pool because after a cuda graph
* recording the allocations in the shared private pool exactly reflect the
* tensors that are allocated.
*
* We would like to extend callable chaining to support a graphed callable
* tree. In this scenario, we have a tree of callable chains which will be
* captured with cuda graphs. In the diagram below, we have a tree with four
* callables, A, B, C, and D. Suppose we have captured, and subsequently
* replayed, A, B, and C. Then on a new invocation, we replay A and B, but
* would now like to record D. At this point the private pool will not reflect
* any of the live tensors created during graph replay. Allocations made
* during a new recording with the pool could overwrite those live tensors.
*
* In order to record a new graph capture after replaying prior callables in
* the tree, we need the allocator to reflect the state of the live tensors.
* We checkpoint the state of the private pool after each recording, and then
* reapply it when we are starting a new recording chain. Additionally, we
* must free the allocations for any tensors that died between the end of our
* previous graph replaying and our new recording. All of the allocated
* segments that existed in the checkpointed state must still exist in the
* pool. There may also exist new allocated blocks.
* (TODO : link note [live tensors between iterations] when it exists). For
* every block that is currently allocated but no allocated in the snapshot,
* we will return a pointer to their block.
*.
*
*
* ---------------> A ---------------> B ---------------> C
* |
* |
* |
* |
* â•° ---------------> D
*/
RestoreResult setCheckpointPoolState(PrivatePoolState& pps) {
// To reset the caching allocator state we will
// - Free all the blocks currently allocated to the pool (see [live tensors
// between iterations])
// - Allocate all the blocks in a checkpointed segment, whether they are
// live or not
// - Free the blocks in a checkpointed segment which are not live
// This could be optimized, but it nicely reuses exiting apis, and this
// is not on the hot path.
// following `done outside the lock because we don't know what locks the
// recorder needs to have...`
std::shared_ptr<GatheredContext> context =
maybeGatherContext(RecordContext::STATE);
std::lock_guard<std::recursive_mutex> lock(mutex);
RestoreResult rr;
TORCH_CHECK(
!graph_pools_freeable.count(pps.owner_id),
"Not expected to checkpoint freeable graph");
auto pool = graph_pools.find(pps.owner_id);
TORCH_CHECK(pool != graph_pools.end(), "Could not find private pool id");
PrivatePool* private_pool = pool->second.get();
freeBlocksAllocatedToPool(private_pool, rr);
std::unordered_map<void*, Block*> ptrs_to_blocks;
// at this point, all of the blocks should be free, so they will all be in
// the block set
for (Block* block : private_pool->small_blocks.blocks) {
ptrs_to_blocks[block->ptr] = block;
}
for (Block* block : private_pool->large_blocks.blocks) {
ptrs_to_blocks[block->ptr] = block;
}
for (auto& segment : pps.segments) {
auto ptr = segment.blocks.at(0).ptr;
TORCH_CHECK(ptrs_to_blocks.count(ptr), " could not find ", ptr)
auto block = ptrs_to_blocks[ptr];
setSegmentStateToCheckpoint(block, segment, context, rr);
}
return rr;
}
/** Dump a complete snapshot of the memory held by the allocator. Potentially
* VERY expensive. **/
std::vector<SegmentInfo> snapshot() {
std::lock_guard<std::recursive_mutex> lock(mutex);
std::unordered_map<PrivatePool*, MempoolId_t> pool_to_id;
pool_to_id.reserve(graph_pools.size() + graph_pools_freeable.size());
std::vector<Block*> all_blocks;
MempoolId_t mempool_id = {0, 0};
auto active_mempool = MemPoolContext::getActiveMemPool();
if (active_mempool) {
mempool_id = active_mempool->id();
}
if (mempool_id.first != 0 || mempool_id.second != 0) {
// If there is an active mempool, we find the corresponding PrivatePool
// in graph_pools and only return the blocks from it.
auto pool = graph_pools.find(mempool_id);
if (pool != graph_pools.end()) {
pool_to_id[pool->second.get()] = pool->first;
all_blocks = get_private_pool_head_blocks(pool->second.get());
}
auto pool_freeable = graph_pools_freeable.find(mempool_id);
if (pool_freeable != graph_pools_freeable.end()) {
pool_to_id[pool_freeable->second] = pool_freeable->first;
}
} else {
// When snapshot is called outside a MemPoolContext, we return
// all the blocks in the CUDACachingAllocator (as returned by
// get_all_blocks).
for (const auto& pair : graph_pools) {
pool_to_id[pair.second.get()] = pair.first;
}
for (const auto& pair : graph_pools_freeable) {
pool_to_id[pair.second] = pair.first;
}
all_blocks = get_all_blocks();
}
size_t total_active = 0;
std::vector<SegmentInfo> result;
for (const Block* const head_block : all_blocks) {
// For expandable segments, we report one segment for each contiguous
// mapped range of memory
if (head_block->prev && head_block->prev->mapped) {
continue;
}
result.emplace_back();
SegmentInfo& segment_info = result.back();
segment_info.device = head_block->device;
segment_info.address = reinterpret_cast<size_t>(head_block->ptr);
segment_info.stream = head_block->stream;
segment_info.is_large = (!head_block->pool->is_small);
segment_info.is_expandable = head_block->expandable_segment_;
segment_info.context_when_allocated =
head_block->context_when_segment_allocated;
auto id = pool_to_id.find(head_block->pool->owner_PrivatePool);
if (id != pool_to_id.end()) {
segment_info.owner_private_pool_id = id->second;
}
const Block* block = head_block;
while (block != nullptr && block->mapped) {
segment_info.blocks.emplace_back();
BlockInfo& block_info = segment_info.blocks.back();
block_info.size = block->size;
block_info.requested_size = block->requested_size;
block_info.allocated = block->allocated;
block_info.active = block->allocated || (block->event_count > 0) ||
!block->stream_uses.empty();
segment_info.total_size += block_info.size;
if (block_info.allocated) {
segment_info.allocated_size += block_info.size;
}
if (block_info.active) {
segment_info.active_size += block_info.size;
segment_info.requested_size += block_info.requested_size;
}
block_info.context_when_allocated = block->context_when_allocated;
block = block->next;
}
total_active += segment_info.active_size;
}
std::sort(
result.begin(),
result.end(),
[](const SegmentInfo& a, const SegmentInfo& b) {
return a.address < b.address;
});
record_trace(TraceEntry::SNAPSHOT, 0, total_active, nullptr, 0, nullptr);
return result;
}
std::vector<TraceEntry> trace(
const std::function<time_t(approx_time_t)>& tsc_to_us) {
std::lock_guard<std::recursive_mutex> lock(mutex);
std::vector<TraceEntry> result;
alloc_buffer.getEntries(result);
// Convert all the timestamps from tsc to epoch time in microseconds.
for (auto& te : result) {
te.time_.t_ = tsc_to_us(te.time_.approx_t_);
}
return result;
}
// This function takes the size and number of divisions argument and rounds
// up the size argument for the nearest power-of-2 division.
// For example, if we need to round-up 1200 and number of divisions is 4,
// the size 1200 lies between 1024 and 2048 and if we do 4 divisions between
// them, the values are 1024, 1280, 1536, and 1792. So the function will
// return 1280 as the nearest ceiling of power-2 divison.
static size_t roundup_power2_next_division(size_t size, size_t divisions) {
if (llvm::isPowerOf2_64(size)) {
return size;
}
TORCH_CHECK(divisions >= 2, "Only 2 or more divisions are supported");
// divide the space between these 2's power into equal divisions
// If division is zero, return the power-of-2 ceiling.
size_t power2_floor = llvm::PowerOf2Floor(size);
size_t power2_divison =
power2_floor >> (63 - llvm::countLeadingZeros(divisions));
if (C10_UNLIKELY(power2_divison == 0)) {
return (power2_floor << 1);
}
size_t round_size_floor = size & (~(power2_divison - 1));
return (round_size_floor == size) ? size
: round_size_floor + power2_divison;
}
static size_t round_size(size_t size) {
if (size < kMinBlockSize) {
return kMinBlockSize;
} else {
auto divisions = CUDAAllocatorConfig::roundup_power2_divisions(size);
if (divisions > 1 && size > (kMinBlockSize * divisions)) {
return roundup_power2_next_division(size, divisions);
} else {
return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize);
}
}
}
void ensureExistsAndIncrefPool(MempoolId_t mempool_id) {
// Create a PrivatePool object if it does not exist yet
// and increment its use_count
std::lock_guard<std::recursive_mutex> lock(mutex);
ensure_exists_and_incref_pool(mempool_id);
}
// See Note [Interaction with CUDA graph capture]
// Called by CUDAGraph::capture_begin
void beginAllocateToPool(
MempoolId_t mempool_id,
std::function<bool(cudaStream_t)> filter) {
std::lock_guard<std::recursive_mutex> lock(mutex);
ensure_exists_and_incref_pool(mempool_id);
for (auto it2 = captures_underway.begin(); it2 != captures_underway.end();
++it2) {
TORCH_CHECK(
it2->first != mempool_id,
"beginAllocateToPool: already recording to mempool_id");
}
captures_underway.emplace_back(mempool_id, std::move(filter));
}
// Called by CUDAGraph::capture_end
void endAllocateToPool(MempoolId_t mempool_id) {
std::lock_guard<std::recursive_mutex> lock(mutex);
for (auto it = captures_underway.begin(); it != captures_underway.end();
++it) {
if (it->first == mempool_id) {
captures_underway.erase(it);
return;
}
}
TORCH_CHECK(
false, "endAllocatePool: not currently recording to mempool_id");
}
// Called by CUDAGraph::reset and MemPool::~MemPool()
void releasePool(MempoolId_t mempool_id) {
std::lock_guard<std::recursive_mutex> lock(mutex);
// The instantiated cudaGraphExec_t has been destroyed. We can't blindly
// delete and cudaFree the mempool its capture used, because
// 1. other graph(s) might share the same pool
// 2. the user might still hold references to output tensors allocated
// during capture.
// To handle 1 and 2, we track the number of graphs using this particular
// mempool. When the count reaches 0, we tell free_cached_blocks it may now
// cudaFree blocks from this graph's pool when it discovers they're unused
// (unsplit).
auto pp = get_private_pool(mempool_id);
auto uc = --(pp->use_count);
TORCH_INTERNAL_ASSERT(uc >= 0);
if (uc == 0) {
// Allows free_cached_blocks to begin cudaFreeing this pool's memory,
// and makes sure this pool wasn't somehow made freeable already.
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
bool inserted = graph_pools_freeable.insert({mempool_id, pp}).second;
TORCH_INTERNAL_ASSERT(inserted);
}
}
int getPoolUseCount(MempoolId_t mempool_id) {
std::lock_guard<std::recursive_mutex> lock(mutex);
auto pp = get_private_pool(mempool_id);
return pp->use_count;
}
void addPeerAccess(c10::DeviceIndex dev_to_access) {
std::lock_guard<std::recursive_mutex> lock(mutex);
if (std::find(
devices_with_peer_access_.begin(),
devices_with_peer_access_.end(),
dev_to_access) != devices_with_peer_access_.end()) {
return;
}
devices_with_peer_access_.push_back(dev_to_access);
for (auto& es : expandable_segments_) {
es->addPeer(dev_to_access);
}
}
std::vector<c10::DeviceIndex> peers() const {
std::lock_guard<std::recursive_mutex> lock(mutex);
return devices_with_peer_access_;
}
bool hasAllocatedExpandableSegments() const {
return !expandable_segments_.empty();
}
private:
// All private methods do not acquire the allocator mutex.
std::vector<Block*> get_all_blocks() const {
std::vector<Block*> blocks;
blocks.insert(
blocks.end(), small_blocks.blocks.begin(), small_blocks.blocks.end());
blocks.insert(
blocks.end(), large_blocks.blocks.begin(), large_blocks.blocks.end());
for (const auto& gp : graph_pools) {
blocks.insert(
blocks.end(),
gp.second->small_blocks.blocks.begin(),
gp.second->small_blocks.blocks.end());
blocks.insert(
blocks.end(),
gp.second->large_blocks.blocks.begin(),
gp.second->large_blocks.blocks.end());
}
blocks.insert(blocks.end(), active_blocks.begin(), active_blocks.end());
return blocks;
}
std::vector<Block*> get_private_pool_head_blocks(PrivatePool* pool) const {
std::vector<Block*> blocks;
for (Block* b : active_blocks) {
if ((b->pool == &pool->small_blocks || b->pool == &pool->large_blocks) &&
b->prev == nullptr) {
blocks.push_back(b);
}
}
for (Block* b : pool->small_blocks.blocks) {
if (b->prev == nullptr) {
blocks.push_back(b);
}
}
for (Block* b : pool->large_blocks.blocks) {
if (b->prev == nullptr) {
blocks.push_back(b);
}
}
return blocks;
}
void ensure_exists_and_incref_pool(MempoolId_t mempool_id) {
auto it = graph_pools.find(mempool_id);
if (it == graph_pools.end()) {
// mempool_id does not reference an existing pool.
// Make a new pool for CUDAGraph capture or torch.cuda.use_mem_pool
// usage. use_count is initially 1, which means the pool is
// being used since somebody called ensureExistsAndIncrefPool.
graph_pools.emplace(mempool_id, std::make_unique<PrivatePool>());
} else {
// mempool_id references an existing pool, which the current CUDAGraph
// capture or torch.cuda.use_mem_pool will
// share. Check this pool is live (at least one other capture already
// references it). Increment it to establish the usage.
TORCH_INTERNAL_ASSERT(it->second->use_count > 0);
it->second->use_count++;
}
}
PrivatePool* get_private_pool(MempoolId_t mempool_id) {
auto it = graph_pools.find(mempool_id);
TORCH_INTERNAL_ASSERT(it != graph_pools.end());
return it->second.get();
}
// returns the smallest possible address in any segment
// where there is enough free address space to fit size
// may be composed of free and unmapped segments
Block* find_expandable_block(
c10::DeviceIndex device,
cudaStream_t stream,
BlockPool* pool,
size_t size) {
Block key(device, stream, 0);
auto allocatable = [](Block* b) {
return b && !b->allocated && b->event_count == 0 &&
b->stream_uses.empty();
};
auto has_available_address_space = [&](Block* b) {
size_t bytes = 0;
while (bytes < size && allocatable(b)) {
bytes += b->size;
b = b->next;
}
return bytes >= size;
};
for (auto it = pool->unmapped.lower_bound(&key);
it != pool->unmapped.end() && (*it)->stream == stream;
++it) {
Block* c = *it;
// we found the lowest address of an unmapped segment
// but there might be a free segment we can also use
// right before it
if (allocatable(c->prev)) {
c = c->prev;
}
if (has_available_address_space(c)) {
return c;
}
}
auto segment_size = pool->is_small ? kSmallBuffer : kLargeBuffer;
cudaDeviceProp prop{};
C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
// we allocate enough address space for 1 1/8 the total memory on the GPU.
// This allows for some cases where we have to unmap pages earlier in the
// segment to put them at the end.
size_t address_space_size = prop.totalGlobalMem + prop.totalGlobalMem / 8;
expandable_segments_.emplace_back(new ExpandableSegment(
device,
stream,
address_space_size,
segment_size,
devices_with_peer_access_));
ExpandableSegment* es = expandable_segments_.back();
Block* candidate = new Block(device, stream, es->size(), pool, es->ptr());
candidate->mapped = false;
candidate->expandable_segment_ = es;
pool->unmapped.insert(candidate);
return candidate;
}
bool map_block(
Block* to_map,
size_t size,
const std::shared_ptr<GatheredContext>& ctx) {
TORCH_INTERNAL_ASSERT(!to_map->mapped && size <= to_map->size);
TORCH_INTERNAL_ASSERT(
!to_map->context_when_allocated); // unmapped blocks should not keep
// history
auto mapped_range =
to_map->expandable_segment_->map(SegmentRange{to_map->ptr, size});
// failed to map the memory
if (mapped_range.size == 0) {
return false;
}
TORCH_INTERNAL_ASSERT(
mapped_range.ptr == to_map->ptr && mapped_range.size >= size);
BlockPool& pool = *to_map->pool;
pool.unmapped.erase(to_map);
to_map->mapped = true;
if (mapped_range.size < to_map->size) {
// to_map -> remaining -> to_map->next(?)
Block* remaining = new Block(
to_map->device,
to_map->stream,
to_map->size - mapped_range.size,
&pool,
static_cast<char*>(to_map->ptr) + mapped_range.size);
remaining->mapped = false;
remaining->expandable_segment_ = to_map->expandable_segment_;
remaining->splice(to_map, to_map->next);
pool.unmapped.insert(remaining);
to_map->size = mapped_range.size;
}
try_merge_blocks(to_map, to_map->prev, pool);
try_merge_blocks(to_map, to_map->next, pool);
pool.insert_into_blocks(to_map);
// update statistics
total_allocated_memory += mapped_range.size;
StatTypes stat_types = get_stat_types_for_pool(*to_map->pool);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
stats.reserved_bytes[stat_type].increase(mapped_range.size);
});
auto reserved_bytes_gauge =
STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes);
reserved_bytes_gauge.record(
stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
.current);
stats.num_device_alloc++;
record_trace(
TraceEntry::SEGMENT_MAP,
int64_t(mapped_range.ptr),
mapped_range.size,
to_map->stream,
to_map->device,
ctx);
if (!to_map->prev && !to_map->context_when_segment_allocated) {
to_map->context_when_segment_allocated = ctx;
}
return true;
}
Block* try_allocate_expandable_block(
c10::DeviceIndex device,
cudaStream_t stream,
BlockPool* pool,
size_t size,
const std::shared_ptr<GatheredContext>& ctx) {
Block* candidate = find_expandable_block(device, stream, pool, size);
// Candidate is now a list free/unmapped blocks with at least size room:
// unmapped -> null
// unmapped -> free -> *
// free -> unmapped -> *
if (!candidate->mapped &&
!map_block(candidate, std::min(candidate->size, size), ctx)) {
return nullptr;
}
TORCH_INTERNAL_ASSERT(candidate->mapped);
while (candidate->size < size) {
// invariant: free -> unmapped -> *
// map_block will map some of unmapped and merge with free
auto remaining = size - candidate->size;
auto new_candidate = candidate->next;
if (!map_block(
new_candidate, std::min(remaining, candidate->next->size), ctx)) {
return nullptr;
}
candidate = new_candidate;
}
pool->blocks.erase(candidate);
return candidate;
}
/** moves a block into a pool of cached free blocks */
void free_block(
Block* block,
const std::shared_ptr<GatheredContext>& context) {
TORCH_INTERNAL_ASSERT(
!block->allocated && block->event_count == 0 &&
block->stream_uses.empty());
record_trace(
TraceEntry::FREE_COMPLETED,
int64_t(block->ptr),
block->requested_size,
block->stream,
block->device,
context ? context : block->context_when_allocated);
block->context_when_allocated = nullptr;
size_t original_block_size = block->size;
size_t requested_size = block->requested_size;
auto& pool = *block->pool;
int64_t net_change_inactive_split_blocks = 0;
int64_t net_change_inactive_split_size = 0;
const std::array<Block*, 2> merge_candidates = {block->prev, block->next};
for (Block* merge_candidate : merge_candidates) {
const auto subsumed_size = try_merge_blocks(block, merge_candidate, pool);
if (subsumed_size > 0) {
net_change_inactive_split_blocks -= 1;
net_change_inactive_split_size -= static_cast<int64_t>(subsumed_size);
}
}
active_blocks.erase(block);
// Makes sure the Block* isn't already present in the pool we're freeing it
// back into.
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
bool inserted = pool.insert_into_blocks(block).second;
TORCH_INTERNAL_ASSERT(inserted);
if (block->is_split()) {
net_change_inactive_split_blocks += 1;
net_change_inactive_split_size += static_cast<int64_t>(block->size);
}
StatTypes stat_types = get_stat_types_for_pool(pool);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
// inactive_split tries to capture the idea that blocks
// cannot be freed when requested, but fully free pages
// of expandable blocks can always be freed.
// The logic to track this as statistic is pretty involved,
// so we simply just exclude expandable segments from
// inactive_split
if (!block->expandable_segment_) {
if (net_change_inactive_split_blocks > 0) {
stats.inactive_split[stat_type].increase(
static_cast<size_t>(net_change_inactive_split_blocks));
} else if (net_change_inactive_split_blocks < 0) {
stats.inactive_split[stat_type].decrease(
static_cast<size_t>(-net_change_inactive_split_blocks));
}
if (net_change_inactive_split_size > 0) {
stats.inactive_split_bytes[stat_type].increase(
static_cast<size_t>(net_change_inactive_split_size));
} else if (net_change_inactive_split_size < 0) {
stats.inactive_split_bytes[stat_type].decrease(
static_cast<size_t>(-net_change_inactive_split_size));
}
}
stats.active[stat_type].decrease(1);
stats.active_bytes[stat_type].decrease(original_block_size);
stats.requested_bytes[stat_type].decrease(requested_size);
});
}
/** combine previously split blocks. returns the size of the subsumed block,
* or 0 on failure. */
size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) {
if (!src || src->allocated || src->event_count > 0 ||
!src->stream_uses.empty() || dst->mapped != src->mapped) {
return 0;
}
AT_ASSERT(dst->is_split() && src->is_split());
if (dst->prev == src) { // [src dst]
dst->ptr = src->ptr;
dst->prev = src->prev;
if (dst->prev) {
dst->prev->next = dst;
}
dst->context_when_segment_allocated =
std::move(src->context_when_segment_allocated);
} else { // [dest src]
dst->next = src->next;
if (dst->next) {
dst->next->prev = dst;
}
}
const size_t subsumed_size = src->size;
dst->size += subsumed_size;
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
auto erased =
src->mapped ? pool.blocks.erase(src) : pool.unmapped.erase(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(erased == 1);
delete src;
return subsumed_size;
}
BlockPool& get_pool(size_t size, cudaStream_t stream) {
// captures_underway is a conservative guess that the current stream may be
// capturing. It's only non-empty if some thread has begun and not yet ended
// a capture, so it's usually 0, and we can short-circuit
// cudaStreamCaptureStatus (which does a TLS lookup).
if (C10_UNLIKELY(!captures_underway.empty())) {
for (auto& entry : captures_underway) {
if (entry.second(stream)) {
auto it1 = graph_pools.find(entry.first);
TORCH_INTERNAL_ASSERT(it1 != graph_pools.end());
if (size <= kSmallSize) {
return it1->second->small_blocks;
} else {
return it1->second->large_blocks;
}
}
}
}
if (size <= kSmallSize) {
return small_blocks;
} else {
return large_blocks;
}
}
StatTypes get_stat_types_for_pool(const BlockPool& pool) {
StatTypes stat_types = {false};
stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true;
stat_types[static_cast<size_t>(
pool.is_small ? StatType::SMALL_POOL : StatType::LARGE_POOL)] = true;
return stat_types;
}
bool should_split(const Block* block, size_t size) {
size_t remaining = block->size - size;
if (block->pool->is_small || CUDAAllocatorConfig::expandable_segments()) {
return remaining >= kMinBlockSize;
} else {
return (size < CUDAAllocatorConfig::max_split_size()) &&
(remaining > kSmallSize);
}
}
static size_t get_allocation_size(size_t size) {
if (size <= kSmallSize) {
return kSmallBuffer;
} else if (size < kMinLargeAlloc) {
return kLargeBuffer;
} else {
return kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge);
}
}
bool get_free_block(AllocParams& p) {
BlockPool& pool = *p.pool;
if (C10_UNLIKELY(
set_fraction &&
CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) {
// Track block reuse interval only when garbage collection is enabled.
++pool.get_free_blocks_call_count;
}
auto it = pool.blocks.lower_bound(&p.search_key);
if (it == pool.blocks.end() || (*it)->stream != p.stream())
return false;
if ((*it)->expandable_segment_) {
if (CUDAAllocatorConfig::expandable_segments()) {
// if we are allocated to the part of the block that is expandable
// for the purposes of "best fit" we consider its size to be the size it
// can expand to, not the size it currently is. This means that we
// sometimes have to search for blocks with bigger 'size' before
// choosing this segment.
auto expandable_size = [](Block* b) {
return b->size + (b->next && !b->next->mapped ? b->next->size : 0);
};
auto next = it;
next++;
while ((*it)->expandable_segment_ && next != pool.blocks.end() &&
(*next)->stream == p.stream() &&
expandable_size(*next) < expandable_size(*it)) {
it = next++;
}
} else {
// Rarely expandable segments has been turned off after we have
// already allocated some blocks as expandable. For instance,
// since we cannot share expandable memory via IPC, someone might
// temporarily disable it. In this case we need to honor this request
// by only finding non-expandable blocks
do {
it++;
} while (it != pool.blocks.end() && (*it)->expandable_segment_ &&
(*it)->stream == p.stream());
if (it == pool.blocks.end() || (*it)->stream != p.stream()) {
return false;
}
}
}
// Do not return an oversized block for a large request
if ((p.size() < CUDAAllocatorConfig::max_split_size()) &&
((*it)->size >= CUDAAllocatorConfig::max_split_size()))
return false;
// Allow oversized block size to be rounded up but within a limit
if ((p.size() >= CUDAAllocatorConfig::max_split_size()) &&
((*it)->size >=
p.size() + CUDAAllocatorConfig::max_non_split_rounding_size()))
return false;
p.block = *it;
pool.blocks.erase(it);
return true;
}
bool trigger_free_memory_callbacks(AllocParams& p) {
bool freed_memory = false;
for (const auto& name : FreeCudaMemoryCallbacksRegistry()->Keys()) {
freed_memory |=
FreeCudaMemoryCallbacksRegistry()->Create(name)->Execute();
}
return freed_memory;
}
void garbage_collect_cached_blocks(
const std::shared_ptr<GatheredContext>& context) {
// Free unused cached blocks to reclaim GPU memory.
// Unlike release_cached_blocks(), this does not enforce synchronization and
// therefore should be of less overheads.
size_t gc_threshold = static_cast<size_t>(
CUDAAllocatorConfig::garbage_collection_threshold() *
static_cast<double>(allowed_memory_maximum));
// No need to trigger GC yet
if (total_allocated_memory <= gc_threshold) {
return;
}
const auto target_size = total_allocated_memory - gc_threshold;
size_t gc_reclaimed = 0;
// Calculate the total age of the free-able blocks. We'll use it later to
// get "avg age" threshold.
size_t total_age = 0.0;
int freeable_block_count = 0;
for (auto& b : large_blocks.blocks) {
if (!b->is_split()) {
total_age += b->gc_count();
++freeable_block_count;
}
}
// No free-able blocks?
if (freeable_block_count == 0) {
return;
}
// Repeat GC until we reach reclaim > target size.
bool block_freed = true;
while (gc_reclaimed < target_size && block_freed == true &&
freeable_block_count > 0) {
// Free blocks exceeding this age threshold first.
double age_threshold =
static_cast<double>(total_age) / freeable_block_count;
// Stop iteration if we can no longer free a block.
block_freed = false;
// Free blocks of > avg age. Don't stop upon reaching the target_size,
// we don't want this GC to be triggered frequently.
auto it = large_blocks.blocks.begin();
while (it != large_blocks.blocks.end()) {
Block* block = *it;
++it;
if (!block->is_split() && !block->expandable_segment_ &&
static_cast<double>(block->gc_count()) >= age_threshold) {
block_freed = true;
gc_reclaimed += block->size;
total_age -= block->gc_count(); // Decrement the age
freeable_block_count--; // One less block that can be freed
release_block(block, context);
}
}
}
}
// This function assumes that global lock has been taken whle calling into
// this function. We do cudaMalloc sync call in this function which
// can be expensive while holding the lock. Hence, we pass-in the lock to the
// function to temporarily release the lock before cudaMalloc call and acquire
// it back again after the call so that other threads dont get blocked.
bool alloc_block(
AllocParams& p,
bool isRetry,
const std::shared_ptr<GatheredContext>& ctx,
std::unique_lock<std::recursive_mutex>& lock) {
// Defensively checks for preexisting CUDA error state.
C10_CUDA_CHECK(cudaGetLastError());
size_t size = p.alloc_size;
void* ptr = nullptr;
if (isRetry) {
stats.num_alloc_retries += 1;
}
#ifdef FBCODE_CAFFE2
bool in_fbcode = true;
#else
bool in_fbcode = false;
#endif
if (set_fraction &&
total_allocated_memory + size > allowed_memory_maximum) {
p.err = cudaErrorMemoryAllocation;
return false;
// Temporarily disable checkpointing & cudagraphs internally
} else if (
CUDAAllocatorConfig::expandable_segments() &&
!(in_fbcode && p.pool->owner_PrivatePool)) {
p.block = try_allocate_expandable_block(
p.device(), p.stream(), p.pool, p.size(), ctx);
if (p.block) {
p.err = cudaSuccess;
if (p.pool->owner_PrivatePool) {
// The block is for a CUDA graph's PrivatePool.
p.pool->owner_PrivatePool->cudaMalloc_count++;
}
} else {
p.err = cudaErrorMemoryAllocation;
}
return bool(p.block);
} else {
auto active_pool = MemPoolContext::getActiveMemPool();
if (active_pool && active_pool->allocator() &&
p.pool->owner_PrivatePool) {
// Ensure that active_pool and p.pool are the same
auto pp = get_private_pool(active_pool->id());
TORCH_INTERNAL_ASSERT(pp == p.pool->owner_PrivatePool);
}
if (CUDAAllocatorConfig::release_lock_on_cudamalloc()) {
// At scope exit, acquire the lock again. This provides safety against
// any potential exceptions in the cudaMallocMaybeCapturing function.
auto sg = c10::make_scope_exit([&]() { lock.lock(); });
lock.unlock();
p.err = cudaMallocMaybeCapturing(&ptr, size, p);
} else {
p.err = cudaMallocMaybeCapturing(&ptr, size, p);
}
if (CUDAAllocatorConfig::release_lock_on_cudamalloc()) {
TORCH_CHECK(
lock.owns_lock(), "Failed to acquire lock after cudaMalloc");
}
if (p.err != cudaSuccess) {
if (p.err == cudaErrorMemoryAllocation) {
// If this is the first attempt (!isRetry), we can forgive and clear
// CUDA's internal error state.
//
// If this is the second attempt (isRetry), malloc's TORCH_CHECK_WITH
// will take over to throw a helpful exception. The user can choose
// to catch the exception, free some stuff in their script, and
// attempt the allocation again. In this case, we can also forgive and
// clear CUDA's internal error state.
(void)cudaGetLastError();
} else {
// If the error's unrelated to memory allocation, we should throw
// immediately.
C10_CUDA_CHECK(p.err);
}
return false;
}
}
if (p.pool->owner_PrivatePool) {
// The block is for a CUDA graph's PrivatePool.
p.pool->owner_PrivatePool->cudaMalloc_count++;
}
total_allocated_memory += size;
p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr);
for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
stats.segment[stat_type].increase(1);
stats.reserved_bytes[stat_type].increase(size);
});
if (size >= CUDAAllocatorConfig::max_split_size())
stats.oversize_segments.increase(1);
auto reserved_bytes_gauge =
STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes);
reserved_bytes_gauge.record(
stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
.current);
// p.block came from new, not cudaMalloc. It should not be nullptr here.
TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr);
stats.num_device_alloc++;
record_trace(
TraceEntry::SEGMENT_ALLOC,
int64_t(p.block->ptr),
p.block->size,
p.stream(),
p.device(),
ctx);
p.block->context_when_segment_allocated = ctx;
return true;
}
/** Free one or more oversize blocks to the system allocator. But only enough
* **/
/** to satisfy the target size **/
bool release_available_cached_blocks(
const AllocParams& p,
const std::shared_ptr<GatheredContext>& context) {
if (CUDAAllocatorConfig::max_split_size() ==
std::numeric_limits<size_t>::max())
return false;
BlockPool& pool = *p.pool;
// because of std::unique_ptr, block cannot be trivially copied
// Use constructor for search key.
Block key(p.search_key.device, p.search_key.stream, p.search_key.size);
key.size = (key.size < CUDAAllocatorConfig::max_split_size())
? CUDAAllocatorConfig::max_split_size()
: key.size;
auto it = pool.blocks.lower_bound(&key);
if (it == pool.blocks.end() || (*it)->stream != p.stream() ||
(*it)->expandable_segment_) {
// No single block is large enough; free multiple oversize blocks,
// starting with the largest
if (it == pool.blocks.begin())
return false;
size_t totalReleased = 0;
--it; // Back up one item. Now on the largest block for the correct
// stream
while ((totalReleased < key.size) &&
((*it)->size >= CUDAAllocatorConfig::max_split_size()) &&
((*it)->stream == p.stream())) {
auto cur = it;
bool is_first = cur == pool.blocks.begin();
if (!is_first) {
--it;
}
if (!(*cur)->expandable_segment_) {
release_block(*cur, context);
totalReleased += (*cur)->size;
}
if (is_first) {
break;
}
}
if (totalReleased < key.size)
return false;
} else {
release_block(*it, context);
}
return true;
}
bool release_cached_blocks(const std::shared_ptr<GatheredContext>& context) {
MempoolId_t mempool_id = {0, 0};
auto active_mempool = MemPoolContext::getActiveMemPool();
if (active_mempool) {
mempool_id = active_mempool->id();
}
if (mempool_id.first == 0 && mempool_id.second == 0) {
// If there is no active mempool, we work on releasing *all* blocks.
// First ensure that all blocks that can't currently be allocated due to
// outstanding events are returned to the pool.
synchronize_and_free_events(context);
// Free all non-split cached blocks to system allocator
release_blocks(large_blocks, context);
release_blocks(small_blocks, context);
}
for (auto it = graph_pools_freeable.begin();
it != graph_pools_freeable.end();) {
if (mempool_id.first != 0 || mempool_id.second != 0) {
if (it->first == mempool_id) {
// If there is an active mempool, we sync only the events
// associated with the pool
synchronize_and_free_events(context, it->second);
} else {
// otherwise we move on
++it;
continue;
}
}
// See notifyCaptureDestroy for the strategy here.
TORCH_INTERNAL_ASSERT(it->second->use_count == 0);
release_blocks(it->second->small_blocks, context);
release_blocks(it->second->large_blocks, context);
if (it->second->cudaMalloc_count == 0) {
auto erase_count = graph_pools.erase(it->first);
TORCH_INTERNAL_ASSERT(erase_count == 1);
it = graph_pools_freeable.erase(it);
} else {
++it;
}
}
return true;
}
void release_expandable_segment(Block* block) {
TORCH_INTERNAL_ASSERT(
block->size == block->expandable_segment_->size(),
"block disagrees with segment");
TORCH_INTERNAL_ASSERT(!block->mapped);
auto it = std::find(
expandable_segments_.begin(),
expandable_segments_.end(),
block->expandable_segment_);
TORCH_INTERNAL_ASSERT(it != expandable_segments_.end());
expandable_segments_.erase(it);
block->pool->unmapped.erase(block);
delete block->expandable_segment_;
delete block;
}
void release_block(
Block* block,
const std::shared_ptr<GatheredContext>& context) {
TORCH_INTERNAL_ASSERT(!block->expandable_segment_);
stats.num_device_free++;
record_trace(
TraceEntry::SEGMENT_FREE,
int64_t(block->ptr),
block->size,
block->stream,
block->device,
context ? context : block->context_when_segment_allocated);
auto* pool = block->pool;
auto active_pool = MemPoolContext::getActiveMemPool();
if (active_pool && active_pool->allocator() && pool->owner_PrivatePool) {
// Ensure that active_pool and pool are the same
auto pp = get_private_pool(active_pool->id());
TORCH_INTERNAL_ASSERT(pp == pool->owner_PrivatePool);
// If there is an active mempool with a given allocator,
// we use the given allocator's delete function.
active_pool->allocator()->raw_delete((void*)block->ptr);
} else {
C10_CUDA_CHECK(cudaFree((void*)block->ptr));
}
total_allocated_memory -= block->size;
if (pool->owner_PrivatePool) {
// The cudaFreed block belonged to a CUDA graph's PrivatePool.
TORCH_INTERNAL_ASSERT(pool->owner_PrivatePool->cudaMalloc_count > 0);
pool->owner_PrivatePool->cudaMalloc_count--;
}
StatTypes stat_types = get_stat_types_for_pool(*pool);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
stats.segment[stat_type].decrease(1);
stats.reserved_bytes[stat_type].decrease(block->size);
});
auto reserved_bytes_gauge =
STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes);
reserved_bytes_gauge.record(
stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
.current);
if (block->size >= CUDAAllocatorConfig::max_split_size())
stats.oversize_segments.decrease(1);
pool->blocks.erase(block);
delete block;
}
void unmap_block(
Block* block,
const std::shared_ptr<GatheredContext>& context) {
auto unmapped = block->expandable_segment_->unmap(
SegmentRange{block->ptr, block->size});
if (unmapped.size == 0) {
return;
}
block->pool->blocks.erase(block);
ptrdiff_t before_size =
static_cast<char*>(unmapped.ptr) - static_cast<char*>(block->ptr);
if (before_size > 0) {
// prev? -> before_free -> block
Block* before_free = new Block(
block->device, block->stream, before_size, block->pool, block->ptr);
before_free->expandable_segment_ = block->expandable_segment_;
before_free->splice(block->prev, block);
block->pool->insert_into_blocks(before_free);
}
auto after_size = block->size - (before_size + unmapped.size);
if (after_size > 0) {
// block -> after_free -> next?
Block* after_free = new Block(
block->device,
block->stream,
after_size,
block->pool,
static_cast<char*>(unmapped.ptr) + unmapped.size);
after_free->expandable_segment_ = block->expandable_segment_;
after_free->splice(block, block->next);
block->pool->insert_into_blocks(after_free);
}
block->ptr = unmapped.ptr;
block->size = unmapped.size;
block->mapped = false;
try_merge_blocks(block, block->prev, *block->pool);
try_merge_blocks(block, block->next, *block->pool);
block->pool->unmapped.insert(block);
// update statistics
total_allocated_memory -= unmapped.size;
StatTypes stat_types = get_stat_types_for_pool(*block->pool);
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
stats.reserved_bytes[stat_type].decrease(unmapped.size);
});
auto reserved_bytes_gauge =
STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes);
reserved_bytes_gauge.record(
stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
.current);
if (block->pool->owner_PrivatePool) {
// The cudaFreed block belonged to a CUDA graph's PrivatePool.
TORCH_INTERNAL_ASSERT(
block->pool->owner_PrivatePool->cudaMalloc_count > 0);
block->pool->owner_PrivatePool->cudaMalloc_count--;
}
stats.num_device_free++;
record_trace(
TraceEntry::SEGMENT_UNMAP,
int64_t(unmapped.ptr),
unmapped.size,
block->stream,
block->device,
context ? context : block->context_when_segment_allocated);
}
void release_blocks(
BlockPool& pool,
const std::shared_ptr<GatheredContext>& context) {
std::vector<Block*> to_unmap;
// Frees all non-split blocks
auto it = pool.blocks.begin();
while (it != pool.blocks.end()) {
Block* block = *it;
++it;
if (block->expandable_segment_) {
// unmapping will mutate the free pool
// so just gather what needs to be freed
// to avoid invalidating the iterator
to_unmap.push_back(block);
} else if (!block->prev && !block->next) {
release_block(block, context);
}
}
for (Block* block : to_unmap) {
unmap_block(block, context);
if (!block->prev && !block->next) {
release_expandable_segment(block);
}
}
}
EventPool::Event create_event_internal(c10::DeviceIndex idx) {
// Leak the event pool to avoid shutdown issues.
static auto* event_pool = new EventPool();
return event_pool->get(idx);
}
void synchronize_and_free_events(
const std::shared_ptr<GatheredContext>& context,
PrivatePool* pool = nullptr) {
// Synchronize on outstanding events and then free associated blocks.
stats.num_sync_all_streams++;
// This function syncs, so capture should not be underway. Might as well
// make sure capture-deferred end of life events get processed too.
TORCH_INTERNAL_ASSERT(captures_underway.empty());
insert_events_deferred_until_no_capture(context);
for (auto it = cuda_events.begin(); it != cuda_events.end();) {
for (auto e = it->second.begin(); e != it->second.end();) {
Block* block = e->second;
// If a pool was passed, only synchronize the events
// that are associated with the pool, otherwise move on
if (pool && block->pool->owner_PrivatePool != pool) {
++e;
continue;
}
EventPool::Event event = std::move(e->first);
C10_CUDA_CHECK(cudaEventSynchronize(*event));
block->event_count--;
if (block->event_count == 0) {
free_block(block, context);
}
// We are done with the event, so erase it from the deque
e = it->second.erase(e);
}
// If the events deque is empty, only then erase the
// cuda event from the events map
if (it->second.empty()) {
it = cuda_events.erase(it);
} else {
it++;
}
}
}
void remove_cudagraph_stream_uses(Block* block) {
// remove stream uses added during cudagraph capture
// (i.e., block->stream_uses - block->cudagraph_stream_uses)
if (C10_UNLIKELY(
block_to_cudagraph_stream_uses.find(block) !=
block_to_cudagraph_stream_uses.end())) {
stream_set streams(std::move(block->stream_uses));
AT_ASSERT(block->stream_uses.empty());
for (auto& stream : streams) {
if (block_to_cudagraph_stream_uses[block].find(stream) ==
block_to_cudagraph_stream_uses[block].end()) {
block->stream_uses.insert(stream);
}
}
block_to_cudagraph_stream_uses.erase(block);
}
}
void insert_events(Block* block) {
c10::DeviceIndex prev_device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&prev_device));
stream_set streams(std::move(block->stream_uses));
AT_ASSERT(block->stream_uses.empty());
for (auto& stream : streams) {
C10_CUDA_CHECK(c10::cuda::SetDevice(stream.device_index()));
EventPool::Event event = create_event_internal(stream.device_index());
C10_CUDA_CHECK(cudaEventRecord(*event, stream.stream()));
block->event_count++;
cuda_events[stream].emplace_back(std::move(event), block);
}
C10_CUDA_CHECK(c10::cuda::MaybeSetDevice(prev_device));
}
void insert_events_deferred_until_no_capture(
const std::shared_ptr<GatheredContext>& context) {
if (C10_UNLIKELY(!needs_events_deferred_until_no_capture.empty())) {
for (auto* block : needs_events_deferred_until_no_capture) {
TORCH_INTERNAL_ASSERT(!block->stream_uses.empty());
// only streams recorded before cudagraph will be used to insert events
// since we know all streams recorded during cudagraph must have
// completed (refer to Section 3.2.8.7.3.1 Cross-stream Dependencies and
// Events in CUDA Programming Guide).
remove_cudagraph_stream_uses(block);
insert_events(block);
if (block->event_count == 0) {
free_block(block, context);
}
}
needs_events_deferred_until_no_capture.clear();
}
}
void process_events(const std::shared_ptr<GatheredContext>& context) {
insert_events_deferred_until_no_capture(context);
// Process outstanding cudaEvents. Events that are completed are
// removed from the queue, and the 'event_count' for the
// corresponding allocation is decremented. We maintain a separate
// list of events per stream to avoid head-of-line delays if one
// or more streams has long-running operations.
// Iterate over different streams.
for (auto it = cuda_events.begin(); it != cuda_events.end();) {
// Iterate over this stream's (event, block) pairs.
while (!it->second.empty()) {
auto& e = it->second.front();
EventPool::Event event = std::move(e.first);
Block* block = e.second;
cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(*event));
if (err == cudaErrorNotReady) {
// ignore and clear the error if not ready
(void)cudaGetLastError();
// Return the ownership of the Event (unique ptr)
e.first = std::move(event);
break;
} else if (err != cudaSuccess) {
C10_CUDA_CHECK(err);
}
block->event_count--;
if (block->event_count == 0) {
free_block(block, context);
}
it->second.pop_front();
}
if (it->second.empty()) {
it = cuda_events.erase(it);
} else {
it++;
}
}
}
// Iterates over sizes of all memory blocks for given device in given pool
void cache_info_aux(const BlockPool& pool, size_t* largest) {
for (const auto& block : pool.blocks) {
const auto blocksize = block->size;
if (blocksize > *largest) {
*largest = blocksize;
}
}
}
void record_trace(
TraceEntry::Action action,
size_t addr,
size_t size,
cudaStream_t stream,
c10::DeviceIndex device,
std::shared_ptr<GatheredContext> context) {
if (!record_history && trace_trackers_.empty())
return;
auto te = TraceEntry(
action,
device,
addr,
size,
stream,
getApproximateTime(),
record_context_ >= RecordContext::ALLOC ? std::move(context) : nullptr);
// Callbacks should not include any Pytorch call
for (const auto& cb : trace_trackers_) {
cb(te);
}
if (record_history) {
alloc_buffer.insertEntries(te);
}
}
};
// Returns whether to force all allocations to bypass the caching allocator and
// go straight to cudaMalloc. This setting is useful when debugging GPU memory
// errors, since the caching allocator foils cuda-memcheck.
static bool forceUncachedAllocator() {
// Allow either CUDA or HIP name for env var for maximum user comfort
// the CUDA env var avoids being hipified in cuda_to_hip_mappings.py
static bool has_cuda_env =
c10::utils::has_env("PYTORCH_NO_CUDA_MEMORY_CACHING");
static bool has_rocm_env =
c10::utils::has_env("PYTORCH_NO_HIP_MEMORY_CACHING");
static bool force_uncached = has_cuda_env || has_rocm_env;
return force_uncached;
}
static void* uncached_allocate(size_t size) {
void* devPtr = nullptr;
// Deliberately don't use cudaMallocMaybeCapturing here, to force an error
// if someone tries to use forceUncachedAllocator while capturing.
C10_CUDA_CHECK(cudaMalloc(&devPtr, size));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_memory_allocation(
c10::kCUDA, reinterpret_cast<uintptr_t>(devPtr));
}
return devPtr;
}
static void uncached_delete(void* ptr) {
if (TORCH_SDT_IS_ENABLED(free)) {
TORCH_SDT_WITH_SEMAPHORE(free, ptr);
}
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_memory_deallocation(
c10::kCUDA, reinterpret_cast<uintptr_t>(ptr));
}
C10_CUDA_CHECK(cudaFree(ptr));
}
static void local_raw_delete(void* ptr);
class NativeCachingAllocator : public CUDAAllocator {
private:
// allows this allocator to be turned on and off programmatically
bool enable_ = true;
// Shard allocation region to have independent mutexes to reduce contention.
static constexpr size_t kNumMutexShard = 67;
// TODO: use std::hardware_destructive_interference_size once available
struct alignas(64) AlignedMutex {
std::mutex m;
};
std::array<AlignedMutex, kNumMutexShard> mutex;
// allocated blocks by device pointer
std::array<ska::flat_hash_map<void*, Block*>, kNumMutexShard>
allocated_blocks;
static size_t get_mutex_shard_id(void* ptr) {
return twang_mix64((size_t)ptr) % kNumMutexShard;
}
void add_allocated_block(Block* block) {
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
const auto mutex_shard_id = get_mutex_shard_id(block->ptr);
std::lock_guard<std::mutex> lock(mutex[mutex_shard_id].m);
allocated_blocks[mutex_shard_id][block->ptr] = block;
}
// Variables by memory snapshot
c10::ApproximateClockToUnixTimeConverter clock_converter;
bool record_history = false;
RingBuffer<AnnotationEntry> annotation_buffer;
public:
std::vector<std::unique_ptr<DeviceCachingAllocator>> device_allocator;
Block* get_allocated_block(void* ptr, bool remove = false) {
const auto mutex_shard_id = get_mutex_shard_id(ptr);
std::lock_guard<std::mutex> lock(mutex[mutex_shard_id].m);
auto it = allocated_blocks[mutex_shard_id].find(ptr);
if (it == allocated_blocks[mutex_shard_id].end()) {
return nullptr;
}
Block* block = it->second;
if (remove) {
allocated_blocks[mutex_shard_id].erase(it);
}
return block;
}
void init(int device_count) override {
const auto size = static_cast<int64_t>(device_allocator.size());
if (size < device_count) {
device_allocator.resize(device_count);
for (const auto i : c10::irange(size, device_count)) {
device_allocator[i] = std::make_unique<DeviceCachingAllocator>();
}
}
}
bool initialized() override {
return !device_allocator.empty();
}
/** allocates a block which is safe to use from the provided stream */
void malloc(
void** devPtr,
c10::DeviceIndex device,
size_t size,
cudaStream_t stream) {
TORCH_INTERNAL_ASSERT(
0 <= device && static_cast<size_t>(device) < device_allocator.size(),
"Allocator not initialized for device ",
device,
": did you call init?");
Block* block = device_allocator[device]->malloc(device, size, stream);
add_allocated_block(block);
*devPtr = (void*)block->ptr;
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_memory_allocation(
c10::kCUDA, reinterpret_cast<uintptr_t>(*devPtr));
}
}
void free(void* ptr) {
if (!ptr) {
return;
}
Block* block = get_allocated_block(ptr, true /* remove */);
if (!block) {
TORCH_CHECK(false, "invalid device pointer: ", ptr);
}
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_memory_deallocation(
c10::kCUDA, reinterpret_cast<uintptr_t>(block->ptr));
}
device_allocator[block->device]->free(block);
}
double getMemoryFraction(c10::DeviceIndex device) override {
TORCH_INTERNAL_ASSERT(
0 <= device && static_cast<size_t>(device) < device_allocator.size(),
"Allocator not initialized for device ",
device,
": did you call init?");
C10_CUDA_CHECK(c10::cuda::SetDevice(device));
return device_allocator[device]->getMemoryFraction();
}
void setMemoryFraction(double fraction, c10::DeviceIndex device) override {
TORCH_INTERNAL_ASSERT(
0 <= device && static_cast<size_t>(device) < device_allocator.size(),
"Allocator not initialized for device ",
device,
": did you call init?");
TORCH_INTERNAL_ASSERT(
0 <= fraction && fraction <= 1,
"invalid fraction:",
fraction,
". Please set within (0, 1).");
C10_CUDA_CHECK(c10::cuda::SetDevice(device));
device_allocator[device]->setMemoryFraction(fraction);
}
void recordHistory(
bool enabled,
CreateContextFn context_recorder,
size_t alloc_buffer_max_entries,
RecordContext when) override {
record_history = enabled;
annotation_buffer.setMaxEntries(alloc_buffer_max_entries);
annotation_buffer.clear();
for (auto& allocator : device_allocator) {
allocator->recordHistory(
enabled, context_recorder, alloc_buffer_max_entries, when);
}
}
void recordAnnotation(
const std::vector<std::pair<std::string, std::string>>& md) override {
if (!record_history) {
return;
}
c10::DeviceIndex device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
auto ae = AnnotationEntry(
/*device=*/device,
/*time=*/getApproximateTime());
for (const auto& md_pair : md) {
ae.recordUserMetadata(md_pair.first, md_pair.second);
}
annotation_buffer.insertEntries(ae);
}
bool isHistoryEnabled() override {
c10::DeviceIndex device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
return device_allocator[device]->isHistoryEnabled();
}
bool checkPoolLiveAllocations(
c10::DeviceIndex device,
MempoolId_t mempool_id,
const std::unordered_set<void*>& expected_live_allocations) override {
return device_allocator[device]->checkPoolLiveAllocations(
mempool_id, expected_live_allocations);
}
void attachOutOfMemoryObserver(OutOfMemoryObserver observer) override {
for (auto& allocator : device_allocator) {
allocator->attachOutOfMemoryObserver(observer);
}
}
void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) override {
for (auto& allocator : device_allocator) {
allocator->attachAllocatorTraceTracker(tracker);
}
}
void emptyCache() override {
for (auto& da : device_allocator)
da->emptyCache();
}
void enable(bool value) override {
enable_ = value;
}
bool isEnabled() const override {
return enable_;
}
void* getBaseAllocation(void* ptr, size_t* outSize) override {
Block* block = get_allocated_block(ptr);
if (!block) {
TORCH_CHECK(false, "invalid device pointer: ", ptr);
}
return device_allocator[block->device]->getBaseAllocation(block, outSize);
}
ShareableHandle shareIpcHandle(void* ptr) override {
Block* block = get_allocated_block(ptr);
if (!block) {
TORCH_CHECK(false, "invalid device pointer: ", ptr);
}
return device_allocator[block->device]->shareIpcHandle(block);
}
void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) override {
// Empty tensor's storage().data() might be a null ptr. As there is no
// blocks associated with those tensors, it is fine to do nothing here.
if (!ptr.get()) {
return;
}
// If a tensor is not allocated by this instance, simply skip
// This usually happens when CUDA tensors are shared across processes,
// we have implemented reference counting based sharing mechanism to
// guarantee tensors won't be accidentally freed by one process while
// they are still being used in another
if (ptr.get_deleter() != &local_raw_delete)
return;
Block* block = get_allocated_block(ptr.get());
// block must not be null reaching here
TORCH_INTERNAL_ASSERT(block != nullptr, "No allocated block can be found");
device_allocator[block->device]->recordStream(block, stream);
}
SnapshotInfo snapshot() override {
// Set-up converter to convert timestamps from tsc to microseconds.
auto tsc_to_ns = clock_converter.makeConverter();
auto tsc_to_us = [=](approx_time_t t_approx) {
return tsc_to_ns(t_approx) / 1000;
};
SnapshotInfo result;
// Get AnnotationEntry list and convert the timestamps.
annotation_buffer.getEntries(result.external_annotations);
for (auto& ae : result.external_annotations) {
ae.time_.t_ = tsc_to_us(ae.time_.approx_t_);
}
// Get the device_traces' TraceEntry lists.
for (auto& da : device_allocator) {
result.device_traces.emplace_back(da->trace(tsc_to_us));
auto snap = da->snapshot();
result.segments.insert(result.segments.end(), snap.begin(), snap.end());
}
auto& md = result.config_metadata;
md.garbage_collection_threshold =
CUDAAllocatorConfig::garbage_collection_threshold();
md.max_split_size = CUDAAllocatorConfig::max_split_size();
md.pinned_num_register_threads =
CUDAAllocatorConfig::pinned_num_register_threads();
md.expandable_segments = CUDAAllocatorConfig::expandable_segments();
md.release_lock_on_malloc =
CUDAAllocatorConfig::release_lock_on_cudamalloc();
md.pinned_use_host_register =
CUDAAllocatorConfig::pinned_use_cuda_host_register();
md.last_allocator_settings = CUDAAllocatorConfig::last_allocator_settings();
md.roundup_power2_divisions =
CUDAAllocatorConfig::roundup_power2_divisions();
return result;
}
std::shared_ptr<AllocatorState> getCheckpointState(
c10::DeviceIndex device,
MempoolId_t id) override {
return device_allocator[device]->getCheckpointState(id);
}
/**
* @brief Checkpoint the private pool state identified in `as` to its prior
* state
*
* @param device - device of the pool to manipulate
* @param as - allocator state
* @param stale_live_storages - storages of tensors which are currently
* allocated but which will be not be allocated after the checkpoint is set.
* For these storages we will remove their deleter function.
* @return CheckpointDelta - Freed Pointers and DataPtrs that contain deleter
* functions for all allocated blocks in the new checkpoint state.
*/
CheckpointDelta setCheckpointPoolState(
c10::DeviceIndex device,
std::shared_ptr<AllocatorState> as) override {
std::shared_ptr<PrivatePoolState> pps =
std::dynamic_pointer_cast<PrivatePoolState>(as);
TORCH_CHECK(pps, "Expected PrivatePoolState");
auto rr = device_allocator[device]->setCheckpointPoolState(*pps);
CheckpointDelta cpd;
for (void* ptr : rr.allocations_freed) {
get_allocated_block(ptr, /*remove*/ true);
cpd.ptrs_freed.push_back(ptr);
}
for (Block* block : rr.allocations_created) {
add_allocated_block(block);
cpd.dataptrs_allocd.emplace_back(
block->ptr,
block->ptr,
&local_raw_delete,
Device(DeviceType::CUDA, device));
}
return cpd;
}
DataPtr allocate(size_t size) override {
constexpr size_t one_exa_bytes = 1152921504606846976ULL;
TORCH_CHECK_WITH(
OutOfMemoryError,
size < one_exa_bytes,
"CUDA out of memory. Tried to allocate more than 1EB memory.");
c10::DeviceIndex device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
void* devPtr = nullptr;
void (*deleteFunc)(void*) = &local_raw_delete;
CUDAStream stream = cuda::getCurrentCUDAStream(device);
if (forceUncachedAllocator() || !isEnabled()) {
deleteFunc = &uncached_delete;
devPtr = uncached_allocate(size);
} else {
if (size != 0) {
this->malloc(&devPtr, device, size, stream);
}
}
if (size && TORCH_SDT_IS_ENABLED(malloc)) {
TORCH_SDT_WITH_SEMAPHORE(malloc, devPtr, device, size, stream.id());
}
return {devPtr, devPtr, deleteFunc, Device(DeviceType::CUDA, device)};
}
DeleterFnPtr raw_deleter() const override {
if (forceUncachedAllocator() || !isEnabled()) {
return &uncached_delete;
} else {
return &local_raw_delete;
}
}
void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) override {
device_allocator[device]->cacheInfo(largestBlock);
}
void assertValidDevice(c10::DeviceIndex device) {
const auto device_num = device_allocator.size();
TORCH_CHECK(
0 <= device && device < static_cast<int64_t>(device_num),
"Invalid device argument ",
device,
": did you call init?");
}
DeviceStats getDeviceStats(c10::DeviceIndex device) override {
assertValidDevice(device);
return device_allocator[device]->getStats();
}
void resetAccumulatedStats(c10::DeviceIndex device) override {
assertValidDevice(device);
device_allocator[device]->resetAccumulatedStats();
}
void resetPeakStats(c10::DeviceIndex device) override {
assertValidDevice(device);
device_allocator[device]->resetPeakStats();
}
void ensureExistsAndIncrefPool(
c10::DeviceIndex device,
MempoolId_t mempool_id) override {
assertValidDevice(device);
device_allocator[device]->ensureExistsAndIncrefPool(std::move(mempool_id));
}
// CUDAGraph interactions
void beginAllocateToPool(
c10::DeviceIndex device,
MempoolId_t mempool_id,
std::function<bool(cudaStream_t)> filter) override {
assertValidDevice(device);
device_allocator[device]->beginAllocateToPool(
std::move(mempool_id), std::move(filter));
}
void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id)
override {
assertValidDevice(device);
device_allocator[device]->endAllocateToPool(mempool_id);
}
void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) override {
assertValidDevice(device);
device_allocator[device]->releasePool(std::move(mempool_id));
}
int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id)
override {
assertValidDevice(device);
return device_allocator[device]->getPoolUseCount(std::move(mempool_id));
}
void* raw_alloc(size_t nbytes) override {
if (nbytes == 0) {
return nullptr;
}
void* r = nullptr;
if (forceUncachedAllocator() || !isEnabled()) {
r = uncached_allocate(nbytes);
} else {
c10::DeviceIndex device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
malloc(&r, device, nbytes, cuda::getCurrentCUDAStream(device));
}
return r;
}
void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) override {
if (nbytes == 0) {
return nullptr;
}
void* r = nullptr;
if (forceUncachedAllocator() || !isEnabled()) {
r = uncached_allocate(nbytes);
} else {
c10::DeviceIndex device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
malloc(&r, device, nbytes, stream);
}
return r;
}
void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access)
override {
c10::cuda::CUDAGuard device_guard(dev);
cudaError_t err = cudaDeviceEnablePeerAccess(dev_to_access, 0);
if (err == cudaErrorPeerAccessAlreadyEnabled) {
// ignore and clear the error if access was already enabled
(void)cudaGetLastError();
} else {
C10_CUDA_CHECK(err);
}
device_allocator[dev_to_access]->addPeerAccess(dev);
std::lock_guard<std::mutex> lock(IpcMutex);
for (auto& entry : ipcMemHandle_to_devptr) {
if (entry.second.device_ == dev_to_access &&
entry.second.expandable_segment_) {
entry.second.expandable_segment_->addPeer(dev);
}
}
}
cudaError_t memcpyAsync(
void* dst,
int dstDevice,
const void* src,
int srcDevice,
size_t count,
cudaStream_t stream,
bool p2p_enabled) override {
if (p2p_enabled || // memcpy ok because memory is mapped in both devices
srcDevice == dstDevice || // memcpy ok on a single device
// memcpy ok because both dst and src must have come from cudaMalloc
(!device_allocator[dstDevice]->hasAllocatedExpandableSegments() &&
!device_allocator[srcDevice]->hasAllocatedExpandableSegments())) {
return cudaMemcpyAsync(dst, src, count, cudaMemcpyDeviceToDevice, stream);
}
// when p2p is not enabled, only cudaMemcpyPeerAsync correctly handles
// memory not allocated via cudaMalloc
return cudaMemcpyPeerAsync(dst, dstDevice, src, srcDevice, count, stream);
}
void raw_delete(void* ptr) override {
if (forceUncachedAllocator() || !isEnabled()) {
uncached_delete(ptr);
} else {
this->free(ptr);
}
}
// In CUDA IPC, sender sends a tensor to receiver via shareIPCHandle,
// getIpcDevPtr is called by the receiving process to map the CUDA memory from
// the sending process into its own address space.
// When allocated with cudaMalloc we use the cudaIPCMemHandle_t APIs.
// These APIs only allow sharing a big memory block associated with a
// cudaIpcMemHandle_t and it can be opened only **once** per context per
// process. There can be multiple types of storage in the same IPC mem block,
// so we must cache the device ptr to construct typed storage as it comes.
// When using cuMemCreate, via expandable segments, we use
// cuMemExportToShareableHandle to create a file descriptor that can be sent
// to the other process to sort the object. Then we recreate part of the
// exandable segment necessary to load the allocation.
// ipcMemHandle_to_devptr caches the mapping from shareable handle to
// this process' memory mapping information for that share to ensure we do not
// create it twice. When the shared_ptr is no longer in use we clean up the
// cache.
std::mutex IpcMutex;
struct MemHandleCacheEntry {
MemHandleCacheEntry(
c10::DeviceIndex device,
std::string& handle,
const DeviceCachingAllocator& allocator)
: device_(device) {
int type = SHAREABLE_CUDA_MALLOC;
std::istringstream ss(handle);
if (handle.size() != CUDA_IPC_HANDLE_SIZE) {
auto version = ss.get();
TORCH_CHECK(
version <= SHAREABLE_HANDLE_VERSION,
"received sharable handle from a future version of torch that this version does not know how to handle")
type = ss.get();
} // otherwise this is coming from an old pytorch where it has to be a raw
// SHARABLE_CUDA_MALLOC
if (type == SHAREABLE_CUDA_MALLOC) {
cudaIpcMemHandle_t cuda_handle;
ss.read((char*)&cuda_handle, CUDA_IPC_HANDLE_SIZE);
C10_CUDA_CHECK(cudaIpcOpenMemHandle(
&cuda_ipc_ptr_, cuda_handle, cudaIpcMemLazyEnablePeerAccess));
} else if (type == SHAREABLE_CUDA_EXPANDABLE_SEGMENT) {
expandable_segment_ =
ExpandableSegment::fromShared(device, allocator.peers(), ss)
.release();
} else {
TORCH_INTERNAL_ASSERT(
false, "unexpected or illformed shareable handle type");
}
}
// this struct expects that clear is explicitly called to
// free resources, because we only want this code running when
// the shared pointer to this entry is destructed, not during
// deinitialization when cuda may already have been shutdown.
// This replicates the previous behavior of this map when it
// stored raw cuda_ipc_ptr_ handles.
void clear() {
if (cuda_ipc_ptr_) {
cuda::CUDAGuard device_guard(device_);
C10_CUDA_CHECK(cudaIpcCloseMemHandle(cuda_ipc_ptr_));
cuda_ipc_ptr_ = nullptr;
}
if (expandable_segment_) {
delete expandable_segment_;
expandable_segment_ = nullptr;
}
}
void* ptr() {
if (cuda_ipc_ptr_) {
return cuda_ipc_ptr_;
} else {
return expandable_segment_->ptr();
}
}
c10::DeviceIndex device_;
ExpandableSegment* expandable_segment_{nullptr};
void* cuda_ipc_ptr_{nullptr}; // nullptr if expandable_segment_ is not null
std::weak_ptr<void> wp_;
};
ska::flat_hash_map<std::string, MemHandleCacheEntry> ipcMemHandle_to_devptr;
std::shared_ptr<void> getIpcDevPtr(std::string handle) override {
std::lock_guard<std::mutex> lock(IpcMutex);
auto iter = ipcMemHandle_to_devptr.find(handle);
if (iter != ipcMemHandle_to_devptr.end()) {
auto devptr = iter->second.wp_.lock();
// the weak_ptr should always be valid because we delete the entry from
// the cache when the shared_ptr is destructed, so we should never get
// here.
TORCH_INTERNAL_ASSERT(devptr, "entry in cache has missing shared_ptr");
return devptr;
}
c10::DeviceIndex curr_device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&curr_device));
auto inserted = ipcMemHandle_to_devptr.insert(
iter,
{handle,
MemHandleCacheEntry(
curr_device, handle, *device_allocator[curr_device])});
auto sp = std::shared_ptr<void>(
inserted->second.ptr(), [handle, this](void* ptr) {
std::lock_guard<std::mutex> deleter_lock(IpcMutex);
auto it = ipcMemHandle_to_devptr.find(handle);
TORCH_INTERNAL_ASSERT(it != ipcMemHandle_to_devptr.end());
it->second.clear();
ipcMemHandle_to_devptr.erase(it);
});
inserted->second.wp_ = sp;
return sp;
}
std::string name() override {
return "native";
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
C10_CUDA_CHECK(
cudaMemcpy(dest, src, count, cudaMemcpyKind::cudaMemcpyDeviceToDevice));
}
};
static NativeCachingAllocator allocator;
void local_raw_delete(void* ptr) {
if (TORCH_SDT_IS_ENABLED(free)) {
TORCH_SDT_WITH_SEMAPHORE(free, ptr);
}
allocator.free(ptr);
}
} // namespace Native
namespace CudaMallocAsync {
// If this is put in its own header file, it gets incorrectly renamed in HIPify.
CUDAAllocator* allocator();
} // namespace CudaMallocAsync
struct BackendStaticInitializer {
// Parses env for backend at load time, duplicating some logic from
// CUDAAllocatorConfig. CUDAAllocatorConfig double-checks it later (at
// runtime). Defers verbose exceptions and error checks, including Cuda
// version checks, to CUDAAllocatorConfig's runtime doublecheck. If this
// works, maybe we should move all of CUDAAllocatorConfig here?
CUDAAllocator* parseEnvForBackend() {
const auto val = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
if (val.has_value()) {
const std::string& config = val.value();
std::regex exp("[\\s,]+");
std::sregex_token_iterator it(config.begin(), config.end(), exp, -1);
std::sregex_token_iterator end;
std::vector<std::string> options(it, end);
for (auto option : options) {
std::regex exp2("[:]+");
std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1);
std::sregex_token_iterator end2;
std::vector<std::string> kv(it2, end2);
if (kv.size() >= 2) {
if (kv[0] == "backend") {
if (kv[1] == "cudaMallocAsync")
return CudaMallocAsync::allocator();
if (kv[1] == "native")
return &Native::allocator;
}
}
}
}
return &Native::allocator;
}
BackendStaticInitializer() {
auto r = parseEnvForBackend();
allocator.store(r);
}
};
std::atomic<CUDAAllocator*> allocator;
static BackendStaticInitializer backend_static_initializer;
} // namespace cuda::CUDACachingAllocator
} // namespace c10
namespace c10::cuda {
// uid_ is incremented when a user creates a MemPool,
// for example: using graph_pool_handle() or c10::cuda::MemPool().
//
// uuid_ is incremented when CUDAGraph creates a MemPool
// as a result of a user not providing a pool.
//
// MempoolId_t of {0, 0} is used to denote when no MemPool has been
// passed to a function, either by user or CUDAGraphs. For example,
// default value of MempoolId_t for capture_begin function is {0, 0}.
// That's why uid_ and uuid_ start at 1.
std::atomic<CaptureId_t> MemPool::uid_{1};
std::atomic<CaptureId_t> MemPool::uuid_{1};
MemPool::MemPool(
CUDACachingAllocator::CUDAAllocator* allocator,
bool is_user_created)
: allocator_(allocator), is_user_created_(is_user_created) {
if (is_user_created_) {
id_ = {0, uid_++};
} else {
id_ = {uuid_++, 0};
}
device_ = c10::cuda::current_device();
CUDACachingAllocator::ensureExistsAndIncrefPool(device_, id_);
}
MemPool::~MemPool() {
TORCH_INTERNAL_ASSERT(use_count() == 1);
CUDACachingAllocator::releasePool(device_, id_);
auto ctx = MemPoolContext(this);
c10::cuda::CUDACachingAllocator::emptyCache();
}
MempoolId_t MemPool::id() {
return id_;
}
CUDACachingAllocator::CUDAAllocator* MemPool::allocator() {
return allocator_;
}
int MemPool::use_count() {
return CUDACachingAllocator::getPoolUseCount(device_, id_);
}
c10::DeviceIndex MemPool::device() {
return device_;
}
MempoolId_t MemPool::graph_pool_handle(bool is_user_created) {
if (is_user_created) {
return {0, uid_++};
}
return {uuid_++, 0};
}
// Note that active_mempool_ is a global variable here
// and not inside MemPoolContext class, because in windows we
// can't use __declspec(dllexport) and __declspec(thread)
// together: https://stackoverflow.com/a/50967977
static thread_local MemPool* active_mempool_ = nullptr;
MemPoolContext::MemPoolContext(MemPool* mempool)
: prev_mempool_(active_mempool_) {
active_mempool_ = mempool;
}
MemPoolContext::~MemPoolContext() {
active_mempool_ = prev_mempool_;
}
MemPool* MemPoolContext::getActiveMemPool() {
return active_mempool_;
}
} // namespace c10::cuda
|