1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171
|
# Owner(s): ["module: optimizer"]
import functools
import math
import tempfile
import unittest
from copy import deepcopy
from typing import Any, Dict, Tuple
from unittest.mock import patch
from optim.test_lrscheduler import TestLRScheduler # noqa: F401
from optim.test_optim import TestDifferentiableOptimizer # noqa: F401
from optim.test_swa_utils import TestSWAUtils # noqa: F401
import torch
from torch.nn import Parameter
from torch.optim import Optimizer, SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.optimizer import (
register_optimizer_step_post_hook,
register_optimizer_step_pre_hook,
)
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
largeTensorTest,
onlyCPU,
onlyCUDA,
onlyNativeDeviceTypes,
skipMPS,
TEST_WITH_ROCM,
)
from torch.testing._internal.common_dtype import floating_types_and
from torch.testing._internal.common_optimizers import (
_get_device_type,
_get_optim_inputs_including_global_cliquey_kwargs,
optim_db,
OptimizerErrorEnum,
optims,
TensorTracker,
)
from torch.testing._internal.common_utils import (
markDynamoStrictTest,
parametrize,
run_tests,
TEST_WITH_TORCHDYNAMO,
TestCase,
xfailIfS390X,
)
FP16_REDUCED_PRECISION = {"atol": 1e-5, "rtol": 1e-4}
def rosenbrock(tensor):
assert tensor.size() == torch.Size(
[2]
), f"Requires tensor with 2 scalars but got {tensor.size()}"
x, y = tensor
return (1 - x) ** 2 + 100 * (y - x**2) ** 2
def drosenbrock(tensor):
assert tensor.size() == torch.Size(
[2]
), f"Requires tensor with 2 scalars but got {tensor.size()}"
x, y = tensor
return torch.stack((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2)))
@markDynamoStrictTest
class TestOptimRenewed(TestCase):
"""
This test class validates the core optimizers and is structured as the correctness of:
- The update algorithms (forloop implementation)
* Every optimizer's algorithm is most readably implemented through a big for-loop
over all the parameters, which is what we refer to as the forloop or single tensor
implementation. These algorithms are manually validated by comparing to the paper
and systematically validated by assuring that the loss goes the right direction
when the optimizer has been applied.
* This implementation should compose with optimizer hyperparameters well, such as
supporting Tensor LRs, the capturable API, and sparse and complex parameters.
- Each varying implementation
* We then have implementations that improve upon the performance of the forloop
implementation by leveraging fusion, namely our foreach (mult_tensor) and fused
implementations.
* These variations are validated numerically by comparing with the forloop version
of the optimizer. In fact, we test most variations this way--we see the forloop
implementation as the ground truth and expect that improvements to it in any way
should be just as correct.
* Both params and optimizer states should be validated numerically.
- state_dict APIs
* The optimizer instance should be serializable
* Calling save and load should be deterministic
* Moving between devices should be seamless
* BC - load_state_dict should be able to handle older optimizer states
- Hook APIs (everything should fire in the right order)
- LR Scheduler integration (composing should not error + should go the right direction)
- Parameter groups (should be equivalent to having multiple optimizers)
- Erroring (what should error should error)
We also cover different ways of generating parameters and grads:
- With parameters, we either generate them randomly given specific shapes or we take
them from a sample NN module.
* Variety is important here because NN modules have type Parameter and randomly
generated tensors have type Tensor.
* Parameters can be sparse for a subset of the optimizers (check out OptimizerInfo)
* Complex parameters should be handled using view_as_real
* Parameters can be spread across different devices and different dtypes for any
given optimizer
* Parameters can be contiguous and noncontiguous
- With grads, we follow suit from the parameters.
* Grads can also be None, empty, or zero-valued, and this should not disrupt training.
"""
@onlyCPU
@optims(optim_db)
def test_optim_infos_do_not_specify_global_cliquey_kwargs(
self, device, dtype, optim_info
):
global_cliquey_flags = ["foreach", "fused", "differentiable"]
for optim_input in optim_info.optim_inputs_func(device=device):
self.assertFalse(
any(f for f in global_cliquey_flags if f in optim_input.kwargs)
)
@optims([optim for optim in optim_db if optim.optim_error_inputs_func is not None])
def test_errors(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
error_inputs = optim_info.optim_error_inputs_func(device=device, dtype=dtype)
for error_input in error_inputs:
optim_input = error_input.optimizer_error_input
params, kwargs = optim_input.params, optim_input.kwargs
if error_input.error_on == OptimizerErrorEnum.CONSTRUCTION_ERROR:
if issubclass(error_input.error_type, Warning):
with self.assertWarnsRegex(
error_input.error_type, error_input.error_regex
):
optim_cls(params, **kwargs)
else:
with self.assertRaisesRegex(
error_input.error_type, error_input.error_regex
):
optim_cls(params, **kwargs)
elif error_input.error_on == OptimizerErrorEnum.STEP_ERROR:
optim = optim_cls(params, **kwargs)
if issubclass(error_input.error_type, Warning):
with self.assertWarnsRegex(
error_input.error_type, error_input.error_regex
):
optim.step()
else:
with self.assertRaisesRegex(
error_input.error_type, error_input.error_regex
):
optim.step()
else:
raise NotImplementedError(f"Unknown error type {error_input.error_on}")
@parametrize("contiguous", [True, False])
@parametrize("with_lrsched", [True, False])
@optims(optim_db, dtypes=[torch.float32])
def test_forloop_goes_right_direction(
self, device, dtype, optim_info, contiguous, with_lrsched
):
optim_cls = optim_info.optim_cls
schedulers_constructors = (
optim_info.scheduler_inputs if with_lrsched else [None]
)
for schedulers_constructor in schedulers_constructors:
# with tensor LR we need fresh inputs for each scheduler
# or mutating it will carry across iters
optim_inputs = optim_info.optim_inputs_func(device=device)
for optim_input in optim_inputs:
if "foreach" in optim_info.supported_impls:
optim_input.kwargs["foreach"] = False # force forloop
if contiguous:
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
else:
weight = Parameter(
torch.randn((10, 5, 2), device=device, dtype=dtype)[..., 0]
)
bias = Parameter(
torch.randn((10, 2), device=device, dtype=dtype)[..., 0]
)
input = torch.randn(5, device=device, dtype=dtype)
optimizer = optim_cls([weight, bias], **optim_input.kwargs)
schedulers = [
s(optimizer)
for s in (schedulers_constructor if schedulers_constructor else [])
]
def closure():
optimizer.zero_grad()
loss = (weight.mv(input) + bias).pow(2).sum()
loss.backward()
if optim_info.only_supports_sparse_grads:
# For this test, we naively convert the Tensor layout, which we know does
# NOT represent the expected use case for optims like SparseAdam!
weight.grad = weight.grad.to_sparse()
bias.grad = bias.grad.to_sparse()
return loss
initial_value = closure().item()
for _ in range(20):
if optim_info.step_requires_closure:
loss = optimizer.step(closure)
else:
loss = closure()
optimizer.step()
for scheduler in schedulers:
if isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(loss)
else:
scheduler.step()
if optim_input.kwargs.get("maximize", False):
self.assertGreater(closure().item(), initial_value)
else:
self.assertLess(closure().item(), initial_value)
@onlyCUDA
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
@parametrize("with_lrsched", [True, False])
@optims(optim_db, dtypes=[torch.float32])
def test_forloop_goes_right_direction_multigpu(
self, device, dtype, optim_info, with_lrsched
):
optim_cls = optim_info.optim_cls
schedulers_constructors = (
optim_info.scheduler_inputs if with_lrsched else [None]
)
for schedulers_constructor in schedulers_constructors:
# We need a fresh set of inputs if we have a tensor LR
# to not carry mutations across iterations.
optim_inputs = optim_info.optim_inputs_func(device=device)
for optim_input in optim_inputs:
if "foreach" in optim_info.supported_impls:
optim_input.kwargs["foreach"] = False # force forloop
weight = Parameter(torch.randn((10, 5), device="cuda:0", dtype=dtype))
bias = Parameter(torch.randn((10), device="cuda:1", dtype=dtype))
inpt = torch.randn(5, device="cuda:0", dtype=dtype)
optimizer = optim_cls([weight, bias], **optim_input.kwargs)
schedulers = [
s(optimizer)
for s in (schedulers_constructor if schedulers_constructor else [])
]
def closure():
optimizer.zero_grad()
loss = (weight.mv(inpt).cuda(1) + bias).pow(2).sum()
loss.backward()
if optim_info.only_supports_sparse_grads:
# For this test, we naively convert the Tensor layout, which we know does
# NOT represent the expected use case for optims like SparseAdam!
weight.grad = weight.grad.to_sparse()
bias.grad = bias.grad.to_sparse()
return loss
initial_value = closure().item()
for _ in range(20):
loss = optimizer.step(closure)
for scheduler in schedulers:
if isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(loss)
else:
scheduler.step()
if optim_input.kwargs.get("maximize", False):
self.assertGreater(closure().item(), initial_value)
else:
self.assertLess(closure().item(), initial_value)
@optims(optim_db, dtypes=[torch.float32])
def test_param_group_with_lrscheduler_goes_right_direction(
self, device, dtype, optim_info
):
optim_cls = optim_info.optim_cls
for schedulers_c in optim_info.scheduler_inputs:
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
inpt = torch.randn(5, device=device, dtype=dtype)
# avoid endless recompiles by wrapping LR in a tensor if we're compiling
lr = torch.tensor(0.01) if torch.compiler.is_compiling() else 0.01
optimizer = optim_cls([{"params": [weight]}, {"params": [bias], "lr": lr}])
schedulers = [scheduler_c(optimizer) for scheduler_c in schedulers_c]
def closure():
optimizer.zero_grad()
loss = (weight.mv(inpt) + bias).pow(2).sum()
loss.backward()
if optim_info.only_supports_sparse_grads:
# For this test, we naively convert the Tensor layout, which we know does
# NOT represent the expected use case for optims like SparseAdam!
weight.grad = weight.grad.to_sparse()
bias.grad = bias.grad.to_sparse()
return loss
initial_value = closure().item()
for _ in range(20):
loss = optimizer.step(closure)
for scheduler in schedulers:
if isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(loss)
else:
scheduler.step()
self.assertLess(closure().item(), initial_value)
@optims(optim_db, dtypes=[torch.float32])
def test_tensor_lr(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable",)
)
for optim_input in all_optim_inputs:
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
weight_c = weight.detach().clone().requires_grad_(True)
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
bias_c = bias.detach().clone().requires_grad_(True)
inpt = torch.randn(5, device=device, dtype=dtype)
kwargs = optim_input.kwargs
if "lr" in kwargs:
del kwargs["lr"]
kwargs["lr"] = 1.0 if optim_info.step_requires_closure else 1e-3
optimizer_r = optim_cls([weight, bias], **kwargs)
try:
kwargs["lr"] = torch.tensor(kwargs["lr"])
optimizer = optim_cls([weight_c, bias_c], **kwargs)
except ValueError as e:
self.assertRegex(str(e), ".*lr as a Tensor is not supported.*")
continue
def closure(optim, w, b, i):
optim.zero_grad()
loss = (w.mv(i) + b).pow(2).sum()
loss.backward()
if optim_info.only_supports_sparse_grads:
# For this test, we naively convert the Tensor layout, which we know does
# NOT represent the expected use case for optims like SparseAdam!
w.grad = w.grad.to_sparse()
b.grad = b.grad.to_sparse()
return loss
for _ in range(5):
if optim_info.step_requires_closure:
optimizer_r.step(
functools.partial(closure, optimizer_r, weight, bias, inpt)
)
optimizer.step(
functools.partial(closure, optimizer, weight_c, bias_c, inpt)
)
else:
closure(optimizer_r, weight, bias, inpt)
closure(optimizer, weight_c, bias_c, inpt)
self.assertEqual(weight, weight_c)
self.assertEqual(bias, bias_c)
@parametrize("with_lrsched", [True, False])
@optims(
[o for o in optim_db if o.supports_sparse or o.only_supports_sparse_grads],
dtypes=[torch.float64],
)
def test_rosenbrock_sparse(self, device, dtype, optim_info, with_lrsched):
optim_cls = optim_info.optim_cls
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
# Fused impls do not support sparse gradients
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable", "fused")
)
kwarg_updates, schedulers_constructors = optim_info.metadata_for_sparse
if with_lrsched and len(schedulers_constructors) == 0:
return
supported_inputs = []
if len(kwarg_updates) != 0:
seen = set()
for i in all_optim_inputs:
for k in kwarg_updates:
if k in i.kwargs:
del i.kwargs[k]
hashable_kwargs = tuple(sorted(i.kwargs.items()))
if len(i.kwargs) > 0 and hashable_kwargs not in seen:
supported_inputs.append(i)
seen.add(hashable_kwargs)
if "lr" in kwarg_updates:
i.kwargs["lr"] = kwarg_updates["lr"]
else:
supported_inputs = all_optim_inputs
for optim_input in supported_inputs:
kwargs = optim_input.kwargs
multi_tensor = kwargs.get("foreach", False)
# For rosenbrock tests, it is mandated that the param is a tensor with 2 numbers
if multi_tensor:
params_t = [
torch.tensor([1.5, 1.5]),
torch.tensor([1.5, 1.5], dtype=dtype),
]
else:
params_t = [torch.tensor([1.5, 1.5])]
params = [Parameter(param_t) for param_t in params_t]
optimizer = optim_cls(params, **kwargs)
schedulers = [
s(optimizer) for s in (schedulers_constructors if with_lrsched else [])
]
if not optim_info.only_supports_sparse_grads:
params_c = [Parameter(param_t.clone()) for param_t in params_t]
optimizer_c = optim_cls(params_c, **kwargs)
schedulers_c = [
s(optimizer_c)
for s in (schedulers_constructors if with_lrsched else [])
]
solution = torch.tensor([1, 1])
with torch.no_grad():
initial_dist = sum(param.dist(solution) for param in params)
def get_grad(param, sparse_grad, w):
grad = drosenbrock(param)
# NB: We torture test the optimizer by returning an
# uncoalesced sparse tensor
# Depending on w, provide only the x or y gradient
if sparse_grad:
if w:
i = torch.tensor([[0, 0]], dtype=torch.int64)
x = grad[0]
v = torch.tensor([x / 4.0, x - x / 4.0])
else:
i = torch.tensor([[1, 1]], dtype=torch.int64)
y = grad[1]
v = torch.tensor([y - y / 4.0, y / 4.0])
grad_out = torch.sparse_coo_tensor(i, v, (2,), dtype=v.dtype)
else:
if w:
grad_out = torch.tensor([grad[0], 0], dtype=param.dtype)
else:
grad_out = torch.tensor([0, grad[1]], dtype=param.dtype)
return grad_out
def eval(params, sparse_grad, w):
optimizer.zero_grad()
if multi_tensor:
loss = sum(rosenbrock(param) for param in params)
else:
loss = rosenbrock(params[0])
loss.backward()
grads_out = [get_grad(param, sparse_grad, w) for param in params]
with torch.no_grad():
params[0].grad = grads_out[0]
if multi_tensor:
params[1].grad = grads_out[1].to(dtype=dtype)
return loss
for i in range(1800):
# Do cyclic coordinate descent
w = i % 2
optimizer.step(functools.partial(eval, params, True, w))
for scheduler in schedulers:
if isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(rosenbrock(params[0]))
else:
scheduler.step()
if not optim_info.only_supports_sparse_grads:
optimizer_c.step(functools.partial(eval, params_c, False, w))
for scheduler in schedulers_c:
if isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(rosenbrock(params_c[0]))
else:
scheduler.step()
# Tolerance is increased due to floating point error from different
# code path for dense case: x v.s. x - x / 4.0 + x / 4.0
self.assertEqual(params, params_c, atol=5e-6, rtol=5e-6)
if not kwargs.get("maximize", False):
self.assertLessEqual(
sum(param.dist(solution) for param in params), initial_dist
)
else:
self.assertGreaterEqual(
sum(rosenbrock(param) for param in params),
sum(rosenbrock(param_t) for param_t in params_t),
)
@skipMPS
@optims([o for o in optim_db if o.supports_complex], dtypes=[torch.complex64])
def test_complex(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
# Also skip fused, since our fused kernels do not support complex
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable", "fused")
)
for optim_input in all_optim_inputs:
# Last param is intentionally real to test that we can mix real and complex
complex_params = [
torch.randn(10, 5, device=device, dtype=dtype, requires_grad=True),
torch.randn(10, device=device, dtype=dtype, requires_grad=True),
torch.randn(
10, 5, device=device, dtype=torch.float32, requires_grad=True
),
]
real_params = [
(
torch.view_as_real(param).detach().clone().requires_grad_()
if param.is_complex()
else param.detach().clone().requires_grad_()
)
for param in complex_params
]
complex_optimizer = optim_cls(complex_params, **optim_input.kwargs)
real_optimizer = optim_cls(real_params, **optim_input.kwargs)
real_steps = []
complex_steps = []
grads_losses = []
def real_closure():
for param in real_params:
grad = torch.randn_like(param)
param.grad = grad
real_steps.append(param.detach().clone())
grads_losses.append(grad.clone())
loss = torch.randn(1)
grads_losses.append(loss.clone())
return loss
def complex_closure():
for param in complex_params:
if torch.is_complex(param):
grad = torch.view_as_complex(grads_losses.pop(0))
complex_steps.append(torch.view_as_real_copy(param.detach()))
else:
grad = grads_losses.pop(0)
complex_steps.append(param.detach().clone())
param.grad = grad
return grads_losses.pop(0)
for _ in range(3):
if optim_info.step_requires_closure:
# LBFGS, for example, requires closure and calls it internally
real_optimizer.step(real_closure)
complex_optimizer.step(complex_closure)
else:
# For other optimizers, we call closure explicitly to set the gradients
real_closure()
complex_closure()
real_optimizer.step()
complex_optimizer.step()
# Final Parameters should be the same
complex_params_asreal = [
torch.view_as_real(param) if param.is_complex() else param
for param in complex_params
]
self.assertEqual(real_params, complex_params_asreal)
# All intermediate steps should also be the same
# also checks steps taken within for example a line search
self.assertEqual(complex_steps, real_steps)
@skipMPS
@xfailIfS390X
@optims([o for o in optim_db if o.supports_complex], dtypes=[torch.complex64])
def test_complex_2d(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
# Also skip fused, since our fused kernels do not support complex
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable", "fused")
)
for optim_input in all_optim_inputs:
if optim_info.step_requires_closure:
# Why? The way we implement complex is by turning complex params into view_as_real
# alternatives. For example, an size (M,N) tensor will become (M,N,2). In this test,
# we break apart a tensor into its real and imaginary parts, which would be 2x(M,N).
# For other pointwise optimizers, this distinction is trivial, but for LBFGS where
# there are reductions across all parameters (and all the grads get flattened into
# one long Tensor), this ordering matters. Why? Reductions are not deterministic
# because addition between floating point numbers is not associative, i.e.,
# a + b + c != a + c + b. Thus, we add a seed here to control the discrepancy that
# will happen with LBFGS. Note that in test_complex above, there is no need for a seed
# nor for increased tolerance, because results should be bitwise equivalent.
torch.manual_seed(2024)
a1 = torch.randn(2, device=device, dtype=dtype, requires_grad=True)
a1_real = a1.real.detach().clone()
a1_imag = a1.imag.detach().clone()
a1_real.requires_grad_()
a1_imag.requires_grad_()
optim1 = optim_cls([a1], **optim_input.kwargs)
optim2 = optim_cls([a1_real, a1_imag], **optim_input.kwargs)
a1_reals = TensorTracker()
a1_imags = TensorTracker()
a1_grad_reals = TensorTracker()
a1_grad_imags = TensorTracker()
losses = TensorTracker()
def closure1():
optim1.zero_grad()
loss = rosenbrock(a1).abs()
loss.backward()
# Track clones to best test accuracy
a1_reals.add(a1.real)
a1_imags.add(a1.imag)
a1_grad_reals.add(a1.grad.real)
a1_grad_imags.add(a1.grad.imag)
losses.add(loss)
return loss
def closure2():
optim2.zero_grad()
a1_reals.pop_check_set(a1_real, self)
a1_imags.pop_check_set(a1_imag, self)
a2 = torch.complex(a1_real, a1_imag)
loss = rosenbrock(a2).abs()
losses.pop_check_set(loss, self)
loss.backward()
a1_grad_reals.pop_check_set(a1_real.grad, self)
a1_grad_imags.pop_check_set(a1_imag.grad, self)
return loss
for _ in range(3):
if optim_info.step_requires_closure:
# LBFGS, for example, requires closure and calls it internally
optim1.step(closure1)
optim2.step(closure2)
else:
closure1()
closure2()
optim1.step()
optim2.step()
self.assertEqual(a1.real, a1_real)
self.assertEqual(a1.imag, a1_imag)
self.assertTrue(a1_reals.all_popped())
self.assertTrue(a1_imags.all_popped())
self.assertTrue(a1_grad_reals.all_popped())
self.assertTrue(a1_grad_imags.all_popped())
self.assertTrue(losses.all_popped())
def _compare_between(
self, inputs, models, optimizers, assert_eq_kwargs=None, assert_step_dtype=None
):
# why 7? iteration 7 is where we start to see differences for RAdam
# params interacting with the small eps value, because that's right
# after rho_t becomes greater than 5 in step 6.
if assert_eq_kwargs is None:
assert_eq_kwargs = {}
kIterations = 7
tracker = TensorTracker(assert_eq_kwargs)
for i in range(kIterations):
state, updated_params = [], []
if not isinstance(inputs, list):
inputs = [inputs, inputs]
for input, model, optimizer in zip(inputs, models, optimizers):
optimizer.zero_grad()
if i == 3:
# Freeze a layer to test if the step of this layer in 'fused' or 'foreach'
# is same as the step in 'forloop'.
model[2].requires_grad_(False)
if i == 5:
# Unfreeze the layer after 2 iters.
model[2].requires_grad_(True)
# Test that step behaves as expected (a no-op) when grads are set to None
if i != 2:
output = model(input)
loss = output.sum()
loss.backward()
optimizer.step()
state.append(optimizer.state)
updated_params.append(model.parameters())
og_state, new_state = state
for og_p, new_p in zip(updated_params[0], updated_params[1]):
tracker.add(og_p)
tracker.pop_check_set(new_p, self)
# check that optimizer states are the same
og_p_state = og_state[og_p]
new_p_state = new_state[new_p]
if assert_step_dtype is not None:
if torch.is_tensor(og_p_state.get("step", None)):
self.assertEqual(og_p_state["step"].dtype, assert_step_dtype)
if torch.is_tensor(new_p_state.get("step", None)):
self.assertEqual(new_p_state["step"].dtype, assert_step_dtype)
for k in og_p_state:
tracker.add(og_p_state[k])
tracker.pop_check_set(new_p_state[k], self)
self.assertTrue(tracker.all_popped())
def _test_derived_optimizers(
self,
device,
dtype,
optim_info,
flag,
reduced_precision=False,
assert_step_dtype=None,
):
"""
Given a flag 'fused' or 'foreach', test for parity of optimizer state
and updated parameters between when the flag is set to True and False
for provided optimizer configurations.
"""
assert flag in ("foreach", "fused")
assert_eq_kwargs = {} if not reduced_precision else FP16_REDUCED_PRECISION
optim_inputs = optim_info.optim_inputs_func(device=device, dtype=dtype)
optim_cls = optim_info.optim_cls
for optim_input in optim_inputs:
models, optimizers = [], []
kwargs = deepcopy(optim_input.kwargs)
if kwargs.get("capturable", False) and _get_device_type(device) == "cpu":
# capturable is not supported on CPU
continue
for flag_value in (False, True):
kwargs[flag] = flag_value
input = torch.tensor(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=device
).reshape(3, 2)
torch.manual_seed(1)
model = torch.nn.Sequential(
torch.nn.Linear(2, 3),
torch.nn.Sigmoid(),
torch.nn.Linear(3, 1),
torch.nn.Sigmoid(),
)
model.to(dtype=dtype, device=device)
# foreach/fused optimizers should be tested with a
# zero_size tensor as its last param.
# ref: https://github.com/pytorch/pytorch/issues/100701
empty_param = torch.empty(
(), device=device, dtype=dtype, requires_grad=True
)
empty_param.grad = torch.rand_like(empty_param)
params = list(model.parameters()) + [empty_param]
optimizer = optim_cls(params, **kwargs)
models.append(model)
optimizers.append(optimizer)
self._compare_between(
input, models, optimizers, assert_eq_kwargs, assert_step_dtype
)
@skipMPS # MPS doesn't support torch.float64, see https://github.com/pytorch/pytorch/issues/115350
@optims(
[optim for optim in optim_db if "foreach" in optim.supported_impls],
dtypes=[torch.float64],
)
def test_foreach_matches_forloop(self, device, dtype, optim_info):
self._test_derived_optimizers(device, dtype, optim_info, "foreach")
@onlyCUDA
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
@parametrize("impl", ["foreach", "fused"])
@optims(
[
optim
for optim in optim_db
if "foreach" in optim.supported_impls or "fused" in optim.supported_impls
]
)
def test_mixed_device_dtype(self, device, dtype, optim_info, impl):
"""
Similar in essence to _test_derived_optimizers above. The main difference is that
_test_derived_optimizers uses model parameters whereas we randomly pass in
parameters of different dtypes and devices here. We need multiple GPUs (vs just a
CPU and GPU) because fused adam only works on GPUs. (Thus we only run the tests
that call into this helper when TEST_MULTIGPU.)
"""
assert impl in ("foreach", "fused")
if impl == "foreach" and "foreach" not in optim_info.supported_impls:
return unittest.skip(
f"foreach not supported for {optim_info.optim_cls.__name__}"
)
elif impl == "fused" and "cuda" not in optim_info.supports_fused_on:
return unittest.skip(
f"fused not supported for {optim_info.optim_cls.__name__} on cuda"
)
params = [
torch.rand(2, 3, dtype=torch.float64, device="cuda:0", requires_grad=True),
torch.rand(2, 3, dtype=torch.float32, device="cuda:0", requires_grad=True),
torch.rand(2, 3, dtype=torch.float16, device="cuda:0", requires_grad=True),
torch.rand(2, 3, dtype=torch.bfloat16, device="cuda:0", requires_grad=True),
torch.rand(2, 3, dtype=torch.float64, device="cuda:1", requires_grad=True),
torch.rand(2, 3, dtype=torch.float32, device="cuda:1", requires_grad=True),
torch.rand(2, 3, dtype=torch.float16, device="cuda:1", requires_grad=True),
torch.rand(2, 3, dtype=torch.bfloat16, device="cuda:1", requires_grad=True),
torch.randint(
1024, (2, 3), dtype=torch.int64, device="cuda:1", requires_grad=False
),
]
for p in params:
if p.requires_grad:
p.grad = torch.rand_like(p, device=p.device, dtype=p.dtype)
kIterations = 7 if impl == "foreach" else 1
optim_inputs = optim_info.optim_inputs_func(device=device)
optim_cls = optim_info.optim_cls
for optim_input in optim_inputs:
updated_params, state = [], []
kwargs = deepcopy(optim_input.kwargs)
if kwargs.get("capturable", False) and _get_device_type(device) == "cpu":
# capturable is not supported on CPU
continue
for use_impl in (False, True):
kwargs[impl] = use_impl
params_clone = []
for p in params:
p_clone = p.detach().clone()
if p.requires_grad:
p_clone.requires_grad = True
p_clone.grad = p.grad.detach().clone()
params_clone.append(p_clone)
optimizer = optim_cls(params_clone, **kwargs)
for _ in range(kIterations):
optimizer.step()
state.append(optimizer.state)
updated_params.append(params_clone)
og_state, new_state = state
for og_p, new_p in zip(updated_params[0], updated_params[1]):
# Increasing the tolerance as we are collating lots of ops together for optimizers and
# the designated tolerances are for single op only.
single_rtol, single_atol = torch.testing._comparison.get_tolerances(
new_p.dtype, rtol=None, atol=None
)
rtol = 5 * single_rtol
atol = 5 * single_atol
self.assertEqual(og_p, new_p, rtol=rtol, atol=atol)
# check that optimizer states are the same
og_p_state = og_state[og_p]
new_p_state = new_state[new_p]
for k in og_p_state:
actual = new_p_state[k]
self.assertEqual(og_p_state[k], actual, rtol=rtol, atol=atol)
@onlyCUDA
@optims(
[optim for optim in optim_db if "foreach" in optim.supported_impls],
dtypes=[torch.float64],
)
def test_set_default_dtype_works_with_foreach(self, device, dtype, optim_info):
# https://github.com/pytorch/pytorch/issues/110940
# We coerce step to always be float32 unless the
# default dtype is higher prec float64
old_default_dtype = torch.get_default_dtype()
for default_dtype in [torch.float64, torch.float16]:
try:
torch.set_default_dtype(default_dtype)
self._test_derived_optimizers(
device,
dtype,
optim_info,
"foreach",
reduced_precision=default_dtype == torch.float16,
assert_step_dtype=(
torch.float64
if default_dtype == torch.float64
else torch.float32
),
)
finally:
torch.set_default_dtype(old_default_dtype)
@onlyCUDA
@largeTensorTest("72GB", "cuda")
@optims(
[optim for optim in optim_db if "foreach" in optim.supported_impls],
dtypes=[torch.float16],
)
def test_foreach_large_tensor(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
optim_inputs = optim_info.optim_inputs_func(device=device)
for optim_input in optim_inputs:
params = [torch.ones(2**32, device=device, dtype=dtype)]
params[0].grad = torch.zeros_like(params[0])
optimizer = optim_cls(params, foreach=True, **optim_input.kwargs)
optimizer.step()
@onlyCUDA
@optims(
[optim for optim in optim_db if "foreach" in optim.supported_impls],
dtypes=[torch.float32],
)
def test_peak_memory_foreach(self, device, dtype, optim_info):
nparams = 10
optim_inputs = optim_info.optim_inputs_func(device=device)
optim_cls = optim_info.optim_cls
for optim_input in optim_inputs:
kwargs = deepcopy(optim_input.kwargs)
max_mems = []
for flag_value in (False, True):
kwargs["foreach"] = flag_value
# The 16 * 8 = 128 is critical here! Our CUDACachingAllocator allocates in blocks
# of 512, meaning any tensor that occupies <512 bytes of memory will allocate a
# whole 512 bytes anyway. We use 128 (cuz datasize would be 4 bytes) so that param
# is size 512 exactly, making our later calculations for intermediate_size easy.
param = torch.rand(16, 8, device=device, dtype=dtype)
params = [torch.rand_like(param) for _ in range(nparams)]
optimizer = optim_cls(params, **kwargs)
for p in params:
p.grad = torch.rand_like(p)
optimizer.step()
import gc
gc.collect()
torch.cuda.reset_peak_memory_stats()
optimizer.step()
gc.collect()
max_mems.append(torch.cuda.max_memory_allocated())
st_max_mem, mt_max_mem = max_mems
intermediate_size = nparams * param.nelement() * param.element_size()
nintermediates = 1 # we expect a budget of 1 intermediate most of the time
# Check the param group directly to handle if the compiler set capturable
if optimizer.param_groups[0].get(
"capturable", False
) or optim_cls.__name__ in ["Adadelta", "ASGD", "RAdam"]:
# with capturable in Adam(W), we have 2 extra intermediates for the bias_corrections
# with Adadelta, we have 2 extra for (acc_delta + eps) and (square_avg + eps)
# ASGD allocates axs, 2x mus, 2x etas, and grads at the same time
nintermediates = 3
if optim_cls.__name__ == "NAdam":
# with capturable in NAdam, we have 3 extra intermediates for the
# bias_correction, mus, and mu_nexts
if TEST_WITH_TORCHDYNAMO:
# With dynamo, the eager/FX backend appears to hold memory longer than
# vanilla eager: https://github.com/pytorch/pytorch/issues/125511
nintermediates = 8
else:
nintermediates = 5
if optim_cls.__name__ == "RAdam":
# RAdam has four intermediates with capturable
# num, unrect_step_size, buffer, grouped_grads
if TEST_WITH_TORCHDYNAMO:
# With dynamo, the eager/FX backend appears to hold memory than
# vanilla eager: https://github.com/pytorch/pytorch/issues/125511
nintermediates = 6
else:
nintermediates = 4
elif optim_cls.__name__ in ["NAdam", "Adagrad", "RMSprop", "Adafactor"]:
# NAdam uses two intermediates at the same time (grads & exp_avg_sq_sqrt)
# Adagrad uses std and grads at the same time
# RMSprop uses avg and grads
# Adafactor uses row/col var and its mean
nintermediates = 2
if optim_cls.__name__ == "Adafactor" and kwargs.get("maximize", False):
# When maximize is True, Adafactor also tracks device_grad
nintermediates = 3
# Dynamo ST uses less mem than eager in the case of Adam/Adagrad/Nadam/RAdam
# which makes the foreach memory check fail
if TEST_WITH_TORCHDYNAMO:
st_max_mem += 6000
expected_max_mem = st_max_mem + intermediate_size * nintermediates
# hipcc currently can't generate efficient code for the small buffer optimization
# code path (see Note [small buffer optimization] for details), thus we always
# dynamically allocate the tensor metadata for ROCM. Adjusting the expected max
# memory usage to account for this.
if TEST_WITH_ROCM:
expected_max_mem *= 1.02
self.assertLessEqual(mt_max_mem, expected_max_mem)
@optims(
[optim for optim in optim_db if "fused" in optim.supported_impls],
dtypes=floating_types_and(
torch.bfloat16,
torch.float16,
),
)
def test_fused_matches_forloop(self, device, dtype, optim_info):
if _get_device_type(device) not in optim_info.supports_fused_on:
self.skipTest(
f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
)
if _get_device_type(device) == "mps" and dtype not in (
torch.float16,
torch.float32,
torch.bfloat16,
):
self.skipTest(
"MPS supports only torch.float16, torch.float32 and torch.bfloat16"
)
self._test_derived_optimizers(device, dtype, optim_info, "fused")
@optims(
[optim for optim in optim_db if "fused" in optim.supported_impls],
dtypes=(torch.float32,),
)
def test_fused_error_on_params_on_meta(self, device, dtype, optim_info):
if _get_device_type(device) not in optim_info.supports_fused_on:
self.skipTest(
f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
)
with torch.device("meta"):
model = torch.nn.Sequential(
torch.nn.Linear(2, 3),
torch.nn.Sigmoid(),
torch.nn.Linear(3, 1),
torch.nn.Sigmoid(),
).to(dtype)
optimizer = optim_info.optim_cls(model.parameters(), fused=True)
with torch.device("meta"):
for p in model.parameters():
p.grad = torch.rand_like(p)
with self.assertRaisesRegex(
RuntimeError,
"`fused=True` requires all the params to be floating point Tensors",
):
optimizer.step()
optimizer.zero_grad(set_to_none=True)
model.to_empty(device=device)
for p in model.parameters():
p.grad = torch.rand_like(p)
optimizer.step()
@onlyNativeDeviceTypes
@largeTensorTest("64GB")
@optims(
[optim for optim in optim_db if "fused" in optim.supported_impls],
dtypes=[torch.float16],
)
def test_fused_large_tensor(self, device, dtype, optim_info):
if device not in optim_info.supports_fused_on:
self.skipTest(
f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
)
optim_cls = optim_info.optim_cls
optim_inputs = optim_info.optim_inputs_func(device=device)
for optim_input in optim_inputs:
params = [torch.ones(2**32, device=device, dtype=dtype)]
params[0].grad = torch.zeros_like(params[0])
optimizer = optim_cls(params, fused=True, **optim_input.kwargs)
optimizer.step()
@onlyCUDA
@optims(
[optim for optim in optim_db if "fused" in optim.supported_impls],
dtypes=[torch.float32],
)
def test_fused_does_not_step_if_foundinf(self, device, dtype, optim_info):
if device not in optim_info.supports_fused_on:
self.skipTest(
f"{device} is not supported for fused on {optim_info.optim_cls.__name__}"
)
optim_cls = optim_info.optim_cls
optim_inputs = optim_info.optim_inputs_func(device=device)
num_params = 5
for optim_input in optim_inputs:
for no_grad_scale in (False, True):
params = [
torch.ones((1,), device=device, dtype=dtype)
for _ in range(num_params)
]
params_c = [param.detach().clone() for param in params]
for p in params:
p.grad = torch.ones_like(p)
optimizer = optim_cls(params, fused=True, **optim_input.kwargs)
optimizer.grad_scale = (
None
if no_grad_scale
else torch.ones((1,), dtype=dtype, device=device)
)
optimizer.found_inf = torch.ones((), dtype=dtype, device=device)
optimizer.step()
for p in params:
if "step" in optimizer.state[p]:
self.assertEqual(
torch.zeros((), dtype=dtype, device=device),
optimizer.state[p]["step"],
)
self.assertEqual(params, params_c)
@parametrize("impl", ["fused", "capturable"])
@optims(
[optim for optim in optim_db if "fused" in optim.supported_impls],
dtypes=[torch.float32],
)
def test_cpu_load_state_dict(self, device, dtype, impl, optim_info):
# NOTE: This SIMULATES a fused/capturable optimizer with state moved to CPU, issue 103256
# How do we get there? Users typically create CUDA models on fused optimizers and then
# store checkpoints on CPU as CUDA memory is limited with torch.load(...map_location="cpu").
# Since this is a unit test, it is more expedient to simulate what the state_dict
# would look like, which is basically CPU tensors with fused/capturable flag = True.
optim_cls = optim_info.optim_cls
opt_name = optim_cls.__name__
if opt_name in ("SGD", "Adagrad") and impl == "capturable":
# Capturable SGD/Adagrad does not exist
self.skipTest("SGD does not currently support capturable")
if _get_device_type(device) == "cpu":
self.skipTest("Test is only for non-cpu devices")
elif (
impl == "fused"
and _get_device_type(device) not in optim_info.supports_fused_on
):
self.skipTest(f"{device} is not supported for fused on {opt_name}")
elif impl == "capturable" and _get_device_type(device) == "mps":
self.skipTest("MPS does not support capturable")
cpu_optim_inputs = optim_info.optim_inputs_func(device="cpu")
for optim_input in cpu_optim_inputs:
param = torch.tensor([0.1, 0.2], dtype=dtype, device="cpu")
optimizer = optim_cls([param], **optim_input.kwargs)
param.grad = torch.rand_like(param)
optimizer.step()
optim_state_dict_cpu = deepcopy(optimizer.state_dict())
optim_state_dict_cpu["param_groups"][0][impl] = True
# load
optim_input.kwargs[impl] = True
param_device = param.detach().clone().to(device=device)
optimizer_device = optim_cls([param_device], **optim_input.kwargs)
optimizer_device.load_state_dict(optim_state_dict_cpu)
optimizer_device.zero_grad()
param_device.grad = torch.rand_like(param_device)
optimizer_device.step()
@optims(optim_db, dtypes=[torch.float32])
def test_param_groups_weight_decay(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable",)
)
for optim_input in all_optim_inputs:
weight_kwargs = optim_input.kwargs
bias_kwargs = deepcopy(optim_input.kwargs)
bias_kwargs["weight_decay"] = 0.0
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
input = torch.randn(5, device=device, dtype=dtype)
optimizer = optim_cls(
[
dict(params=[weight], **weight_kwargs),
dict(params=[bias], **bias_kwargs),
]
)
loss = (weight.mv(input) + bias).pow(2).sum()
initial_value = loss.item()
for _ in range(20):
optimizer.zero_grad()
loss = (weight.mv(input) + bias).pow(2).sum()
loss.backward()
if optim_info.only_supports_sparse_grads:
# For this test, we naively convert the Tensor layout, which we know does
# NOT represent the expected use case for optims like SparseAdam!
weight.grad = weight.grad.to_sparse()
bias.grad = bias.grad.to_sparse()
optimizer.step()
# Test that the direction of loss moved appropriately
if optim_input.kwargs.get("maximize", False):
self.assertGreater(loss.item(), initial_value)
else:
self.assertLess(loss.item(), initial_value)
@optims(optim_db, dtypes=[torch.float32])
def test_param_groups_lr(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable",)
)
for optim_input in all_optim_inputs:
# optim_input.kwargs will be the param group kwargs, which should have >0 lr
if "lr" not in optim_input.kwargs or optim_input.kwargs["lr"] == 0:
optim_input.kwargs["lr"] = 1e-3
outer_kwargs = {"lr": 1e-28}
if optim_cls.__name__ == "Rprop":
# Allow min step size to be 0
outer_kwargs["step_sizes"] = (0, 50)
weight = Parameter(torch.randn((10, 5), device=device, dtype=dtype))
bias = Parameter(torch.randn((10), device=device, dtype=dtype))
irrelevant = Parameter(torch.randn(2, device=device, dtype=dtype))
irrelevant_clone = irrelevant.clone()
input = torch.randn(5, device=device, dtype=dtype)
optimizer = optim_cls(
[
dict(params=[weight, bias], **optim_input.kwargs),
dict(params=[irrelevant]),
],
**outer_kwargs,
)
loss = (weight.mv(input) + bias).pow(2).sum()
initial_value = loss.item()
for _ in range(20):
optimizer.zero_grad()
loss = (weight.mv(input) + bias).pow(2).sum()
loss.backward()
irrelevant.grad = torch.rand_like(irrelevant)
if optim_info.only_supports_sparse_grads:
# For this test, we naively convert the Tensor layout, which we know does
# NOT represent the expected use case for optims like SparseAdam!
weight.grad = weight.grad.to_sparse()
bias.grad = bias.grad.to_sparse()
irrelevant.grad = irrelevant.grad.to_sparse()
optimizer.step()
# Test that the direction of loss moved appropriately
if optim_input.kwargs.get("maximize", False):
self.assertGreater(loss.item(), initial_value)
else:
self.assertLess(loss.item(), initial_value)
# Test that irrelevant parameters were not updated since lr was almost 0
self.assertEqual(irrelevant, irrelevant_clone)
@optims(optim_db, dtypes=[torch.float32])
def test_step_is_noop_when_params_have_no_grad(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info
)
params = [
torch.randn(2, 3, requires_grad=False, device=device, dtype=dtype)
for _ in range(2)
]
old_params = [p.detach().clone() for p in params]
def closure():
return torch.tensor([1], device=device, dtype=dtype)
for optim_input in all_optim_inputs:
optimizer = optim_cls(params, **optim_input.kwargs)
optimizer.step(closure)
@optims(optim_db, dtypes=[torch.float32])
def test_step_is_noop_for_zero_grads(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info
)
param = torch.randn((5, 1), device=device, dtype=dtype, requires_grad=True)
old_param = param.detach().clone()
def closure():
return torch.tensor([1], device=device, dtype=dtype)
for optim_input in all_optim_inputs:
kwargs = optim_input.kwargs
# params will decay even if grads are empty if weight_decay != 0,
# and capturable doesn't work for CPU tensors
if kwargs.get("weight_decay", 0) != 0:
continue
# AdamW params will be updated regardless of grads due to lr, so make lr smaller
if optim_cls.__name__ == "AdamW":
kwargs["lr"] = (
torch.tensor(1e-5)
if isinstance(kwargs.get("lr", 1e-5), torch.Tensor)
else 1e-5
)
if kwargs.get("differentiable", False):
params = [param.clone()]
else:
params = [param]
optimizer = optim_cls(params, **kwargs)
if optim_info.only_supports_sparse_grads:
# Intentionally construct a multidimensional empty v for the sparse grad
# Single dim v passes the test while multidim correctly repros the issue
# https://github.com/pytorch/pytorch/issues/82486
i = torch.empty((1, 0), device=device, dtype=dtype)
v = torch.empty((0, 1), device=device, dtype=dtype)
params[0].grad = torch.sparse_coo_tensor(
i, v, (5, 1), device=device, dtype=dtype
)
else:
params[0].grad = torch.zeros_like(params[0])
optimizer.step(closure)
self.assertEqual(old_param, params[0])
@optims(optim_db, dtypes=[torch.float32])
def test_optimizer_can_be_printed(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info
)
params = [
Parameter(torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype))
for _ in range(2)
]
for optim_input in all_optim_inputs:
optimizer = optim_cls(params, **optim_input.kwargs)
optimizer.__repr__()
@parametrize("is_named_optim0", [True, False])
@parametrize("is_named_optim1", [True, False])
@optims(optim_db, dtypes=[torch.float32])
def test_state_dict_deterministic(
self, device, dtype, optim_info, is_named_optim0, is_named_optim1
):
optim_cls = optim_info.optim_cls
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable",)
)
weight = Parameter(
torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype)
)
bias = Parameter(torch.randn(2, requires_grad=True, device=device, dtype=dtype))
input = torch.randn(3, requires_grad=True, device=device, dtype=dtype)
params = [weight, bias]
def make_named_param(param, is_named):
if not is_named:
return param
return [(f"name{i}", p) for i, p in enumerate(param)]
def without_param_names(state_dict):
new_state_dict = deepcopy(state_dict)
for pg in new_state_dict["param_groups"]:
pg.pop("param_names", None)
return new_state_dict
def fwd_bwd(optim, w, b, i):
optim.zero_grad()
loss = (w.mv(i) + b).pow(2).sum()
loss.backward()
if optim_info.only_supports_sparse_grads:
if w.grad is not None:
w.grad = w.grad.to_sparse()
if b.grad is not None:
b.grad = b.grad.to_sparse()
return loss
for optim_input in all_optim_inputs:
params_in = make_named_param(params, is_named=is_named_optim0)
optimizer = optim_cls(params_in, **optim_input.kwargs)
closure = functools.partial(fwd_bwd, optimizer, weight, bias, input)
# Prime the optimizer
for _ in range(10):
if optim_info.step_requires_closure:
optimizer.step(closure)
else:
closure()
optimizer.step()
# Clone the weights and construct a new optimizer for them
with torch.no_grad():
weight_c = Parameter(weight.clone())
bias_c = Parameter(bias.clone())
params_c = make_named_param([weight_c, bias_c], is_named=is_named_optim1)
optimizer_c = optim_cls(params_c, **optim_input.kwargs)
closure_c = functools.partial(fwd_bwd, optimizer_c, weight_c, bias_c, input)
# Load the state dict from the original optimizer into the new one
optimizer_c.load_state_dict(deepcopy(optimizer.state_dict()))
# Run both optimizers in parallel
for _ in range(10):
if optim_info.step_requires_closure:
optimizer.step(closure)
optimizer_c.step(closure_c)
else:
closure()
closure_c()
optimizer.step()
optimizer_c.step()
self.assertEqual(weight, weight_c)
self.assertEqual(bias, bias_c)
# Make sure state dict is deterministic with equal (not identical) parameters
# Param names are optional and not needed to be the consistent.
self.assertEqual(
without_param_names(optimizer.state_dict()),
without_param_names(optimizer_c.state_dict()),
)
# Make sure repeated parameters have identical representation (see #36831)
optimizer_c.param_groups.extend(optimizer_c.param_groups)
self.assertEqual(
without_param_names(optimizer.state_dict())["param_groups"][-1],
without_param_names(optimizer_c.state_dict())["param_groups"][-1],
)
@optims(optim_db, dtypes=[torch.float32])
def test_can_load_older_state_dict(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable",)
)
for optim_input in all_optim_inputs:
torch.manual_seed(1)
model = torch.nn.Sequential(
torch.nn.Conv2d(4, 2, 1, stride=2),
torch.nn.BatchNorm2d(2, eps=1e-05, momentum=0.1),
)
model.to(dtype=dtype, device=device)
input = torch.rand(1, 4, 16, 16, device=device, dtype=dtype)
optimizer = optim_cls(model.parameters(), **optim_input.kwargs)
def fwd_bwd(optim, mod, i):
optim.zero_grad()
loss = mod(i).sum()
loss.backward()
return loss
for _ in range(3):
if optim_info.step_requires_closure:
optimizer.step(functools.partial(fwd_bwd, optimizer, model, input))
else:
fwd_bwd(optimizer, model, input)
optimizer.step()
# old_state_dict has all new flags del'd
old_state_dict = deepcopy(optimizer.state_dict())
old_state_dict_pg = old_state_dict["param_groups"]
for group in old_state_dict_pg:
for flag in optim_info.not_og_supported_flags:
if flag in group:
del group[flag]
optimizer.load_state_dict(old_state_dict)
# Make sure we can still step
if optim_info.step_requires_closure:
optimizer.step(functools.partial(fwd_bwd, optimizer, model, input))
else:
fwd_bwd(optimizer, model, input)
optimizer.step()
@parametrize("is_named_optim0", [True, False])
@parametrize("is_named_optim1", [True, False])
@optims(
[o for o in optim_db if not o.only_supports_sparse_grads],
dtypes=[torch.float32],
)
def test_can_load_from_to_named_state_dict(
self, device, dtype, optim_info, is_named_optim0, is_named_optim1
):
optim_cls = optim_info.optim_cls
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable",)
)
for optim_input in all_optim_inputs:
torch.manual_seed(1)
model = torch.nn.Sequential(
torch.nn.Conv2d(4, 2, 1, stride=2),
torch.nn.BatchNorm2d(2, eps=1e-05, momentum=0.1),
)
model.to(dtype=dtype, device=device)
input = torch.rand(1, 4, 16, 16, device=device, dtype=dtype)
def fwd_bwd(optim, mod, i):
optim.zero_grad()
loss = mod(i).sum()
loss.backward()
return loss
# test for parameters, named_parameters, and 2 groups:
params_to_optimizer = (
model.named_parameters() if is_named_optim0 else model.parameters()
)
optimizer = optim_cls(params_to_optimizer, **optim_input.kwargs)
for _ in range(3):
if optim_info.step_requires_closure:
optimizer.step(functools.partial(fwd_bwd, optimizer, model, input))
else:
fwd_bwd(optimizer, model, input)
optimizer.step()
# old_state_dict has all new flags del'd
old_state_dict = deepcopy(optimizer.state_dict())
params_to_optimizer2 = (
model.named_parameters() if is_named_optim1 else model.parameters()
)
optimizer2 = optim_cls(params_to_optimizer2, **optim_input.kwargs)
optimizer2.load_state_dict(old_state_dict)
# Make sure we can still step
if optim_info.step_requires_closure:
optimizer2.step(functools.partial(fwd_bwd, optimizer2, model, input))
else:
fwd_bwd(optimizer2, model, input)
optimizer2.step()
# Make sure that param_names are preserved when provided to at least one of the optimizers
if is_named_optim0 or is_named_optim1:
self.assertEqual(
optimizer2.state_dict()["param_groups"][0]["param_names"],
["0.weight", "0.bias", "1.weight", "1.bias"],
)
@parametrize("is_named_optim", [True, False])
@optims(optim_db, dtypes=[torch.float32])
def test_save_load_equality_with_weights_only(
self, device, dtype, optim_info, is_named_optim
):
optim_cls = optim_info.optim_cls
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable",)
)
weight = Parameter(
torch.randn(2, 3, requires_grad=True, device=device, dtype=dtype)
)
bias = Parameter(torch.randn(2, requires_grad=True, device=device, dtype=dtype))
input = torch.randn(3, requires_grad=True, device=device, dtype=dtype)
params = [weight, bias]
def make_named_param(param, is_named):
if not is_named:
return param
return [(f"name{i}", p) for i, p in enumerate(param)]
def fwd_bwd(optim, w, b, i):
optim.zero_grad()
loss = (w.mv(i) + b).pow(2).sum()
loss.backward()
if optim_info.only_supports_sparse_grads:
weight.grad = weight.grad.to_sparse()
bias.grad = bias.grad.to_sparse()
return loss
for optim_input in all_optim_inputs:
params_in = make_named_param(params, is_named=is_named_optim)
optimizer = optim_cls(params_in, **optim_input.kwargs)
closure = functools.partial(fwd_bwd, optimizer, weight, bias, input)
# Prime the optimizer
for _ in range(3):
optimizer.step(closure)
sd = optimizer.state_dict()
# === Check saved/loaded state_dict are the same (including weights_only load). ===
with tempfile.TemporaryFile() as f:
torch.save(sd, f)
f.seek(0)
sd_copy = torch.load(f)
self.assertEqual(sd_copy, sd)
del sd_copy
f.seek(0)
sd_copy_wo = torch.load(f, weights_only=True)
self.assertEqual(sd_copy_wo, sd)
@optims(optim_db, dtypes=[torch.float32])
def test_load_nontensor_step(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable",)
)
params = [
Parameter(torch.randn(2, 3, device=device, dtype=dtype)) for _ in range(2)
]
for p in params:
p.grad = torch.rand_like(p)
if optim_info.only_supports_sparse_grads:
# For this test, we naively convert the Tensor layout, which we know does
# NOT represent the expected use case for optims like SparseAdam!
p.grad = p.grad.to_sparse()
# Needed for second order optims like LBFGS
closure_loss = torch.rand(1, device=device, dtype=dtype)
def closure():
return closure_loss if optim_info.step_requires_closure else None
for optim_input in all_optim_inputs:
kwargs = optim_input.kwargs
optimizer = optim_cls(params, **optim_input.kwargs)
for _ in range(3):
optimizer.step(closure)
state_dict = deepcopy(optimizer.state_dict())
for p_state in state_dict["state"].values():
if "step" in p_state and torch.is_tensor(p_state["step"]):
p_state["step"] = p_state["step"].item()
optimizer.load_state_dict(state_dict)
optimizer.step(closure)
@onlyCUDA
@optims(optim_db, dtypes=[torch.float32])
def test_state_dict_with_cuda_params(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
# We limit our configs to CPU only, because we will be moving them to CUDA later
cpu_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
"cpu", dtype, optim_info, skip=("differentiable",)
)
# Needed for second order optims like LBFGS
closure_loss = torch.rand(1, device=device, dtype=dtype)
def closure():
return closure_loss if optim_info.step_requires_closure else None
for optim_input in cpu_optim_inputs:
if (
"fused" in optim_input.kwargs
and "cuda" not in optim_info.supports_fused_on
):
self.skipTest(
f"cuda is not supported for fused on {optim_cls.__name__}"
)
params = [
Parameter(torch.randn(2, 3, device="cpu", dtype=dtype))
for _ in range(2)
]
for p in params:
p.grad = torch.randn_like(p)
if optim_info.only_supports_sparse_grads:
# For this test, we naively convert the Tensor layout, which we know does
# NOT represent the expected use case for optims like SparseAdam!
p.grad = p.grad.to_sparse()
optimizer = optim_cls(params, **optim_input.kwargs)
for _ in range(3):
optimizer.step(closure)
with torch.no_grad():
params_cuda = [p.to(device="cuda") for p in params]
for i, p in enumerate(params_cuda):
p.grad = params[i].grad.to(device="cuda")
optimizer_cuda = optim_cls(params_cuda, **optim_input.kwargs)
state_dict_cpu = deepcopy(optimizer.state_dict())
state_dict_cuda = deepcopy(optimizer.state_dict())
optimizer_cuda.load_state_dict(state_dict_cuda)
# Make sure state_dict_cuda isn't modified by merely calling load_state_dict
self.assertEqual(state_dict_cpu, state_dict_cuda)
# Make sure that device of state['step'] is still CPU _unless_ torch.compile() added a capturable!
capturable = state_dict_cpu["param_groups"][0].get("capturable", False)
fused = state_dict_cpu["param_groups"][0].get("fused", False)
new_state_dict = optimizer_cuda.state_dict()
for state_cpu, state_cuda in zip(
state_dict_cpu["state"].values(), new_state_dict["state"].values()
):
if "step" in state_cpu and torch.is_tensor(state_cpu["step"]):
self.assertEqual(
state_cuda["step"].device.type,
"cuda" if capturable or fused else "cpu",
)
for _ in range(5):
optimizer.step(closure)
optimizer_cuda.step(closure)
self.assertEqual(params, params_cuda)
self.assertEqual(optimizer.state_dict(), optimizer_cuda.state_dict())
@staticmethod
def _state_dict_pre_hook(optimizer: Optimizer) -> None:
optimizer.state["test"] = 1
@staticmethod
def _state_dict_post_hook(
optimizer: Optimizer, state_dict: Dict[str, Any]
) -> Dict[str, Any]:
if "test" in state_dict["state"]:
state_dict["state"].pop("test")
state_dict["ran_state_dict_pre_hook"] = True
else:
state_dict["ran_state_dict_pre_hook"] = False
return state_dict
@optims(optim_db, dtypes=[torch.float32])
def test_state_dict_pre_hook(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info
)
for optim_input in all_optim_inputs:
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
optim = optim_cls([param], **optim_input.kwargs)
optim.register_state_dict_pre_hook(self.__class__._state_dict_pre_hook)
state_dict = optim.state_dict()
self.assertEqual(state_dict["state"]["test"], 1)
@optims(optim_db, dtypes=[torch.float32])
def test_state_dict_post_hook(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info
)
for optim_input in all_optim_inputs:
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
optim = optim_cls([param], **optim_input.kwargs)
optim.register_state_dict_post_hook(self.__class__._state_dict_post_hook)
state_dict = optim.state_dict()
self.assertFalse(state_dict["ran_state_dict_pre_hook"])
@optims(optim_db, dtypes=[torch.float32])
def test_state_dict_pre_post_hook(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info
)
for optim_input in all_optim_inputs:
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
optim = optim_cls([param], **optim_input.kwargs)
optim.register_state_dict_pre_hook(self.__class__._state_dict_pre_hook)
optim.register_state_dict_post_hook(self.__class__._state_dict_post_hook)
state_dict = optim.state_dict()
self.assertFalse("test" in state_dict["state"])
self.assertTrue(state_dict["ran_state_dict_pre_hook"])
@staticmethod
def _load_state_dict_pre_hook1(
optimizer: Optimizer, state_dict: Dict[str, Any]
) -> None:
state_dict["param_groups"][0]["lr"] = 0.002
@staticmethod
def _load_state_dict_pre_hook2(
optimizer: Optimizer, state_dict: Dict[str, Any]
) -> Dict[str, Any]:
# The typical use case for returning a state dict is to drastically modify the state dict.
# I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used
my_state_dict = deepcopy(state_dict)
my_state_dict["param_groups"][0]["lr"] = 0.003
return my_state_dict
@staticmethod
def _load_state_dict_post_hook(optimizer: Optimizer) -> None:
optimizer.state["ran_load_state_dict_pre_hook2"] = (
optimizer.param_groups[0]["lr"] == 0.003
)
optimizer.state["ran_load_state_dict_post_hook"] = True
@optims(optim_db, dtypes=[torch.float32])
def test_load_state_dict_pre_hook_and_prepend(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info
)
for optim_input in all_optim_inputs:
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
optim = optim_cls([param], **optim_input.kwargs)
state_dict = optim.state_dict()
# usually one would have a new optim instance here, but it's all the same here
optim.register_load_state_dict_pre_hook(
self.__class__._load_state_dict_pre_hook1
)
optim.load_state_dict(state_dict)
self.assertEqual(optim.param_groups[0]["lr"], 0.002)
optim.register_load_state_dict_pre_hook(
self.__class__._load_state_dict_pre_hook2, prepend=True
)
optim.load_state_dict(state_dict)
# If prepend were False would be 0.003 but since prepend is True, the other hook overrides
self.assertEqual(optim.param_groups[0]["lr"], 0.002)
@optims(optim_db, dtypes=[torch.float32])
def test_load_state_dict_post_hook(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info
)
for optim_input in all_optim_inputs:
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
optim = optim_cls([param], **optim_input.kwargs)
optim.register_load_state_dict_post_hook(
self.__class__._load_state_dict_post_hook
)
optim.load_state_dict(optim.state_dict())
self.assertFalse(optim.state["ran_load_state_dict_pre_hook2"])
self.assertTrue(optim.state["ran_load_state_dict_post_hook"])
@optims(optim_db, dtypes=[torch.float32])
def test_load_state_dict_pre_post_hook(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info
)
for optim_input in all_optim_inputs:
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
optim = optim_cls([param], **optim_input.kwargs)
optim.register_load_state_dict_pre_hook(
self.__class__._load_state_dict_pre_hook2
)
optim.register_load_state_dict_post_hook(
self.__class__._load_state_dict_post_hook
)
optim.load_state_dict(optim.state_dict())
self.assertTrue(optim.state["ran_load_state_dict_pre_hook2"])
self.assertTrue(optim.state["ran_load_state_dict_post_hook"])
@optims(optim_db, dtypes=[torch.float32])
def test_step_post_hook(self, device, dtype, optim_info):
def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data += 2
params = [torch.tensor([1, 1], device=device, dtype=dtype)]
def dummy_closure():
return 1
closure = dummy_closure if optim_info.step_requires_closure else None
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info
)
for optim_input in all_optim_inputs:
optim = optim_info.optim_cls(params, **optim_input.kwargs)
data = 2
hook_handle = optim.register_step_post_hook(post_hook)
optim.step(closure)
optim.step(closure)
# check if post hooks were registered
self.assertEqual(data, 6)
# remove handles, take step and verify that hook is no longer registered
hook_handle.remove()
optim.step(closure)
self.assertEqual(data, 6)
@optims(optim_db, dtypes=[torch.float32])
def test_step_pre_hook(self, device, dtype, optim_info):
def pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data += 2
params = [torch.tensor([1, 1], device=device, dtype=dtype)]
def dummy_closure():
return 1
closure = dummy_closure if optim_info.step_requires_closure else None
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info
)
for optim_input in all_optim_inputs:
optim = optim_info.optim_cls(params, **optim_input.kwargs)
data = 5
hook_handle = optim.register_step_pre_hook(pre_hook)
optim.step(closure)
optim.step(closure)
# check if pre hooks were registered
self.assertEqual(data, 9)
# remove handles, take step and verify that hook is no longer registered
hook_handle.remove()
optim.step(closure)
self.assertEqual(data, 9)
@optims(optim_db, dtypes=[torch.float32])
def test_step_all_hooks(self, device, dtype, optim_info):
def global_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data.append(0)
def global_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data.append(5)
def local_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data.append(1)
def local_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data.append(2)
params = [torch.tensor([1, 1], device=device, dtype=dtype)]
def dummy_closure():
return 1
closure = dummy_closure if optim_info.step_requires_closure else None
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info
)
for optim_input in all_optim_inputs:
optim = optim_info.optim_cls(params, **optim_input.kwargs)
optim2 = SGD(params)
data = []
# register global hooks to both optimizers
global_pre_handle = register_optimizer_step_pre_hook(global_pre_hook)
global_post_handle = register_optimizer_step_post_hook(global_post_hook)
# register local hooks
first_pre_handle = optim.register_step_pre_hook(local_pre_hook)
first_post_handle = optim.register_step_post_hook(local_post_hook)
second_pre_handle = optim2.register_step_pre_hook(local_pre_hook)
second_post_handle = optim2.register_step_post_hook(local_post_hook)
optim.step(closure)
self.assertListEqual(data, [0, 1, 2, 5])
optim2.step(closure)
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5])
optim.step(closure)
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
# remove all hooks
global_pre_handle.remove()
global_post_handle.remove()
first_pre_handle.remove()
first_post_handle.remove()
second_pre_handle.remove()
second_post_handle.remove()
optim.step(closure)
optim2.step(closure)
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
@optims(optim_db, dtypes=[torch.float32])
def test_deepcopy_copies_all_public_attrs(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable",)
)
params = [
Parameter(torch.randn(2, 3, device=device, dtype=dtype)) for _ in range(2)
]
for p in params:
p.grad = torch.rand_like(p)
if optim_info.only_supports_sparse_grads:
# For this test, we naively convert the Tensor layout, which we know does
# NOT represent the expected use case for optims like SparseAdam!
p.grad = p.grad.to_sparse()
# Needed for second order optims like LBFGS
def closure():
return 1 if optim_info.step_requires_closure else None
def getPublicAttrs(obj):
return {k for k in obj.__dict__ if not k.startswith("_")}
for optim_input in all_optim_inputs:
optimizer = optim_cls(params, **optim_input.kwargs)
# Make some state
for _ in range(3):
if optim_info.step_requires_closure:
optimizer.step(closure)
else:
closure()
optimizer.step()
self.assertEqual(
getPublicAttrs(optimizer), getPublicAttrs(deepcopy(optimizer))
)
@optims(
[optim for optim in optim_db if optim.step_requires_closure],
dtypes=[torch.float32],
)
def test_second_order_optims_return_consistent_types(
self, device, dtype, optim_info
):
# Motivated by #7586
optim_cls = optim_info.optim_cls
params = [
torch.randn(10, 5, device=device, dtype=dtype),
torch.randn(10, device=device, dtype=dtype),
]
def closure():
return torch.tensor([10], device=device, dtype=dtype)
for optim_input in optim_info.optim_inputs_func(device=device):
# Currently, the only second order optim is LBFGS, so we just go ahead and modify
# "tolerance_grad", but this may not scale if we add second order optims in the future
kwargs = optim_input.kwargs
kwargs["tolerance_grad"] = math.inf
optim_inf = optim_cls(params, **kwargs)
kwargs["tolerance_grad"] = -math.inf
optim_neg_inf = optim_cls(params, **kwargs)
res1 = optim_inf.step(closure)
res2 = optim_neg_inf.step(closure)
self.assertEqual(type(res1), type(res2))
@onlyCUDA
@optims(
[
optim
for optim in optim_db
if "cpu" in optim.supports_fused_on and "cuda" in optim.supports_fused_on
],
dtypes=floating_types_and(
torch.bfloat16,
torch.float16,
),
)
def test_fused_cpu_matches_cuda(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
optim_inputs = optim_info.optim_inputs_func(device="cpu")
for optim_input in optim_inputs:
inpts, models, optimizers = [], [], []
for dev in ("cpu", "cuda"):
kwargs = optim_input.kwargs
kwargs["fused"] = True
inpt = torch.tensor(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=dev
).reshape(3, 2)
torch.manual_seed(1)
model = torch.nn.Sequential(
torch.nn.Linear(2, 3),
torch.nn.Sigmoid(),
torch.nn.Linear(3, 1),
torch.nn.Sigmoid(),
)
model.to(dtype=dtype, device=dev)
# foreach/fused optimizers should be tested with a
# zero_size tensor as its last param.
# ref: https://github.com/pytorch/pytorch/issues/100701
empty_param = torch.empty(
(), device=dev, dtype=dtype, requires_grad=True
)
empty_param.grad = torch.rand_like(empty_param)
params = list(model.parameters()) + [empty_param]
optimizer = optim_cls(params, **kwargs)
inpts.append(inpt)
models.append(model)
optimizers.append(optimizer)
self._compare_between(inpts, models, optimizers)
@onlyCUDA
@optims(
[
o
for o in optim_db
if ("foreach" in o.supported_impls and o.optim_cls.__name__ != "Adafactor")
],
dtypes=[torch.float32],
)
def test_defaults_changed_to_foreach(self, device, dtype, optim_info):
# Test that the default implementations for optimizers are changed to foreach
# except Adafactor, which defaults to the single tensor impl for memory efficiency.
optim_cls = optim_info.optim_cls
model = torch.nn.Linear(5, 5)
model.to(dtype=dtype, device=device)
inpt = torch.rand(2, 5, dtype=dtype, device=device)
import inspect
module = inspect.getmodule(optim_cls)
for optim_input in optim_info.optim_inputs_func(device=device):
optim = optim_cls(model.parameters(), **optim_input.kwargs)
optim.zero_grad()
output = model(inpt)
loss = output.sum()
loss.backward()
with patch.object(
module, f"_multi_tensor_{optim_cls.__name__.lower()}"
) as mocked_foreach_impl:
optim.step()
self.assertTrue(mocked_foreach_impl.called)
@optims(optim_db, dtypes=[torch.float32])
def test_non_empty_state(self, device, dtype, optim_info):
# There are internal tests that check that the state is not empty
optim_cls = optim_info.optim_cls
model = torch.nn.Linear(5, 5)
model.to(dtype=dtype, device=device)
inpt = torch.rand(2, 5, dtype=dtype, device=device)
for optim_input in optim_info.optim_inputs_func(device=device):
optim = optim_cls(model.parameters(), **optim_input.kwargs)
optim.zero_grad()
output = model(inpt)
loss = output.sum()
loss.backward()
if optim_info.only_supports_sparse_grads:
for param in model.parameters():
if param.grad is not None:
param.grad = param.grad.to_sparse()
if optim_info.step_requires_closure:
optim.step(lambda: 1.0)
else:
optim.step()
for state in optim.state.values():
self.assertGreater(len(state), 0)
instantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True)
if __name__ == "__main__":
run_tests()
|