1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027
|
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/core/impl/GPUTrace.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/util/UniqueVoidPtr.h>
#include <c10/util/flat_hash_map.h>
#include <c10/util/irange.h>
#include <c10/util/llvmMathExtras.h>
#include <cuda_runtime_api.h>
#include <algorithm>
#include <bitset>
#include <deque>
#include <iterator>
#include <map>
#include <memory>
#include <mutex>
#include <regex>
#include <set>
#include <vector>
namespace c10 {
C10_DEFINE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback);
namespace cuda {
namespace CUDACachingAllocator {
//
// Yet another caching allocator for CUDA device allocations.
//
// - Allocations are associated with a stream. Once freed, blocks can be
// re-allocated on the same stream, but not on any other stream.
// - The allocator attempts to find the smallest cached block that will fit the
// requested size. If the block is larger than the requested size, it may be
// split. If no block is found, the allocator will delegate to cudaMalloc.
// - If the cudaMalloc fails, the allocator will attempt to free one cached
// block of sufficient size that is not split and retry the allocation.
// If this also fails, the allocator will attempt to free all cached blocks
// that are not split and retry the allocation.
// - Large (>1MB) and small allocations are stored in separate pools.
// Small requests are packed into 2MB buffers. Large requests will use the
// smallest available free block or allocate a new block using cudaMalloc.
// - To reduce fragmentation, requests between 1MB and 10MB will allocate and
// split a 20MB block, if no free block of sufficient size is available.
// - To further reduce fragmentation, blocks >= 200MB are not allowed to be
// split. These oversize cached blocks will still satisfy requests within
// 20MB of the oversize cached block size.
//
// With this allocator, allocations and frees should logically be considered
// "usages" of the memory segment associated with streams, just like kernel
// launches. The programmer must insert the proper synchronization if memory
// segments are used from multiple streams.
//
// The library provides a recordStream() function to help insert the correct
// synchronization when allocations are used on multiple streams. This will
// ensure that the block is not reused before each recorded stream completes
// work.
//
/**
* Note [Interaction with CUDA graph capture]
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
* Graph capture performs a dry run of a region of execution, freezing all CUDA
* work (and virtual addresses used during that work) into a "graph." The graph
* may be "replayed" like a single giant kernel, with greatly reduced CPU
* overhead as well as modestly improved GPU performance.
*
* Because capture bakes in memory addresses, the memory used during capture
* must be available for the graph to use during replay. DeviceCachingAllocator
* assigns and frees memory eagerly and dynamically, so if we're not careful
* about managing graphs' memory, at replay time those memory addresses could be
* use by other tensors.
*
* To guarantee a graph's baked in addresses are safe to reuse in replay,
* DeviceAllocator satisfies allocations from a graph-private memory pool during
* capture, and doesn't begin cudaFreeing those addresses until the graph is
* destroyed.
*
* Within the private pool, allocations are freed and reassigned as usual during
* capture. Memory regions will be used in a consistent order during replay. So
* a private pool doesn't use memory more wastefully than the default pools
* during capture, but it does reserve its high-water mark of used memory away
* from the default pools as long as the capture(s) it served survive
* (regardless whether those captures are idle or replaying).
*
* CUDAGraph's requests for private pools are mediated by
* DeviceAllocator::notifyCaptureBegin, notifyCaptureEnd, and
* notifyCaptureDestroy.
*/
namespace {
using stream_set = ska::flat_hash_set<cuda::CUDAStream>;
constexpr size_t kMinBlockSize =
512; // all sizes are rounded to at least 512 bytes
constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB
constexpr size_t kSmallBuffer =
2097152; // "small" allocations are packed in 2 MiB blocks
constexpr size_t kLargeBuffer =
20971520; // "large" allocations may be packed in 20 MiB blocks
constexpr size_t kMinLargeAlloc =
10485760; // allocations between 1 and 10 MiB may use kLargeBuffer
constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB
using StatTypes = std::array<bool, static_cast<size_t>(StatType::NUM_TYPES)>;
void update_stat(Stat& stat, int64_t amount) {
stat.current += amount;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
stat.current >= 0,
"Negative tracked stat in CUDA allocator (likely logic error).");
stat.peak = std::max(stat.current, stat.peak);
if (amount > 0) {
stat.allocated += amount;
}
if (amount < 0) {
stat.freed += -amount;
}
}
void reset_accumulated_stat(Stat& stat) {
stat.allocated = 0;
stat.freed = 0;
}
void reset_peak_stat(Stat& stat) {
stat.peak = stat.current;
}
template <typename Func>
void for_each_selected_stat_type(const StatTypes& stat_types, Func f) {
for (const auto stat_type : c10::irange(stat_types.size())) {
if (stat_types[stat_type]) {
f(stat_type);
}
}
}
void update_stat_array(
StatArray& stat_array,
int64_t amount,
const StatTypes& stat_types) {
for_each_selected_stat_type(
stat_types, [&stat_array, amount](size_t stat_type) {
update_stat(stat_array[stat_type], amount);
});
}
struct Block;
struct PrivatePool;
typedef bool (*Comparison)(const Block*, const Block*);
struct BlockPool {
BlockPool(
Comparison comparator,
bool small,
PrivatePool* private_pool = nullptr)
: blocks(comparator), is_small(small), owner_PrivatePool(private_pool) {}
std::set<Block*, Comparison> blocks;
const bool is_small;
PrivatePool* owner_PrivatePool;
};
struct Block {
int device; // gpu
cudaStream_t stream; // allocation stream
stream_set stream_uses; // streams on which the block was used
size_t size; // block size in bytes
BlockPool* pool; // owning memory pool
void* ptr; // memory address
bool allocated; // in-use flag
Block* prev; // prev block if split from a larger allocation
Block* next; // next block if split from a larger allocation
int event_count; // number of outstanding CUDA events
int gc_count; // counter for prioritizing older / less useful blocks for
// garbage collection
std::unique_ptr<History> history;
History* history_last;
Block(
int device,
cudaStream_t stream,
size_t size,
BlockPool* pool,
void* ptr)
: device(device),
stream(stream),
stream_uses(),
size(size),
pool(pool),
ptr(ptr),
allocated(0),
prev(nullptr),
next(nullptr),
event_count(0),
gc_count(0) {}
// constructor for search key
Block(int device, cudaStream_t stream, size_t size)
: device(device),
stream(stream),
stream_uses(),
size(size),
pool(nullptr),
ptr(nullptr),
allocated(0),
prev(nullptr),
next(nullptr),
event_count(0),
gc_count(0) {}
bool is_split() const {
return (prev != nullptr) || (next != nullptr);
}
};
static bool BlockComparator(const Block* a, const Block* b) {
if (a->stream != b->stream) {
return (uintptr_t)a->stream < (uintptr_t)b->stream;
}
if (a->size != b->size) {
return a->size < b->size;
}
return (uintptr_t)a->ptr < (uintptr_t)b->ptr;
}
static std::string format_size(uint64_t size) {
std::ostringstream os;
os.precision(2);
os << std::fixed;
if (size <= 1024) {
os << size << " bytes";
} else if (size <= 1048576) {
os << (size / 1024.0);
os << " KiB";
} else if (size <= 1073741824ULL) {
os << size / 1048576.0;
os << " MiB";
} else {
os << size / 1073741824.0;
os << " GiB";
}
return os.str();
}
struct AllocParams {
AllocParams(
int device,
size_t size,
cudaStream_t stream,
BlockPool* pool,
size_t alloc_size,
DeviceStats& stats)
: search_key(device, stream, size),
pool(pool),
alloc_size(alloc_size),
block(nullptr),
err(cudaSuccess) {}
int device() const {
return search_key.device;
}
cudaStream_t stream() const {
return search_key.stream;
}
size_t size() const {
return search_key.size;
}
Block search_key;
BlockPool* pool;
size_t alloc_size;
Block* block;
StatTypes stat_types = {false};
cudaError_t err;
};
int trimHistoryBefore(Block* block, void* point) {
int n = 0;
while (block->history && block->history->addr < point) {
block->history = std::move(block->history->next);
++n;
}
if (!block->history) {
block->history_last = nullptr;
}
return n;
}
// Note: cudaEventCreate when concurrently invoked from multiple threads can be
// very expensive (at least on certain device/driver combinations). Thus, we a)
// serialize event creation at a per-device level, and b) pool the events to
// avoid constantly calling cudaEventCreate/cudaEventDestroy. This results in
// significant improvements in multithreaded workloads with high allocation
// rates.
class EventPool {
public:
using Event = std::unique_ptr<cudaEvent_t, std::function<void(cudaEvent_t*)>>;
// TODO: Explicit device count
EventPool() : pools_(at::cuda::device_count()) {}
Event get(int device) {
TORCH_INTERNAL_ASSERT(0 <= device);
TORCH_INTERNAL_ASSERT(device < static_cast<int>(pools_.size()));
auto& pool = pools_[device];
auto destructor = [&pool](cudaEvent_t* event) {
std::lock_guard<std::mutex> g(pool.mutex_);
pool.event_pool_.push_back(std::unique_ptr<cudaEvent_t>(event));
};
// Try to acquire an event from the per-device pool.
{
std::lock_guard<std::mutex> g(pool.mutex_);
if (!pool.event_pool_.empty()) {
auto* event = pool.event_pool_.back().release();
pool.event_pool_.pop_back();
return Event(event, destructor);
}
}
// otherwise, allocate a new event that will be returned to the pool on
// destruction.
auto new_ptr = std::make_unique<cudaEvent_t>();
C10_CUDA_CHECK(
cudaEventCreateWithFlags(new_ptr.get(), cudaEventDisableTiming));
return Event(new_ptr.release(), destructor);
}
void empty_cache() {
for (auto& pool : pools_) {
std::lock_guard<std::mutex> g(pool.mutex_);
pool.event_pool_.clear();
}
}
private:
struct PerDevicePool {
alignas(64) std::mutex mutex_;
std::vector<std::unique_ptr<cudaEvent_t>> event_pool_;
};
std::vector<PerDevicePool> pools_;
};
// CUDA graphs helper
struct PrivatePool {
PrivatePool()
: use_count(1),
cudaMalloc_count(0),
large_blocks(BlockComparator, /*is_small=*/false, this),
small_blocks(BlockComparator, /*is_small=*/true, this) {}
PrivatePool(const PrivatePool&) = delete;
PrivatePool(PrivatePool&&) = delete;
PrivatePool& operator=(const PrivatePool&) = delete;
// Number of live graphs using this pool
int use_count;
// Number of unfreed cudaMallocs made for this pool. When use_count and
// cudaMalloc_count drop to zero, we can delete this PrivatePool from
// graph_pools.
int cudaMalloc_count;
// Instead of maintaining private BlockPools here, I could stuff all blocks
// (private or no) into the top-level large_blocks and small_blocks, and
// distinguish private blocks by adding a "pool id" check above the stream
// check in BlockComparator. BlockComparator is performance- critial though,
// I'd rather not add more logic to it.
BlockPool large_blocks;
BlockPool small_blocks;
};
struct MempoolIdHash {
std::size_t operator()(const MempoolId_t& mempool_id) const noexcept {
return mempool_id.first != 0 ? mempool_id.first : mempool_id.second;
}
};
cudaError_t cudaMallocMaybeCapturing(void** p, size_t size) {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
if (at::cuda::currentStreamCaptureStatusMayInitCtx() ==
at::cuda::CaptureStatus::None) {
#endif
return C10_CUDA_ERROR_HANDLED(cudaMalloc(p, size));
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
} else {
// It's ok to capture cudaMallocs, as long as we never cudaFree those
// addresses before replay.
// Capturing cudaMalloc behaves nicely: it gives the graph new VA,
// but is ignored (won't leakily allocate new memory) in replays.
at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeRelaxed};
return C10_CUDA_ERROR_HANDLED(cudaMalloc(p, size));
}
#endif
}
} // namespace
class CachingAllocatorConfig {
public:
static size_t max_split_size() {
return instance().m_max_split_size;
}
static double garbage_collection_threshold() {
return instance().m_garbage_collection_threshold;
}
// This is used to round-up allocation size to nearest power of 2 divisions.
// More description below in function roundup_power2_next_division
// As ane example, if we want 4 divisions between 2's power, this can be done
// using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4
static size_t roundup_power2_divisions() {
return instance().m_roundup_power2_divisions;
}
static size_t roundup_bypass_threshold() {
return instance().m_roundup_bypass_threshold;
}
static CachingAllocatorConfig& instance() {
static CachingAllocatorConfig* s_instance = ([]() {
auto inst = new CachingAllocatorConfig();
const char* env = getenv("PYTORCH_CUDA_ALLOC_CONF");
inst->parseArgs(env);
return inst;
})();
return *s_instance;
}
void parseArgs(const char* env) {
// If empty, set the default values
m_max_split_size = std::numeric_limits<size_t>::max();
m_roundup_power2_divisions = 0;
m_roundup_bypass_threshold = std::numeric_limits<size_t>::max();
m_garbage_collection_threshold = 0;
if (env == nullptr) {
return;
}
const std::string config(env);
std::regex exp("[\\s,]+");
std::sregex_token_iterator it(config.begin(), config.end(), exp, -1);
std::sregex_token_iterator end;
std::vector<std::string> options(it, end);
for (auto option : options) {
std::regex exp2("[:]+");
std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1);
std::sregex_token_iterator end2;
std::vector<std::string> kv(it2, end2);
if (kv.size() >= 2) {
/* Maximum split size in MB. Limited to large size blocks */
if (kv[0].compare("max_split_size_mb") == 0) {
size_t val2 = stoi(kv[1]);
TORCH_CHECK(
val2 > kLargeBuffer / (1024 * 1024),
"CachingAllocator option max_split_size_mb too small, must be > ",
kLargeBuffer / (1024 * 1024),
"");
val2 = std::max(val2, kLargeBuffer / (1024 * 1024));
val2 = std::min(
val2, (std::numeric_limits<size_t>::max() / (1024 * 1024)));
m_max_split_size = val2 * 1024 * 1024;
} else if (kv[0].compare("roundup_power2_divisions") == 0) {
size_t val2 = stoi(kv[1]);
TORCH_CHECK(
llvm::isPowerOf2_64(val2),
"For roundups, the divisons has to be power of 2 ",
"");
m_roundup_power2_divisions = val2;
} else if (kv[0].compare("roundup_bypass_threshold_mb") == 0) {
size_t val2 = stoi(kv[1]);
m_roundup_bypass_threshold = val2 * 1024 * 1024;
} else if (kv[0].compare("garbage_collection_threshold") == 0) {
/*
* Perform garbage collection of GPU memory blocks to avoid
* triggering expensive sync-and-reclaim-all operation. Upon setting
* the threshold (e.g., 0.8), the allocator will start reclaiming
* blocks if GPU memory capacity usage exceeds the threshold (i.e.,
* 80% of total memory).
* Values 0.0 and 1.0 are not allowed as they are less meaningful.
*/
double val2 = stod(kv[1]);
TORCH_CHECK(
val2 > 0,
"garbage_collect_threshold too small, set it 0.0~1.0",
"");
TORCH_CHECK(
val2 < 1.0,
"garbage_collect_threshold too big, set it 0.0~1.0",
"");
m_garbage_collection_threshold = val2;
} else {
TORCH_CHECK(false, "Unrecognized CachingAllocator option: ", kv[0]);
}
}
}
}
private:
CachingAllocatorConfig()
: m_max_split_size(std::numeric_limits<size_t>::max()),
m_roundup_power2_divisions(0),
m_garbage_collection_threshold(0) {}
std::atomic<size_t> m_max_split_size;
std::atomic<size_t> m_roundup_power2_divisions;
std::atomic<size_t> m_roundup_bypass_threshold;
std::atomic<double> m_garbage_collection_threshold;
};
class DeviceCachingAllocator {
private:
// lock around all operations
mutable std::recursive_mutex mutex;
// device statistics
DeviceStats stats;
// unallocated cached blocks larger than 1 MB
BlockPool large_blocks;
// unallocated cached blocks 1 MB or smaller
BlockPool small_blocks;
// allocated or in use by a stream. Holds all active allocations,
// whether they came from graph_pools or one of the BlockPools above.
ska::flat_hash_set<Block*> active_blocks;
// captures_underway tracks if a capture might be underway on any stream.
// Most of the time it's zero, in which case malloc can avoid calling
// cudaStreamGetCaptureInfo in the hot path.
int captures_underway = 0;
// See free() for this thing's purpose
std::vector<Block*> needs_events_deferred_until_no_capture;
// outstanding cuda events
ska::flat_hash_map<
cuda::CUDAStream,
std::deque<std::pair<EventPool::Event, Block*>>>
cuda_events;
// record used memory.
size_t total_allocated_memory = 0;
size_t allowed_memory_maximum = 0;
bool set_fraction = false;
// Members specific to CUDA graphs
// Private pools for CUDA graphs
ska::flat_hash_map<MempoolId_t, std::unique_ptr<PrivatePool>, MempoolIdHash>
graph_pools;
// Pools no longer referenced by any graph. Their BlockPools are eligible for
// free_blocks. Can't be a vector or deque because we might erase entries in
// any order. Could be an std::list, but we don't care much, access and
// insert/erase are rare.
ska::flat_hash_map<MempoolId_t, PrivatePool*, MempoolIdHash>
graph_pools_freeable;
// Maps a capturing stream to its assigned private pool,
// in case we want multiple captures to share the same pool
ska::flat_hash_map<CaptureId_t, MempoolId_t> capture_to_pool_map;
std::atomic<CreateContextFn> context_recorder_;
public:
DeviceCachingAllocator()
: large_blocks(BlockComparator, /*is_small=*/false),
small_blocks(BlockComparator, /*is_small=*/true) {
stats.max_split_size = CachingAllocatorConfig::max_split_size();
context_recorder_.store(nullptr);
}
void setContextRecorder(CreateContextFn c) {
context_recorder_.store(c);
}
// All public methods (except the above) acquire the allocator mutex.
// Thus, do not call a public method from another public method.
Block* malloc(int device, size_t orig_size, cudaStream_t stream) {
// done outside the lock because we don't know what locks the recorder needs
// to have...
CreateContextFn context_recorder = context_recorder_.load();
std::unique_ptr<Context> context =
context_recorder ? context_recorder() : nullptr;
std::unique_lock<std::recursive_mutex> lock(mutex);
if (C10_LIKELY(captures_underway == 0)) {
// Processes end-of-life events for outstanding allocations used on
// multiple streams (checks if their GPU-side uses are complete and
// recycles their memory if so)
//
// Q. Why skip process_events if a capture might be underway?
// A. process_events involves cudaEventQueries, illegal during CUDA graph
// capture.
// Dumb simple solution: defer reclaiming these allocations until after
// capture. Cross-stream memory use is uncommon, so the deferral's
// effect on memory use during capture should be small.
process_events();
}
size_t size = round_size(orig_size);
auto& pool = get_pool(size, stream);
const size_t alloc_size = get_allocation_size(size);
AllocParams params(device, size, stream, &pool, alloc_size, stats);
params.stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true;
params.stat_types[static_cast<size_t>(get_stat_type_for_pool(pool))] = true;
// First, try to get a block from the existing pool.
bool block_found =
// Search pool
get_free_block(params)
// Trigger callbacks and retry search
|| (trigger_free_memory_callbacks(params) && get_free_block(params));
// Can't reuse an existing block; try to get a new one.
if (!block_found) {
// Do garbage collection if the flag is set.
if (C10_UNLIKELY(
set_fraction &&
CachingAllocatorConfig::garbage_collection_threshold() > 0.0)) {
garbage_collect_cached_blocks();
}
// Attempt allocate
block_found = alloc_block(params, false)
// Free enough available cached blocks to satisfy alloc and retry
// alloc.
|| (release_available_cached_blocks(params) &&
alloc_block(params, false))
// Free all non-split cached blocks and retry alloc.
|| (C10_LIKELY(captures_underway == 0) && release_cached_blocks() &&
alloc_block(params, true));
}
if (!block_found) {
// For any error code other than cudaErrorMemoryAllocation,
// alloc_block should have thrown an exception already.
TORCH_INTERNAL_ASSERT(params.err == cudaErrorMemoryAllocation);
size_t device_free;
size_t device_total;
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
std::string allowed_info;
if (set_fraction) {
allowed_info = format_size(allowed_memory_maximum) + " allowed; ";
}
stats.num_ooms += 1;
c10::reportOutOfMemoryToProfiler(
size,
stats.allocated_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
.current,
stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
.current,
c10::Device(c10::DeviceType::CUDA, static_cast<DeviceIndex>(device)));
// "total capacity": total global memory on GPU
// "allowed": memory is allowed to use, which set by fraction.
// "already allocated": memory allocated by the program using the
// caching allocator
// "free": free memory as reported by the CUDA API
// "cached": memory held by the allocator but not used by the program
//
// The "allocated" amount does not include memory allocated outside
// of the caching allocator, such as memory allocated by other programs
// or memory held by the driver.
//
// The sum of "allocated" + "free" + "cached" may be less than the
// total capacity due to memory held by the driver and usage by other
// programs.
//
// Note that at this point free_cached_blocks has already returned all
// possible "cached" memory to the driver. The only remaining "cached"
// memory is split from a larger block that is partially in-use.
TORCH_CHECK_WITH(
OutOfMemoryError,
false,
"CUDA out of memory. Tried to allocate ",
format_size(alloc_size),
" (GPU ",
device,
"; ",
format_size(device_total),
" total capacity; ",
format_size(
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)]
.current),
" already allocated; ",
format_size(device_free),
" free; ",
allowed_info,
format_size(
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
.current),
" reserved in total by PyTorch)",
" If reserved memory is >> allocated memory try setting max_split_size_mb to avoid"
" fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
"");
}
TORCH_INTERNAL_ASSERT(
params.err == cudaSuccess && params.block != nullptr &&
params.block->ptr != nullptr);
Block* block = params.block;
Block* remaining = nullptr;
const bool already_split = block->is_split();
if (should_split(block, size)) {
remaining = block;
block = new Block(device, stream, size, &pool, block->ptr);
block->prev = remaining->prev;
if (block->prev) {
block->prev->next = block;
}
block->next = remaining;
remaining->prev = block;
remaining->ptr = static_cast<char*>(remaining->ptr) + size;
remaining->size -= size;
bool inserted = pool.blocks.insert(remaining).second;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted);
if (context) {
trimHistoryBefore(remaining, (char*)block->ptr + size);
}
if (already_split) {
// An already-split inactive block is being shrunk by size bytes.
update_stat_array(
stats.inactive_split_bytes, -block->size, params.stat_types);
} else {
// A new split inactive block is being created from a previously unsplit
// block, size remaining->size bytes.
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
update_stat(stats.inactive_split_bytes[stat_type], remaining->size);
update_stat(stats.inactive_split[stat_type], 1);
});
}
} else if (already_split) {
// An already-split block is becoming active
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
update_stat(stats.inactive_split_bytes[stat_type], -block->size);
update_stat(stats.inactive_split[stat_type], -1);
});
}
block->allocated = true;
if (context) {
trimHistoryBefore(block, (char*)block->ptr + size);
block->history = std::make_unique<History>(History{
block->ptr,
orig_size,
std::move(context),
std::move(block->history)});
if (!block->history_last) {
block->history_last = block->history.get();
}
}
bool inserted = active_blocks.insert(block).second;
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted);
for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) {
update_stat(stats.allocation[stat_type], 1);
update_stat(stats.allocated_bytes[stat_type], block->size);
update_stat(stats.active[stat_type], 1);
update_stat(stats.active_bytes[stat_type], block->size);
});
if (block->size >= CachingAllocatorConfig::max_split_size())
update_stat(stats.oversize_allocations, 1);
c10::reportMemoryUsageToProfiler(
block->ptr,
block->size,
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
c10::Device(c10::DeviceType::CUDA, device));
return block;
}
void free(Block* block) {
std::lock_guard<std::recursive_mutex> lock(mutex);
block->allocated = false;
// following logic might modifying underlaying Block, causing the size
// changed. We store ahead for reporting
auto orig_block_ptr = block->ptr;
auto orig_block_size = block->size;
StatTypes stat_types = {false};
stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true;
stat_types[static_cast<size_t>(get_stat_type_for_pool(*(block->pool)))] =
true;
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
update_stat(stats.allocation[stat_type], -1);
update_stat(stats.allocated_bytes[stat_type], -block->size);
});
if (block->size >= CachingAllocatorConfig::max_split_size())
update_stat(stats.oversize_allocations, -1);
if (!block->stream_uses.empty()) {
if (C10_UNLIKELY(captures_underway)) {
// It's forbidden to cudaEventQuery an event recorded during CUDA graph
// capture. We conservatively defer recording end-of-life events until
// the next call to process_events() (which won't happen until no
// captures are underway)
needs_events_deferred_until_no_capture.push_back(block);
} else {
insert_events(block);
}
} else {
free_block(block);
}
c10::reportMemoryUsageToProfiler(
orig_block_ptr,
-orig_block_size,
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current,
c10::Device(c10::DeviceType::CUDA, block->device));
}
void* getBaseAllocation(Block* block, size_t* outSize) {
std::lock_guard<std::recursive_mutex> lock(mutex);
while (block->prev) {
block = block->prev;
}
void* basePtr = block->ptr;
if (outSize) {
size_t size = 0;
while (block) {
size += block->size;
block = block->next;
}
*outSize = size;
}
return basePtr;
}
void recordStream(Block* block, cuda::CUDAStream stream) {
std::lock_guard<std::recursive_mutex> lock(mutex);
if (stream.stream() == block->stream) {
// ignore uses on the allocation stream, since those don't require any
// special synchronization
return;
}
block->stream_uses.insert(stream);
}
/** set memory fraction to limit maximum allocated memory **/
void setMemoryFraction(double fraction) {
size_t device_free;
size_t device_total;
C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
allowed_memory_maximum = static_cast<size_t>(fraction * device_total);
set_fraction = true;
}
/** returns cached blocks to the system allocator **/
void emptyCache() {
std::lock_guard<std::recursive_mutex> lock(mutex);
release_cached_blocks();
}
/** Retrieves info (total size + largest block) of the memory cache **/
void cacheInfo(size_t* total, size_t* largest) {
std::lock_guard<std::recursive_mutex> lock(mutex);
if (*largest ==
0) { // make an initial guess if a zero *largest is passed in
size_t tmp_bytes;
C10_CUDA_CHECK(cudaMemGetInfo(
largest, // Use free memory as an optimistic initial guess of *largest
&tmp_bytes));
}
cache_info_aux(large_blocks, total, largest);
cache_info_aux(small_blocks, total, largest);
for (const auto& gp : graph_pools) {
cache_info_aux(gp.second->large_blocks, total, largest);
cache_info_aux(gp.second->small_blocks, total, largest);
}
}
/** Returns a copy of the memory allocator stats **/
DeviceStats getStats() {
std::lock_guard<std::recursive_mutex> lock(mutex);
return stats;
}
/** Resets the historical accumulation stats for the device **/
void resetAccumulatedStats() {
std::lock_guard<std::recursive_mutex> lock(mutex);
for (const auto statType :
c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) {
reset_accumulated_stat(stats.allocation[statType]);
reset_accumulated_stat(stats.segment[statType]);
reset_accumulated_stat(stats.active[statType]);
reset_accumulated_stat(stats.inactive_split[statType]);
reset_accumulated_stat(stats.allocated_bytes[statType]);
reset_accumulated_stat(stats.reserved_bytes[statType]);
reset_accumulated_stat(stats.active_bytes[statType]);
reset_accumulated_stat(stats.inactive_split_bytes[statType]);
}
stats.num_alloc_retries = 0;
stats.num_ooms = 0;
reset_accumulated_stat(stats.oversize_allocations);
reset_accumulated_stat(stats.oversize_segments);
}
/** Resets the historical peak stats for the device **/
void resetPeakStats() {
std::lock_guard<std::recursive_mutex> lock(mutex);
for (const auto statType :
c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) {
reset_peak_stat(stats.allocation[statType]);
reset_peak_stat(stats.segment[statType]);
reset_peak_stat(stats.active[statType]);
reset_peak_stat(stats.inactive_split[statType]);
reset_peak_stat(stats.allocated_bytes[statType]);
reset_peak_stat(stats.reserved_bytes[statType]);
reset_peak_stat(stats.active_bytes[statType]);
reset_peak_stat(stats.inactive_split_bytes[statType]);
}
reset_peak_stat(stats.oversize_allocations);
reset_peak_stat(stats.oversize_segments);
}
/** Dump a complete snapshot of the memory held by the allocator. Potentially
* VERY expensive. **/
std::vector<SegmentInfo> snapshot() const {
std::lock_guard<std::recursive_mutex> lock(mutex);
std::vector<SegmentInfo> result;
const auto all_blocks = get_all_blocks();
for (const Block* const head_block : all_blocks) {
if (head_block->prev != nullptr) {
continue;
}
result.emplace_back();
SegmentInfo& segment_info = result.back();
segment_info.device = head_block->device;
segment_info.address = reinterpret_cast<int64_t>(head_block->ptr);
segment_info.stream = head_block->stream;
segment_info.is_large = (!head_block->pool->is_small);
const Block* block = head_block;
while (block != nullptr) {
segment_info.blocks.emplace_back();
BlockInfo& block_info = segment_info.blocks.back();
block_info.size = block->size;
block_info.allocated = block->allocated;
block_info.active = block->allocated || (block->event_count > 0) ||
!block->stream_uses.empty();
segment_info.total_size += block_info.size;
if (block_info.allocated) {
segment_info.allocated_size += block_info.size;
}
if (block_info.active) {
segment_info.active_size += block_info.size;
}
block_info.history = block->history.get();
block = block->next;
}
}
std::sort(
result.begin(),
result.end(),
[](const SegmentInfo& a, const SegmentInfo& b) {
return a.address < b.address;
});
return result;
}
// This function takes the size and number of divisions argument and rounds
// up the size argument for the nearest power-of-2 division.
// For example, if we need to round-up 1200 and number of divisions is 4,
// the size 1200 lies between 1024 and 2048 and if we do 4 divisions between
// them, the values are 1024, 1280, 1536, and 1792. So the function will
// return 1280 as the nearest ceiling of power-2 divison.
static size_t roundup_power2_next_division(size_t size, size_t divisions) {
if (C10_UNLIKELY(size <= 4 || divisions <= 1)) {
return size;
}
if (llvm::isPowerOf2_64(size)) {
return size;
}
// divide the space between these 2's power into equal divisions
// If division is zero, return the power-of-2 ceiling.
size_t power2_floor = llvm::PowerOf2Floor(size);
size_t power2_divison =
power2_floor >> (63 - llvm::countLeadingZeros(divisions));
if (C10_UNLIKELY(power2_divison == 0)) {
return (power2_floor << 1);
}
size_t round_size_floor = size & (~(power2_divison - 1));
return (round_size_floor == size) ? size
: round_size_floor + power2_divison;
}
static size_t round_size(size_t size) {
if (size < kMinBlockSize) {
return kMinBlockSize;
} else if (size > CachingAllocatorConfig::roundup_bypass_threshold()) {
return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize);
} else {
auto divisions = CachingAllocatorConfig::roundup_power2_divisions();
if (divisions > 0 && size > (kMinBlockSize * divisions)) {
return roundup_power2_next_division(size, divisions);
} else {
return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize);
}
}
}
// See Note [Interaction with CUDA graph capture]
// Called by CUDAGraph::capture_begin
void notifyCaptureBegin(CaptureId_t graph_id, MempoolId_t mempool_id) {
std::lock_guard<std::recursive_mutex> lock(mutex);
captures_underway++;
auto it = graph_pools.find(mempool_id);
if (it == graph_pools.end()) {
// mempool_id does not reference an existing pool. Make a new pool for
// this capture.
graph_pools.emplace(mempool_id, std::make_unique<PrivatePool>());
} else {
// mempool_id references an existing pool, which the current capture will
// share. Check this pool is live (at least one other capture already
// references it).
TORCH_INTERNAL_ASSERT(it->second->use_count > 0);
it->second->use_count++;
}
// Maps this graph_id to mempool_id and makes sure this graph_id wasn't
// somehow assigned a mempool_id already. Keeps essential effect (insert)
// out of macro.
bool inserted = capture_to_pool_map.insert({graph_id, mempool_id}).second;
TORCH_INTERNAL_ASSERT(inserted);
}
// Called by CUDAGraph::capture_end
void notifyCaptureEnd(CaptureId_t graph_id) {
std::lock_guard<std::recursive_mutex> lock(mutex);
captures_underway--;
auto it = capture_to_pool_map.find(graph_id);
TORCH_INTERNAL_ASSERT(it != capture_to_pool_map.end());
capture_to_pool_map.erase(it);
}
// Called by CUDAGraph::reset
void notifyCaptureDestroy(MempoolId_t mempool_id) {
std::lock_guard<std::recursive_mutex> lock(mutex);
// The instantiated cudaGraphExec_t has been destroyed. We can't blindly
// delete and cudaFree the mempool its capture used, because
// 1. other graph(s) might share the same pool
// 2. the user might still hold references to output tensors allocated
// during capture.
// To handle 1 and 2, we track the number of graphs using this particular
// mempool. When the count reaches 0, we tell free_cached_blocks it may now
// cudaFree blocks from this graph's pool when it discovers they're unused
// (unsplit).
auto it = graph_pools.find(mempool_id);
TORCH_INTERNAL_ASSERT(it != graph_pools.end());
auto uc = --(it->second->use_count);
TORCH_INTERNAL_ASSERT(uc >= 0);
if (uc == 0) {
// Allows free_cached_blocks to begin cudaFreeing this pool's memory,
// and makes sure this pool wasn't somehow made freeable already.
bool inserted =
graph_pools_freeable.insert({mempool_id, it->second.get()}).second;
TORCH_INTERNAL_ASSERT(inserted);
}
}
private:
// All private methods do not acquire the allocator mutex.
std::vector<const Block*> get_all_blocks() const {
std::vector<const Block*> blocks;
blocks.insert(
blocks.end(), small_blocks.blocks.begin(), small_blocks.blocks.end());
blocks.insert(
blocks.end(), large_blocks.blocks.begin(), large_blocks.blocks.end());
for (const auto& gp : graph_pools) {
blocks.insert(
blocks.end(),
gp.second->small_blocks.blocks.begin(),
gp.second->small_blocks.blocks.end());
blocks.insert(
blocks.end(),
gp.second->large_blocks.blocks.begin(),
gp.second->large_blocks.blocks.end());
}
blocks.insert(blocks.end(), active_blocks.begin(), active_blocks.end());
return blocks;
}
/** moves a block into a pool of cached free blocks */
void free_block(Block* block) {
TORCH_INTERNAL_ASSERT(
!block->allocated && block->event_count == 0 &&
block->stream_uses.empty());
size_t original_block_size = block->size;
auto& pool = *block->pool;
int64_t net_change_inactive_split_blocks = 0;
int64_t net_change_inactive_split_size = 0;
const std::array<Block*, 2> merge_candidates = {block->prev, block->next};
for (Block* merge_candidate : merge_candidates) {
const int64_t subsumed_size =
try_merge_blocks(block, merge_candidate, pool);
if (subsumed_size > 0) {
net_change_inactive_split_blocks -= 1;
net_change_inactive_split_size -= subsumed_size;
}
}
active_blocks.erase(block);
// Makes sure the Block* isn't already present in the pool we're freeing it
// back into.
bool inserted = pool.blocks.insert(block).second;
TORCH_INTERNAL_ASSERT(inserted);
if (block->is_split()) {
net_change_inactive_split_blocks += 1;
net_change_inactive_split_size += block->size;
}
StatTypes stat_types = {false};
stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true;
stat_types[static_cast<size_t>(get_stat_type_for_pool(pool))] = true;
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
update_stat(
stats.inactive_split[stat_type], net_change_inactive_split_blocks);
update_stat(
stats.inactive_split_bytes[stat_type],
net_change_inactive_split_size);
update_stat(stats.active[stat_type], -1);
update_stat(stats.active_bytes[stat_type], -original_block_size);
});
}
/** combine previously split blocks. returns the size of the subsumed block,
* or 0 on failure. */
size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) {
if (!src || src->allocated || src->event_count > 0 ||
!src->stream_uses.empty()) {
return 0;
}
AT_ASSERT(dst->is_split() && src->is_split());
if (dst->prev == src) { // [src dst]
dst->ptr = src->ptr;
dst->prev = src->prev;
if (dst->prev) {
dst->prev->next = dst;
}
if (!dst->history) {
dst->history = std::move(src->history);
dst->history_last = src->history_last;
} else if (src->history) {
src->history_last->next = std::move(dst->history);
dst->history = std::move(src->history);
}
src->history_last = nullptr;
} else { // [dest src]
dst->next = src->next;
if (dst->next) {
dst->next->prev = dst;
}
if (!dst->history) {
dst->history = std::move(src->history);
dst->history_last = src->history_last;
} else if (src->history) {
dst->history_last->next = std::move(src->history);
dst->history_last = src->history_last;
}
src->history_last = nullptr;
}
const size_t subsumed_size = src->size;
dst->size += subsumed_size;
auto erased = pool.blocks.erase(src);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(erased == 1);
delete src;
return subsumed_size;
}
BlockPool& get_pool(size_t size, cudaStream_t stream) {
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
// captures_underway is a conservative guess that the current stream may be
// capturing. It's only > 0 if some thread has begun and not yet ended a
// capture, so it's usually 0, and we can short-circuit
// cudaStreamCaptureStatus (which does a TLS lookup).
if (C10_UNLIKELY(captures_underway)) {
CaptureId_t id;
cudaStreamCaptureStatus status;
C10_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &id));
if (status != cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) {
TORCH_INTERNAL_ASSERT(
status !=
cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated);
// Retrieves the private pool assigned to this capture.
auto it0 = capture_to_pool_map.find(id);
TORCH_INTERNAL_ASSERT(it0 != capture_to_pool_map.end());
auto it1 = graph_pools.find(it0->second);
TORCH_INTERNAL_ASSERT(it1 != graph_pools.end());
if (size <= kSmallSize) {
return it1->second->small_blocks;
} else {
return it1->second->large_blocks;
}
}
}
#endif
if (size <= kSmallSize) {
return small_blocks;
} else {
return large_blocks;
}
}
StatType get_stat_type_for_pool(const BlockPool& pool) {
return pool.is_small ? StatType::SMALL_POOL : StatType::LARGE_POOL;
}
bool should_split(const Block* block, size_t size) {
size_t remaining = block->size - size;
if (block->pool->is_small) {
return remaining >= kMinBlockSize;
} else {
return (size < CachingAllocatorConfig::max_split_size()) &&
(remaining > kSmallSize);
}
}
static size_t get_allocation_size(size_t size) {
if (size <= kSmallSize) {
return kSmallBuffer;
} else if (size < kMinLargeAlloc) {
return kLargeBuffer;
} else {
return kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge);
}
}
bool get_free_block(AllocParams& p) {
BlockPool& pool = *p.pool;
if (C10_UNLIKELY(
set_fraction &&
CachingAllocatorConfig::garbage_collection_threshold() > 0.0)) {
// Track block reuse interval only when garbage collection is enabled.
for (auto& b : pool.blocks) {
++b->gc_count;
}
}
auto it = pool.blocks.lower_bound(&p.search_key);
if (it == pool.blocks.end() || (*it)->stream != p.stream())
return false;
// Do not return an oversized block for a large request
if ((p.size() < CachingAllocatorConfig::max_split_size()) &&
((*it)->size >= CachingAllocatorConfig::max_split_size()))
return false;
// Allow oversized block size to be rounded up but within a limit
if ((p.size() >= CachingAllocatorConfig::max_split_size()) &&
((*it)->size >= p.size() + kLargeBuffer))
return false;
p.block = *it;
(*it)->gc_count = 0; // Denote this block has been used
pool.blocks.erase(it);
return true;
}
bool trigger_free_memory_callbacks(AllocParams& p) {
bool freed_memory = false;
for (const auto& name : FreeCudaMemoryCallbacksRegistry()->Keys()) {
freed_memory |=
FreeCudaMemoryCallbacksRegistry()->Create(name)->Execute();
}
return freed_memory;
}
void garbage_collect_cached_blocks() {
// Free unused cached blocks to reclaim GPU memory.
// Unlike release_cached_blocks(), this does not enforce synchronization and
// therefore should be of less overheads.
size_t gc_threshold = static_cast<size_t>(
CachingAllocatorConfig::garbage_collection_threshold() *
allowed_memory_maximum);
// No need to trigger GC yet
if (total_allocated_memory <= gc_threshold) {
return;
}
const auto target_size = total_allocated_memory - gc_threshold;
size_t gc_reclaimed = 0;
// Calculate the total age of the free-able blocks. We'll use it later to
// get "avg age" threshold.
double total_age = 0.0;
int freeable_block_count = 0;
for (auto& b : large_blocks.blocks) {
if (!b->is_split()) {
total_age += b->gc_count;
++freeable_block_count;
}
}
// No free-able blocks?
if (freeable_block_count == 0) {
return;
}
// Repeat GC until we reach reclaim > target size.
bool block_freed = true;
while (gc_reclaimed < target_size && block_freed == true &&
freeable_block_count > 0) {
// Free blocks exceeding this age threshold first.
double age_threshold = total_age / freeable_block_count;
// Stop iteration if we can no longer free a block.
block_freed = false;
// Free blocks of > avg age. Don't stop upon reaching the target_size,
// we don't want this GC to be triggered frequently.
auto it = large_blocks.blocks.begin();
while (it != large_blocks.blocks.end()) {
Block* block = *it;
++it;
if (!block->is_split() && block->gc_count >= age_threshold) {
block_freed = true;
gc_reclaimed += block->size;
total_age -= block->gc_count; // Decrement the age
freeable_block_count--; // One less block that can be freed
release_block(block);
}
}
}
}
bool alloc_block(AllocParams& p, bool isRetry) {
// Defensively checks for preexisting CUDA error state.
C10_CUDA_CHECK(cudaGetLastError());
size_t size = p.alloc_size;
void* ptr;
if (isRetry) {
stats.num_alloc_retries += 1;
}
if (set_fraction &&
total_allocated_memory + size > allowed_memory_maximum) {
p.err = cudaErrorMemoryAllocation;
return false;
} else {
p.err = cudaMallocMaybeCapturing(&ptr, size);
if (p.err != cudaSuccess) {
if (p.err == cudaErrorMemoryAllocation) {
// If this is the first attempt (!isRetry), we can forgive and clear
// CUDA's
// internal error state.
// If this is the second attempt (isRetry), malloc's TORCH_CHECK_WITH
// will take
// over to throw a helpful exception. The user can choose to catch
// the exception, free some stuff in their script, and attempt their
// allocation again. In this case, we can also forgive and clear
// CUDA's internal error state.
cudaGetLastError();
} else {
// If the error's unrelated to memory allocation, we should throw
// immediately.
C10_CUDA_CHECK(p.err);
}
return false;
}
}
if (p.pool->owner_PrivatePool) {
// The block is for a CUDA graph's PrivatePool.
p.pool->owner_PrivatePool->cudaMalloc_count++;
}
total_allocated_memory += size;
p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr);
for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
update_stat(stats.segment[stat_type], 1);
update_stat(stats.reserved_bytes[stat_type], size);
});
if (size >= CachingAllocatorConfig::max_split_size())
update_stat(stats.oversize_segments, 1);
// p.block came from new, not cudaMalloc. It should not be nullptr here.
TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr);
return true;
}
/** Free one or more oversize blocks to the system allocator. But only enough
* **/
/** to satisfy the target size **/
bool release_available_cached_blocks(const AllocParams& p) {
if (CachingAllocatorConfig::max_split_size() ==
std::numeric_limits<size_t>::max())
return false;
BlockPool& pool = *p.pool;
// because of std::unique_ptr, block cannot be trivially copied
Block key(
p.search_key.device,
p.search_key.stream,
p.search_key.size,
p.search_key.pool,
p.search_key.ptr);
key.size = (key.size < CachingAllocatorConfig::max_split_size())
? CachingAllocatorConfig::max_split_size()
: key.size;
auto it = pool.blocks.lower_bound(&key);
if (it == pool.blocks.end() || (*it)->stream != p.stream()) {
// No single block is large enough; free multiple oversize blocks,
// starting with the largest
if (it == pool.blocks.begin())
return false;
size_t totalReleased = 0;
--it; // Back up one item. Now on the largest block for the correct
// stream
while ((totalReleased < key.size) &&
((*it)->size >= CachingAllocatorConfig::max_split_size()) &&
((*it)->stream == p.stream())) {
auto cur = it;
totalReleased += (*it)->size;
if (it != pool.blocks.begin()) {
--it;
release_block(*cur);
} else {
release_block(*cur);
break;
}
}
if (totalReleased < key.size)
return false;
} else {
release_block(*it);
}
return true;
}
bool release_cached_blocks() {
// First ensure that all blocks that can't currently be allocated due to
// outstanding events are returned to the pool.
synchronize_and_free_events();
// Free all non-split cached blocks to system allocator
release_blocks(large_blocks);
release_blocks(small_blocks);
for (auto it = graph_pools_freeable.begin();
it != graph_pools_freeable.end();) {
// See notifyCaptureDestroy for the strategy here.
TORCH_INTERNAL_ASSERT(it->second->use_count == 0);
release_blocks(it->second->small_blocks);
release_blocks(it->second->large_blocks);
if (it->second->cudaMalloc_count == 0) {
auto erase_count = graph_pools.erase(it->first);
TORCH_INTERNAL_ASSERT(erase_count == 1);
it = graph_pools_freeable.erase(it);
} else {
++it;
}
}
return true;
}
void release_block(Block* block) {
C10_CUDA_CHECK(cudaFree((void*)block->ptr));
total_allocated_memory -= block->size;
auto* pool = block->pool;
if (pool->owner_PrivatePool) {
// The cudaFreed block belonged to a CUDA graph's PrivatePool.
TORCH_INTERNAL_ASSERT(pool->owner_PrivatePool->cudaMalloc_count > 0);
pool->owner_PrivatePool->cudaMalloc_count--;
}
StatTypes stat_types = {false};
stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true;
stat_types[static_cast<size_t>(get_stat_type_for_pool(*pool))] = true;
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
update_stat(stats.segment[stat_type], -1);
update_stat(stats.reserved_bytes[stat_type], -block->size);
});
if (block->size >= CachingAllocatorConfig::max_split_size())
update_stat(stats.oversize_segments, -1);
pool->blocks.erase(block);
delete block;
}
void release_blocks(BlockPool& pool) {
// Frees all non-split blocks
auto it = pool.blocks.begin();
while (it != pool.blocks.end()) {
Block* block = *it;
++it;
if (!block->prev && !block->next) {
release_block(block);
}
}
}
EventPool::Event create_event_internal(int idx) {
// Leak the event pool to avoid shutdown issues.
static auto* event_pool = new EventPool();
return event_pool->get(idx);
}
void synchronize_and_free_events() {
// Synchronize on outstanding events and then free associated blocks.
// This function syncs, so capture should not be underway. Might as well
// make sure capture-deferred end of life events get processed too.
TORCH_INTERNAL_ASSERT(captures_underway == 0);
insert_events_deferred_until_no_capture();
for (auto& st : cuda_events) {
for (auto& e : st.second) {
EventPool::Event event = std::move(e.first);
Block* block = e.second;
C10_CUDA_CHECK(cudaEventSynchronize(*event));
block->event_count--;
if (block->event_count == 0) {
free_block(block);
}
}
}
cuda_events.clear();
}
void insert_events(Block* block) {
int prev_device;
C10_CUDA_CHECK(cudaGetDevice(&prev_device));
stream_set streams(std::move(block->stream_uses));
AT_ASSERT(block->stream_uses.empty());
for (auto& stream : streams) {
C10_CUDA_CHECK(cudaSetDevice(stream.device_index()));
EventPool::Event event =
create_event_internal(static_cast<int>(stream.device_index()));
C10_CUDA_CHECK(cudaEventRecord(*event, stream.stream()));
block->event_count++;
cuda_events[stream].emplace_back(std::move(event), block);
}
C10_CUDA_CHECK(cudaSetDevice(prev_device));
}
void insert_events_deferred_until_no_capture() {
if (C10_UNLIKELY(needs_events_deferred_until_no_capture.size() > 0)) {
for (auto* block : needs_events_deferred_until_no_capture) {
TORCH_INTERNAL_ASSERT(!block->stream_uses.empty());
insert_events(block);
}
needs_events_deferred_until_no_capture.clear();
}
}
void process_events() {
insert_events_deferred_until_no_capture();
// Process outstanding cudaEvents. Events that are completed are
// removed from the queue, and the 'event_count' for the
// corresponding allocation is decremented. We maintain a separate
// list of events per stream to avoid head-of-line delays if one
// or more streams has long-running operations.
// Iterate over different streams.
for (auto it = cuda_events.begin(); it != cuda_events.end();) {
// Iterate over this stream's (event, block) pairs.
while (!it->second.empty()) {
auto& e = it->second.front();
EventPool::Event event = std::move(e.first);
Block* block = e.second;
cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(*event));
if (err == cudaErrorNotReady) {
// ignore and clear the error if not ready
cudaGetLastError();
// Return the ownership of the Event (unique ptr)
e.first = std::move(event);
break;
} else if (err != cudaSuccess) {
C10_CUDA_CHECK(err);
}
block->event_count--;
if (block->event_count == 0) {
free_block(block);
}
it->second.pop_front();
}
if (it->second.empty()) {
it = cuda_events.erase(it);
} else {
it++;
}
}
}
// Accumulates sizes of all memory blocks for given device in given pool
void cache_info_aux(const BlockPool& pool, size_t* total, size_t* largest) {
for (const auto& block : pool.blocks) {
const auto blocksize = block->size;
*total += blocksize;
if (blocksize > *largest) {
*largest = blocksize;
}
}
}
};
class THCCachingAllocator {
private:
std::mutex mutex;
// allocated blocks by device pointer
ska::flat_hash_map<void*, Block*> allocated_blocks;
// lock around calls to cudaFree (to prevent deadlocks with NCCL)
mutable std::mutex cuda_free_mutex;
void add_allocated_block(Block* block) {
std::lock_guard<std::mutex> lock(mutex);
allocated_blocks[block->ptr] = block;
}
public:
std::vector<std::unique_ptr<DeviceCachingAllocator>> device_allocator;
std::mutex* getCudaFreeMutex() const {
return &cuda_free_mutex;
}
Block* get_allocated_block(void* ptr, bool remove = false) {
std::lock_guard<std::mutex> lock(mutex);
auto it = allocated_blocks.find(ptr);
if (it == allocated_blocks.end()) {
return nullptr;
}
Block* block = it->second;
if (remove) {
allocated_blocks.erase(it);
}
return block;
}
void init(int device_count) {
const auto size = static_cast<int64_t>(device_allocator.size());
if (size < device_count) {
device_allocator.resize(device_count);
for (const auto i : c10::irange(size, device_count)) {
device_allocator[i] = std::make_unique<DeviceCachingAllocator>();
}
}
}
/** allocates a block which is safe to use from the provided stream */
void malloc(void** devPtr, int device, size_t size, cudaStream_t stream) {
TORCH_INTERNAL_ASSERT(
0 <= device && static_cast<size_t>(device) < device_allocator.size(),
"Allocator not initialized for device ",
device,
": did you call init?");
Block* block = device_allocator[device]->malloc(device, size, stream);
add_allocated_block(block);
*devPtr = (void*)block->ptr;
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_memory_allocation(
reinterpret_cast<uintptr_t>(*devPtr));
}
}
void free(void* ptr) {
if (!ptr) {
return;
}
Block* block = get_allocated_block(ptr, true /* remove */);
if (!block) {
TORCH_CHECK(false, "invalid device pointer: ", ptr);
}
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_memory_deallocation(
reinterpret_cast<uintptr_t>(block->ptr));
}
device_allocator[block->device]->free(block);
}
void setMemoryFraction(double fraction, int device) {
TORCH_INTERNAL_ASSERT(
0 <= device && static_cast<size_t>(device) < device_allocator.size(),
"Allocator not initialized for device ",
device,
": did you call init?");
TORCH_INTERNAL_ASSERT(
0 <= fraction && fraction <= 1,
"invalid fraction:",
fraction,
". Please set within (0, 1).");
int activated_device;
C10_CUDA_CHECK(cudaGetDevice(&activated_device));
if (activated_device != device) {
C10_CUDA_CHECK(cudaSetDevice(device));
}
device_allocator[device]->setMemoryFraction(fraction);
}
void setContextRecorder(CreateContextFn recorder) {
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
device_allocator[device]->setContextRecorder(std::move(recorder));
}
void emptyCache() {
for (auto& da : device_allocator)
da->emptyCache();
}
void* getBaseAllocation(void* ptr, size_t* outSize) {
Block* block = get_allocated_block(ptr);
if (!block) {
TORCH_CHECK(false, "invalid device pointer: ", ptr);
}
return device_allocator[block->device]->getBaseAllocation(block, outSize);
}
void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) {
// Empty tensor's storage().data() might be a null ptr. As there is no
// blocks associated with those tensors, it is fine to do nothing here.
if (!ptr.get()) {
return;
}
// If a tensor is not allocated by this instance, simply skip
// This usually happens when CUDA tensors are shared across processes,
// we have implemented reference counting based sharing mechanism to
// guarantee tensors won't be accidentally freed by one process while
// they are still being used in another
if (ptr.get_deleter() != &raw_delete)
return;
Block* block = get_allocated_block(ptr.get());
// block must not be null reaching here
TORCH_INTERNAL_ASSERT(block != nullptr, "No allocated block can be found");
device_allocator[block->device]->recordStream(block, stream);
}
std::vector<SegmentInfo> snapshot() {
std::vector<SegmentInfo> result;
for (auto& da : device_allocator) {
auto snap = da->snapshot();
result.insert(result.end(), snap.begin(), snap.end());
}
return result;
}
};
THCCachingAllocator caching_allocator;
// Returns whether to force all allocations to bypass the caching allocator and
// go straight to cudaMalloc. This setting is useful when debugging GPU memory
// errors, since the caching allocator foils cuda-memcheck.
bool forceUncachedAllocator() {
static bool force_uncached =
getenv("PYTORCH_NO_CUDA_MEMORY_CACHING") != nullptr;
return force_uncached;
}
static void uncached_delete(void* ptr) {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_memory_deallocation(reinterpret_cast<uintptr_t>(ptr));
}
C10_CUDA_CHECK(cudaFree(ptr));
}
// NB: I decided not to fold this into THCCachingAllocator, because the latter
// has a lot more methods and it wasn't altogether clear that they should
// actually be publicly exposed
struct CudaCachingAllocator : public Allocator {
DataPtr allocate(size_t size) const override {
constexpr size_t one_exa_bytes = 1152921504606846976ULL;
TORCH_CHECK_WITH(
OutOfMemoryError,
size < one_exa_bytes,
"CUDA out of memory. Tried to allocate more than 1EB memory.");
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
void* r = nullptr;
if (forceUncachedAllocator()) {
// Deliberately don't use cudaMallocMaybeCapturing here, to force an error
// if someone tries to use forceUncachedAllocator while capturing.
C10_CUDA_CHECK(cudaMalloc(&r, size));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_memory_allocation(reinterpret_cast<uintptr_t>(r));
}
return {r, r, &uncached_delete, Device(DeviceType::CUDA, device)};
}
if (size != 0) {
caching_allocator.malloc(
&r, device, size, cuda::getCurrentCUDAStream(device));
}
return {r, r, &raw_delete, Device(DeviceType::CUDA, device)};
}
DeleterFnPtr raw_deleter() const override {
if (forceUncachedAllocator()) {
return &uncached_delete;
} else {
return &raw_delete;
}
}
};
CudaCachingAllocator device_allocator;
Allocator* get(void) {
return &device_allocator;
}
void init(int device_count) {
caching_allocator.init(device_count);
}
void setMemoryFraction(double fraction, int device) {
caching_allocator.setMemoryFraction(fraction, device);
}
void setContextRecorder(CreateContextFn recorder) {
caching_allocator.setContextRecorder(std::move(recorder));
}
void setAllocatorSettings(const std::string& env) {
CachingAllocatorConfig::instance().parseArgs(env.c_str());
}
void emptyCache(void) {
caching_allocator.emptyCache();
}
void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock) {
caching_allocator.device_allocator[dev_id]->cacheInfo(
cachedAndFree, largestBlock);
}
void* getBaseAllocation(void* ptr, size_t* size) {
return caching_allocator.getBaseAllocation(ptr, size);
}
void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) {
caching_allocator.recordStream(ptr, stream);
}
std::mutex* getFreeMutex() {
return caching_allocator.getCudaFreeMutex();
}
static inline void assertValidDevice(int device) {
const auto device_num = caching_allocator.device_allocator.size();
TORCH_CHECK(
0 <= device && device < static_cast<int64_t>(device_num),
"Invalid device argument ",
device,
": did you call init?");
}
DeviceStats getDeviceStats(int device) {
assertValidDevice(device);
return caching_allocator.device_allocator[device]->getStats();
}
void resetAccumulatedStats(int device) {
assertValidDevice(device);
caching_allocator.device_allocator[device]->resetAccumulatedStats();
}
void resetPeakStats(int device) {
assertValidDevice(device);
caching_allocator.device_allocator[device]->resetPeakStats();
}
std::vector<SegmentInfo> snapshot() {
return caching_allocator.snapshot();
}
// CUDAGraph interactions
void notifyCaptureBegin(
int device,
CaptureId_t graph_id,
MempoolId_t mempool_id) {
assertValidDevice(device);
caching_allocator.device_allocator[device]->notifyCaptureBegin(
graph_id, mempool_id);
}
void notifyCaptureEnd(int device, CaptureId_t graph_id) {
assertValidDevice(device);
caching_allocator.device_allocator[device]->notifyCaptureEnd(graph_id);
}
void notifyCaptureDestroy(int device, MempoolId_t mempool_id) {
assertValidDevice(device);
caching_allocator.device_allocator[device]->notifyCaptureDestroy(mempool_id);
}
//
// In CUDA IPC, sender sends a tensor to receiver, getIpcDevPtr
// is called by the receiving process to map the CUDA memory from the sending
// process into its own address space.
//
// CUDA IPC only allows sharing a big memory block associated with a
// cudaIpcMemHandle_t and it can be opened only **once** per context per
// process. There can be multiple types of storage in the same IPC mem block, so
// we must cache the device ptr to construct typed storage as it comes.
//
// ipcMemHandle_to_devptr maps a cudaIpcMemHandle_t to a device pointer in the
// process that can be used to access the memory block in the sender process. It
// only saves a weak_ptr of the device pointer in the map, the shared_ptr will
// be used to reconstruct all storages in this CudaMalloc allocation. And it
// will deleted in cudaIpcCloseMemHandle when its reference count is 0.
//
namespace {
std::mutex IpcMutex;
ska::flat_hash_map<std::string, std::weak_ptr<void>> ipcMemHandle_to_devptr;
} // namespace
std::shared_ptr<void> getIpcDevPtr(std::string handle) {
std::lock_guard<std::mutex> lock(IpcMutex);
auto iter = ipcMemHandle_to_devptr.find(handle);
if (iter != ipcMemHandle_to_devptr.end()) {
auto devptr = iter->second.lock();
if (devptr)
return devptr;
}
// This ipcMemHandle hasn't been opened, or already expired, open it to
// enable IPC access to that mem block.
void* dev = nullptr;
auto ipc_handle = reinterpret_cast<const cudaIpcMemHandle_t*>(handle.c_str());
C10_CUDA_CHECK(
cudaIpcOpenMemHandle(&dev, *ipc_handle, cudaIpcMemLazyEnablePeerAccess));
// devPtr has to be deleted in same device when created.
int curr_device;
C10_CUDA_CHECK(cudaGetDevice(&curr_device));
auto sp = std::shared_ptr<void>(dev, [handle, curr_device](void* ptr) {
cuda::CUDAGuard device_guard(curr_device);
std::lock_guard<std::mutex> deleter_lock(IpcMutex);
C10_CUDA_CHECK(cudaIpcCloseMemHandle(ptr));
ipcMemHandle_to_devptr.erase(handle);
});
std::weak_ptr<void> wp = sp;
// To eliminate an additional search, we can use insert().
// It doesn't overwrite when key already exists(ptr expired).
// But in the deleter for sp we erased the entry,
// this should be safe to do now.
ipcMemHandle_to_devptr.insert(iter, {handle, wp});
return sp;
}
void* raw_alloc(size_t nbytes) {
if (nbytes == 0) {
return nullptr;
}
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
void* r = nullptr;
caching_allocator.malloc(
&r, device, nbytes, cuda::getCurrentCUDAStream(device));
return r;
}
void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) {
if (nbytes == 0) {
return nullptr;
}
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
void* r = nullptr;
caching_allocator.malloc(&r, device, nbytes, stream);
return r;
}
void raw_delete(void* ptr) {
caching_allocator.free(ptr);
}
} // namespace CUDACachingAllocator
} // namespace cuda
} // namespace c10
|