1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315
|
# @package optimizer
# Module caffe2.python.optimizer
import copy
import logging
from collections import defaultdict, namedtuple
import numpy as np
from caffe2.proto import caffe2_pb2
from caffe2.python import core, scope, utils, workspace
from caffe2.python.modeling import parameter_info
from past.builtins import basestring
_LEARNING_RATE_INJECTION = "lr_injection"
AuxOptimizerParams = namedtuple("AuxOptimizerParams", ["local", "shared"])
_optimizer_instance_count = defaultdict(int)
FP16_ENGINES = ["SIMD_Q_FP16", "SIMD_Q_STOC_FP16", "SIMD_Q_STOC_MKL_FP16"]
logger = logging.getLogger(__name__)
def reset_optimizer_instance_count():
"""
This function clears the _optimizer_instance_count. And keeps it
empty. This functionality is needed in some situations where
optimizer instance count might not reset even though the workplace is reset.
"""
_optimizer_instance_count.clear()
class Optimizer(object):
def __init__(self):
self._aux_params = AuxOptimizerParams(local=[], shared=[])
self._instance_num = _optimizer_instance_count[self.__class__.__name__]
_optimizer_instance_count[self.__class__.__name__] += 1
self._lr_multiplier = None
self._local_lr_multiplier = None
self._local_lr_multiplier_on_gpu = False
self._use_dedicated_lr_iteration_counter = False
"""
Adds optimization operators to the net for given parameter and its gradient
Parameter is specified by either 'param' being a ParameterInfo object.
In this case param.grad has to be set
Or by 'param' being a BlobReference and 'grad' being a BlobReference for its
gradient.
"""
def __call__(self, net, param_init_net, param, grad=None):
if grad is None:
assert isinstance(
param, parameter_info.ParameterInfo
), "Expected parameter to be of type ParameterInfo, got {}".format(param)
assert param.grad is not None
else:
if isinstance(param, basestring):
param = core.BlobReference(param)
param = parameter_info.ParameterInfo(param_id=None, param=param, grad=grad)
self._run(net, param_init_net, param)
def _run(self, net, param_init_net, param_info):
raise Exception("Not Implemented")
def get_cpu_blob_name(self, base_str, node_name=""):
classname = self.__class__.__name__
return "%s_%d_%s%s_cpu" % (classname, self._instance_num, base_str, node_name)
def get_gpu_blob_name(self, base_str, gpu_id, node_name):
classname = self.__class__.__name__
return "%s_%d_%s%s_gpu%d" % (
classname,
self._instance_num,
base_str,
node_name,
gpu_id,
)
@property
def attributes(self):
# return a dict that contains attributes related to init args only
attr = copy.deepcopy(self.__dict__)
del attr["_instance_num"]
return attr
@property
def use_dedicated_lr_iteration_counter(self):
return self._use_dedicated_lr_iteration_counter
@use_dedicated_lr_iteration_counter.setter
def use_dedicated_lr_iteration_counter(self, val):
self._use_dedicated_lr_iteration_counter = val
def make_unique_blob_name(self, base_str):
"""
Returns a blob name that will be unique to the current device
and optimizer instance.
"""
current_scope = scope.CurrentDeviceScope()
if current_scope is None:
return self.get_cpu_blob_name(base_str)
if core.IsGPUDeviceType(current_scope.device_type):
return self.get_gpu_blob_name(
base_str, current_scope.device_id, current_scope.node_name
)
else:
return self.get_cpu_blob_name(base_str, current_scope.node_name)
def build_lr(
self,
net,
param_init_net,
base_learning_rate,
learning_rate_blob=None,
policy="fixed",
iter_val=0,
**kwargs
):
if learning_rate_blob is None:
learning_rate_blob = self.make_unique_blob_name("lr")
if self._use_dedicated_lr_iteration_counter:
iteration = utils.BuildUniqueMutexIter(
param_init_net,
net,
iter=utils.OPTIMIZER_ITERATION_LR_NAME,
iter_mutex=utils.ITERATION_MUTEX_LR_NAME,
iter_val=iter_val,
)
logger.info(f"Created dedicated learning rate iteration counter: {iteration}")
else:
iteration = utils.BuildUniqueMutexIter(param_init_net, net, iter_val=iter_val)
if not net.BlobIsDefined(learning_rate_blob):
# There is one interesting thing here: since we are minimizing, we are
# doing "descent" so the learning rate is set to be negative.
lr = net.LearningRate(
[iteration],
learning_rate_blob,
base_lr=-base_learning_rate,
policy=policy,
**kwargs
)
else:
lr = net.GetBlobRef(learning_rate_blob)
if self._lr_multiplier is not None:
lr_multiplier = net.CopyFromCPUInput(
self._lr_multiplier, self.make_unique_blob_name("lr_multiplier")
)
lr = net.Mul(
[lr, lr_multiplier],
self.make_unique_blob_name("scaled_lr"),
broadcast=1,
)
if self._local_lr_multiplier is not None:
current_scope = scope.CurrentDeviceScope()
if (
current_scope is not None
and core.IsGPUDeviceType(current_scope.device_type)
and not self._local_lr_multiplier_on_gpu
):
local_lr_multiplier = net.CopyFromCPUInput(
self._local_lr_multiplier,
self.make_unique_blob_name("local_lr_multiplier"),
)
else:
local_lr_multiplier = self._local_lr_multiplier
lr = net.Mul(
[lr, local_lr_multiplier],
self.make_unique_blob_name("local_scaled_lr"),
broadcast=1,
)
return lr, iteration
def build_non_lr_iter(
self,
net,
param_init_net,
iter_val=0,
):
assert (
self._use_dedicated_lr_iteration_counter
), "This method should be only called when dedicated learning rate iteration counter is used."
iteration = utils.BuildUniqueMutexIter(param_init_net, net, iter_val=iter_val)
logger.info(f"Created iteration counter for non learning rate purposes: {iteration}")
# We need to create a dummy learning rate operator to enforce that
# iteration counter blob being placed in the trainer nodes. Otherwise,
# the Automatic Device Placement (ADP) algorithm for Hierachical
# Training (HT) will encounter issues to distribute blobs across group
# parameter servers. Note that this learning rate operator will not be
# used for any other purpose.
learning_rate_blob = self.make_unique_blob_name("iter_placement_hint")
if not net.BlobIsDefined(learning_rate_blob):
net.LearningRate(
[iteration],
learning_rate_blob,
base_lr=1.0,
policy="fixed",
)
return iteration
def add_lr_multiplier(self, lr_multiplier):
"""
Set the global learning rate multiplier. If a multiplier already
existed, this will overwrite the existing multiplier. The multiplier is
used for all future calls to _run(), unless it is overwritten.
"""
self._lr_multiplier = lr_multiplier
def _add_local_lr_multiplier(self, local_lr_multiplier, is_gpu_blob=False):
"""
Set the local learning rate multiplier. This local multiplier is
multiplied with the global learning rate multiplier if it exists. As
with the global learning rate multiplier, this multiplier will be
used for all future calls to _run(), so please call
_clear_local_lr_multiplier() at the beginning of the optimizer's _run()
before optionally calling this function.
"""
self._local_lr_multiplier = local_lr_multiplier
self._local_lr_multiplier_on_gpu = is_gpu_blob
def _clear_local_lr_multiplier(self):
self._local_lr_multiplier = None
self._local_lr_multiplier_on_gpu = False
@staticmethod
def dedup(net, sparse_dedup_aggregator, grad):
assert isinstance(
grad, core.GradientSlice
), "Dedup only works for sparse gradient, got {}".format(grad)
if sparse_dedup_aggregator:
return net.DeduplicateGradientSlices(
grad, aggregator=sparse_dedup_aggregator
)
else:
return grad
def get_auxiliary_parameters(self):
"""Returns a list of auxiliary parameters.
Returns:
aux_params: A namedtuple, AuxParams.
aux_params.local stores a list of blobs. Each blob is a local
auxiliary parameter. A local auxiliary parameter is a parameter in
parallel to a learning rate parameter. Take adagrad as an example,
the local auxiliary parameter is the squared sum parameter, because
every learning rate has a squared sum associated with it.
aux_params.shared also stores a list of blobs. Each blob is a shared
auxiliary parameter. A shared auxiliary parameter is a parameter
that is shared across all the learning rate parameters. Take adam as
an example, the iteration parameter is a shared parameter, because
all the learning rates share the same iteration parameter.
"""
return self._aux_params
# TODO(xlwang): In transfer learning, parameter initialized from pretrained
# model might require a different learning rate than otherwise initialized.
# To this end, here we implement a python solution where
# `base_learning_rate` is scaled by `scale`, by calling
# `scale_learning_rate`; Alternatively, we can achieve same effect by
# rewriting the LearningRate operator in C++
# Note that it is the responsibility of specific optimizer to decide what
# logic should be used for `scale_learning_rate`
def scale_learning_rate(self, *args, **kwargs):
raise NotImplementedError(
"Optimizer Need to Implement `scale_learning_rate` method."
)
def create_lars_inputs(self, param_init_net, weight_decay, trust, lr_max):
wd = param_init_net.ConstantFill(
[], "weight_decay", shape=[1], value=weight_decay
)
trust = param_init_net.ConstantFill([], "trust", shape=[1], value=trust)
lr_max = param_init_net.ConstantFill([], "lr_max", shape=[1], value=lr_max)
return wd, trust, lr_max
class SgdOptimizer(Optimizer):
def __init__(
self,
base_learning_rate=0.01,
policy="fixed",
momentum=0.0,
nesterov=True,
sparse_dedup_aggregator=None,
lars=None,
**kwargs
):
super(SgdOptimizer, self).__init__()
self.base_learning_rate = base_learning_rate
self.policy = policy
self.momentum = momentum
self.nesterov = nesterov
self.sparse_dedup_aggregator = sparse_dedup_aggregator
self.lars = lars
self.init_kwargs = kwargs
def _run(self, net, param_init_net, param_info):
param = param_info.blob
grad = param_info.grad
if self.base_learning_rate == 0:
return
assert (
self.base_learning_rate > 0
), "Expect positive base learning rate, got {}".format(self.base_learning_rate)
self._clear_local_lr_multiplier()
# TODO(zqq): support LARS for sparse parameters
if self.lars is not None and not isinstance(grad, core.GradientSlice):
assert self.lars >= 0, "Lars offset must be nonnegative, got {}".format(
self.lars
)
wd, trust, lr_max = self.create_lars_inputs(
param_init_net, 0.0, 1.0, np.finfo(np.float32).max
)
lr_lars_multiplier = net.Lars(
[param, grad, wd, trust, lr_max],
self.make_unique_blob_name(str(param) + "_lars"),
offset=self.lars,
lr_min=0.0,
)
current_scope = scope.CurrentDeviceScope()
self._add_local_lr_multiplier(
lr_lars_multiplier,
is_gpu_blob=(
current_scope is not None
and core.IsGPUDeviceType(current_scope.device_type)
),
)
# We need negative sign for LR when used directly with WeightedSum
# below.
lr_sign = -1 if self.momentum else 1
lr, _ = self.build_lr(
net,
param_init_net,
base_learning_rate=self.base_learning_rate * lr_sign,
policy=self.policy,
**(self.init_kwargs)
)
dev = scope.CurrentDeviceScope()
if dev is None:
dev = core.DeviceOption(caffe2_pb2.CPU)
# Each GPU/CPU must have its own ONE blob, thus modify the name
# to include device information.
ONE = param_init_net.ConstantFill(
[],
"ONE_{}_{}{}".format(dev.device_type, dev.device_id, dev.node_name),
shape=[1],
value=1.0,
)
self._aux_params.shared.append(ONE)
if self.momentum > 0:
momentum_data = param_init_net.ConstantFill(
param, str(param) + "_momentum", value=0.0
)
self._aux_params.local.append(momentum_data)
if isinstance(grad, core.GradientSlice):
grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
if self.momentum > 0.0:
net.SparseMomentumSGDUpdate(
[grad.values, momentum_data, lr, param, grad.indices],
[grad.values, momentum_data, param],
momentum=self.momentum,
nesterov=self.nesterov,
)
else:
net.ScatterWeightedSum(
[param, ONE, grad.indices, grad.values, lr], param
)
else:
if self.momentum > 0.0:
net.MomentumSGDUpdate(
[grad, momentum_data, lr, param],
[grad, momentum_data, param],
momentum=self.momentum,
nesterov=self.nesterov,
)
else:
coeff = lr
net.WeightedSum([param, ONE, grad, coeff], param)
def scale_learning_rate(self, scale):
self.base_learning_rate *= scale
return
class MultiPrecisionSgdOptimizer(SgdOptimizer):
def __init__(
self,
base_learning_rate=0.1,
momentum=0.0,
policy="fixed",
nesterov=True,
sparse_dedup_aggregator=None,
**kwargs
):
super(MultiPrecisionSgdOptimizer, self).__init__(
base_learning_rate=base_learning_rate,
policy=policy,
momentum=momentum,
nesterov=nesterov,
sparse_dedup_aggregator=sparse_dedup_aggregator,
**kwargs
)
def _run(self, net, param_init_net, param_info):
param = param_info.blob
param_fp32 = (
param_info.blob_copy[core.DataType.FLOAT]
if param_info.blob_copy is not None
else None
)
# If we have a straight fp32 parameter, run the base class
if param_fp32 is None:
return SgdOptimizer._run(self, net, param_init_net, param_info)
grad = param_info.grad
if self.base_learning_rate == 0:
return
assert (
self.base_learning_rate > 0
), "Expect positive base learning rate, got {}".format(self.base_learning_rate)
lr, _ = self.build_lr(
net,
param_init_net,
base_learning_rate=-self.base_learning_rate,
policy=self.policy,
**(self.init_kwargs)
)
momentum_data = param_init_net.ConstantFill(
param_fp32, str(param) + "_momentum", value=0.0
)
self._aux_params.local.append(momentum_data)
assert not isinstance(
grad, core.GradientSlice
), "MultiPrecisionSgd does not support sparse gradients"
# Copy gradient to fp32
grad_fp32 = net.HalfToFloat(grad, grad + "_fp32")
# update (fused) in fp32
net.MomentumSGDUpdate(
[grad_fp32, momentum_data, lr, param_fp32],
[grad_fp32, momentum_data, param_fp32],
momentum=self.momentum,
nesterov=self.nesterov,
)
# Copy updated param back to fp16
net.FloatToHalf(param_fp32, param)
class FP16SgdOptimizer(SgdOptimizer):
def __init__(
self,
base_learning_rate=0.1,
momentum=0.0,
policy="fixed",
nesterov=True,
weight_decay=0.0001,
sparse_dedup_aggregator=None,
**kwargs
):
super(FP16SgdOptimizer, self).__init__(
base_learning_rate=base_learning_rate,
policy=policy,
momentum=momentum,
nesterov=nesterov,
sparse_dedup_aggregator=sparse_dedup_aggregator,
**kwargs
)
self.weight_decay = weight_decay
def _run(self, net, param_init_net, param_info, fp32_update=False):
fp32_update_flag = 0
param_name = str(param_info.blob)
# should only be triggered in FP16 training by SpatialBN, which
# requires FP32 params in CuDNN.
if param_name.find("spatbn") != -1:
fp32_update = True
if fp32_update:
# doing a 32bit update
# Have to assume param_info.blob is FP32 as there is no way
# (that i currently know of) to query a blob's type in python
fp32_update_flag = 1
param = param_info.blob
param_fp32 = param_info.blob
else:
if param_info.blob_copy is None:
# doing a 32bit update
# Have to assume param_info.blob is FP32 as there is no way
# (that i currently know of) to query a blob's type in python
fp32_update_flag = 1
param = param_info.blob
param_fp32 = param_info.blob
else:
if core.DataType.FLOAT in param_info.blob_copy:
param = param_info.blob
param_fp32 = param_info.blob_copy[core.DataType.FLOAT]
elif core.DataType.FLOAT16 in param_info.blob_copy:
param = param_info.blob_copy[core.DataType.FLOAT16]
param_fp32 = param_info.blob
else:
AssertionError(
"Unrecognized parameter format to be updated "
"by FP16 Optimizer. Parameter: {}".format(param_info.name)
)
grad = param_info.grad
if self.base_learning_rate == 0:
return
assert (
self.base_learning_rate > 0
), "Expect positive base learning rate, got {}".format(self.base_learning_rate)
lr, _ = self.build_lr(
net,
param_init_net,
base_learning_rate=-self.base_learning_rate,
policy=self.policy,
**(self.init_kwargs)
)
momentum_data_fp32 = param_init_net.ConstantFill(
param_fp32, str(param) + "_momentum_fp32", value=0.0
)
momentum_data = param_init_net.FloatToHalf(
momentum_data_fp32, str(param) + "_momentum"
)
self._aux_params.local.append(momentum_data)
assert not isinstance(
grad, core.GradientSlice
), "FP16Sgd does not support sparse gradients"
if fp32_update_flag == 0:
net.FP16MomentumSGDUpdate(
[grad, momentum_data, lr, param],
[grad, momentum_data, param],
momentum=self.momentum,
nesterov=self.nesterov,
weight_decay=self.weight_decay,
)
else:
# flag set to 1, therefore doing FP32 update
net.FP32MomentumSGDUpdate(
[grad, momentum_data_fp32, lr, param],
[grad, momentum_data_fp32, param],
momentum=self.momentum,
nesterov=self.nesterov,
weight_decay=self.weight_decay,
)
class WeightDecayBuilder(Optimizer):
def __init__(self, weight_decay):
self.weight_decay = weight_decay
def _run(self, net, param_init_net, param_info):
dev = scope.CurrentDeviceScope()
if dev is None:
dev = core.DeviceOption(caffe2_pb2.CPU)
ONE = param_init_net.ConstantFill(
[], "ONE_{}_{}".format(dev.device_type, dev.device_id), shape=[1], value=1.0
)
WD = param_init_net.ConstantFill(
[],
"wd_{}_{}".format(dev.device_type, dev.device_id),
shape=[1],
value=self.weight_decay,
)
if isinstance(param_info.grad, core.GradientSlice):
raise ValueError("Weight decay does not yet support sparse gradients")
else:
net.WeightedSum(
[param_info.grad, ONE, param_info.blob, WD], param_info.grad
)
class AdagradOptimizer(Optimizer):
def __init__(
self,
alpha=0.01,
epsilon=1e-4,
decay=1,
weight_decay=0.0,
policy="fixed",
sparse_dedup_aggregator=None,
rowWise=False,
engine="",
lars=None,
output_effective_lr=False,
output_effective_lr_and_update=False,
pruning_options=None,
swa_options=None,
ema_options=None,
weight_scale=None,
counter_halflife=-1,
use_dedicated_lr_iteration_counter=False,
**kwargs
):
super(AdagradOptimizer, self).__init__()
self.alpha = alpha
self.epsilon = epsilon
self.decay = decay
self.weight_decay = float(weight_decay)
self.policy = policy
self.sparse_dedup_aggregator = sparse_dedup_aggregator
self.rowWise = rowWise
self.engine = engine
self.lars = lars
self.output_effective_lr = output_effective_lr
self.output_effective_lr_and_update = output_effective_lr_and_update
self.counter_halflife = counter_halflife
self.init_kwargs = kwargs
self.weight_scale = weight_scale
self.use_dedicated_lr_iteration_counter = use_dedicated_lr_iteration_counter
self._process_pruning_options(pruning_options)
self._process_swa_options(swa_options)
self._process_ema_options(ema_options)
def _process_swa_options(self, swa_options):
self.swa_enabled = True if swa_options else False
if self.swa_enabled:
self.swa_avg_start_it = swa_options.get("swa_avg_start_it", None)
self.swa_avg_end_it = swa_options.get("swa_avg_end_it", None)
self.swa_feedback_start_it = swa_options.get("swa_feedback_start_it", None)
self.swa_feedback_step = swa_options.get("swa_feedback_step", None)
self.swa_feedback_end_it = swa_options.get("swa_feedback_end_it", None)
def _process_ema_options(self, ema_options):
self.ema_enabled = True if ema_options else False
if self.ema_enabled:
self.ema_start = ema_options.get("ema_start", None)
self.ema_end = ema_options.get("ema_end", None)
self.ema_step = ema_options.get("ema_step", None)
self.ema_alpha = ema_options.get("ema_alpha", None)
def _process_pruning_options(self, pruning_options):
self.use_mask = False
if pruning_options is None:
pruning_options = {}
else:
assert isinstance(pruning_options, dict), (
"pruning_options can only "
"be provided as a dictionary, currently: {}".format(pruning_options)
)
self.mask_tensor = pruning_options.get("mask_tensor", None)
self.mask_db_path = pruning_options.get("mask_db_path", None)
self.mask_db_type = pruning_options.get("mask_db_type", None)
self.mask_blob_name = pruning_options.get("mask_blob_name", None)
self.prune_delays = pruning_options.get("prune_delays", [])
self.prune_ratios = pruning_options.get("prune_ratios", [])
self.prune_block_size = pruning_options.get("prune_block_size", 1)
if self.mask_tensor is not None:
assert (
type(self.mask_tensor) is np.ndarray
), "mask_tensor must be a numpy array!"
assert self.mask_db_path is None, (
"mask can be provided through either a numpy array "
"or a db path, not both"
)
assert self.mask_db_type is None, (
"mask can be provided through either a numpy array "
"or a db path, not both"
)
assert self.mask_blob_name is None, (
"mask can be provided through either a numpy array "
"or a db path, not both"
)
self.use_mask = True
if self.mask_db_path is not None or self.mask_db_type is not None:
assert self.mask_db_path is not None, (
"when mask is provided through db, "
"db path, db type, and blob name are all needed"
)
assert self.mask_db_type is not None, (
"when mask is provided through db, "
"db path, db type, and blob name are all needed"
)
assert self.mask_tensor is None, (
"mask can be provided through either a numpy array "
"or a db path, not both"
)
self.use_mask = True
if self.prune_delays:
assert self.prune_ratios is not None and len(self.prune_delays) == len(
self.prune_ratios
), "Prune Delays and prune ratios should be of the same length"
assert (
self.mask_tensor is None
), "Mask Tensor should be None with prune ratios"
assert (
self.mask_db_path is None
), "Mask DB Path should be None with prune ratios"
self.use_mask = True
def _run(self, net, param_init_net, param_info):
param = param_info.blob
grad = param_info.grad
if self.alpha <= 0:
return
self._clear_local_lr_multiplier()
if self.lars is not None and not isinstance(grad, core.GradientSlice):
assert (
self.weight_decay == 0
), "weight decay is not implemented for LARS yet"
assert self.lars >= 0, "Lars offset must be nonnegative, got {}".format(
self.lars
)
wd, trust, lr_max = self.create_lars_inputs(
param_init_net, 0.0, 1.0, np.finfo(np.float32).max
)
lr_lars_multiplier = net.Lars(
[param, grad, wd, trust, lr_max],
self.make_unique_blob_name(str(param) + "_lars"),
offset=self.lars,
lr_min=0.0,
)
current_scope = scope.CurrentDeviceScope()
self._add_local_lr_multiplier(
lr_lars_multiplier,
is_gpu_blob=(
current_scope is not None
and core.IsGPUDeviceType(current_scope.device_type)
),
)
lr, lr_iteration = self.build_lr(
net,
param_init_net,
base_learning_rate=self.alpha,
policy=self.policy,
**(self.init_kwargs)
)
iteration = (
self.build_non_lr_iter(net, param_init_net, iter_val=0)
if self._use_dedicated_lr_iteration_counter
else lr_iteration
)
if self.counter_halflife > 0:
self._aux_params.shared.append(iteration)
if self.rowWise:
logger.debug(
"Using engine {} for rowWise Adagrad to train param {}".format(
self.engine, param
)
)
shapes, types = workspace.InferShapesAndTypes([param_init_net])
if str(param) not in shapes:
# Type/shape inference is not available for this param, fallback
# on Shape/Slice logic
shape = param_init_net.Shape(param, str(param) + "_shape")
num_rows = param_init_net.Slice(
[shape], str(shape) + "_numrows", starts=[0], ends=[1]
)
param_squared_sum = param_init_net.ConstantFill(
num_rows,
str(param) + "_avg_squared_sum",
input_as_shape=1,
value=0.0,
)
else:
param_squared_sum = param_init_net.ConstantFill(
[],
str(param) + "_avg_squared_sum",
shape=[shapes[str(param)][0]],
value=0.0,
)
else:
logger.debug(
"Using engine {} for regular Adagrad to train param {}".format(
self.engine, param
)
)
if self.engine in FP16_ENGINES:
assert (
self.weight_decay == 0
), "weight decay is not tested for engine: {}".format(self.engine)
shapes, types = workspace.InferShapesAndTypes([param_init_net])
assert str(param) in shapes, shapes
shape = shapes[str(param)]
param_squared_sum = param_init_net.Float16ConstantFill(
[], str(param) + "_squared_sum", value=0.0, shape=shape
)
else:
param_squared_sum = param_init_net.ConstantFill(
[param], str(param) + "_squared_sum", value=0.0
)
if self.use_mask is True:
assert (
self.weight_decay == 0
), "weight decay is not implemented for use_mask yet"
if self.mask_tensor is not None:
if not isinstance(grad, core.GradientSlice):
mask_blob = param_init_net.GivenTensorFill(
[],
[str(param) + "_mask"],
values=self.mask_tensor,
shape=self.mask_tensor.shape,
)
else:
self.mask_tensor = self.mask_tensor.astype(np.uint8)
mask_blob = param_init_net.GivenTensorBoolFill(
[],
[str(param) + "_mask"],
values=self.mask_tensor,
shape=self.mask_tensor.shape,
)
mask_blob = param_init_net.Cast(mask_blob, to=core.DataType.UINT8)
mask_changed_blob = param_init_net.ConstantFill(
[],
[str(param) + "_mask_changed_blob"],
value=False,
dtype=core.DataType.BOOL,
shape=[1],
)
elif (
self.mask_db_path is not None or self.mask_db_type is not None
): # mask is provided through a db file
# if mask_blob_name is not given use the param name to derive mask name
self.mask_blob_name = self.mask_blob_name or str(param) + "_mask"
mask_blob = param_init_net.Load(
[],
self.mask_blob_name,
db=self.mask_db_path,
db_type=self.mask_db_type,
absolute_path=True,
)
if isinstance(grad, core.GradientSlice):
mask_changed_blob = param_init_net.ConstantFill(
[],
[str(param) + "_mask_changed_blob"],
value=False,
dtype=core.DataType.BOOL,
shape=[1],
)
elif self.prune_delays:
last_mask_updated_iter = param_init_net.ConstantFill(
[],
[str(param) + "_last_mask_updated_iter"],
value=-1,
dtype=core.DataType.INT64,
shape=[1],
)
if isinstance(grad, core.GradientSlice):
AssertionError(
"Prune Delays and Prune Ratios are currently not supported"
"for sparse operators"
)
else:
mask_blob = param_init_net.GivenTensorFill(
[],
[str(param) + "_empty_mask"],
values=[],
dtype=core.DataType.FLOAT,
shape=[0],
)
else:
raise NotImplementedError(
"If mask is used, it needs a numpy array or a db file or"
"a delay iter needs to be provided"
)
self._aux_params.local.append(param_squared_sum)
if self.counter_halflife > 0:
shapes, types = workspace.InferShapesAndTypes([param_init_net])
if str(param) not in shapes:
shape = param_init_net.Shape(param, str(param) + "_shape")
num_rows = param_init_net.Slice(
[shape], str(shape) + "_numrows", starts=[0], ends=[1]
)
update_counter = param_init_net.ConstantFill(
num_rows,
str(param) + "_update_counter",
input_as_shape=1,
value=0.0,
dtype=core.DataType.DOUBLE,
)
prev_update_iter = param_init_net.ConstantFill(
num_rows,
str(param) + "_prev_update_iter",
input_as_shape=1,
value=0,
dtype=core.DataType.INT64,
)
else:
update_counter = param_init_net.ConstantFill(
[],
str(param) + "_update_counter",
shape=[shapes[str(param)][0]],
value=0.0,
dtype=core.DataType.DOUBLE,
)
prev_update_iter = param_init_net.ConstantFill(
[],
str(param) + "_prev_update_iter",
shape=[shapes[str(param)][0]],
value=0,
dtype=core.DataType.INT64,
)
self._aux_params.local.append(update_counter)
self._aux_params.local.append(prev_update_iter)
if self.rowWise:
assert isinstance(grad, core.GradientSlice), (
"If SparseAdagrad with rowWise=True, gradient must be "
"a gradientslice. PLease ensure that rowWise is not enabled "
"for the dense Adagrad optimizer, as it is not supported."
)
shapes, _ = workspace.InferShapesAndTypes([param_init_net])
param_shape = shapes[str(param)]
weight_decay = 0.0
if isinstance(grad, core.GradientSlice):
if len(param_shape) == 1:
weight_decay = 0.0
logger.warn(
"SKIPPING weight decay on 1d sparse param: {}.shape is {}".format(
str(param), param_shape
)
)
else:
weight_decay = self.weight_decay
else:
# Skip weight decay for 1d parameters
if len(param_shape) == 1:
weight_decay = 0.0
logger.warning(
"SKIPPING weight decay on 1d dense param: {}.shape is {}".format(
str(param), param_shape
)
)
else:
weight_decay = self.weight_decay
logger.debug(
"weight_decay for {} (shape:{}): {}".format(
str(param), param_shape, weight_decay
)
)
if isinstance(grad, core.GradientSlice):
assert (
self.decay == 1.0
), "Decay is not implemented for SparseAdagrad and must be set to 1"
grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
input_args = [param, param_squared_sum, grad.indices, grad.values, lr]
output_args = [param, param_squared_sum]
if self.rowWise:
if self.use_mask is True:
op = "MaskedRowWiseSparseAdagrad"
assert (
weight_decay == 0
), "weight decay is not implemented for {} yet".format(op)
input_args += [mask_blob, mask_changed_blob]
else:
if self.counter_halflife > 0:
input_args += [update_counter]
op = "RowWiseSparseAdagrad"
else:
if self.use_mask is True:
op = "MaskedSparseAdagrad"
assert (
weight_decay == 0
), "weight decay is not implemented for {} yet".format(op)
input_args += [mask_blob, mask_changed_blob]
else:
op = "SparseAdagrad"
logger.debug("using {} for {}".format(op, str(param)))
if self.prune_delays:
input_args += [iteration, last_mask_updated_iter]
output_args += [mask_blob, last_mask_updated_iter]
if weight_decay > 0 and self.counter_halflife == -1:
net.__getattr__(op)(
input_args,
output_args,
epsilon=self.epsilon,
weight_decay=weight_decay,
engine=self.engine,
)
elif weight_decay > 0 and self.counter_halflife != -1:
net.__getattr__(op)(
input_args,
output_args,
epsilon=self.epsilon,
weight_decay=weight_decay,
engine=self.engine,
counter_halflife=self.counter_halflife,
)
else:
net.__getattr__(op)(
input_args, output_args, epsilon=self.epsilon, engine=self.engine
)
if self.counter_halflife > 0:
net.RowWiseCounter(
[prev_update_iter, update_counter, grad.indices, iteration],
[prev_update_iter, update_counter],
counter_halflife=self.counter_halflife,
)
else:
input_args = [param, param_squared_sum, grad, lr]
output_args = [param, param_squared_sum]
if self.output_effective_lr_and_update:
assert (
self.use_mask is False
), "MaskedAdagrad doesn't support outputting effective_lr_and_update"
output_args.append(str(param) + "_effective_lr")
output_args.append(str(param) + "_update")
elif self.output_effective_lr:
assert (
self.use_mask is False
), "MaskedAdagrad doesn't support outputting effective_lr"
output_args.append(str(param) + "_effective_lr")
if self.use_mask is True:
input_args += [mask_blob]
if self.prune_delays:
input_args += [iteration, last_mask_updated_iter]
output_args += [mask_blob, last_mask_updated_iter]
if self.use_mask:
assert (
weight_decay == 0
), "weight decay is not implemented for use_mask yet"
net.MaskedAdagrad(
input_args,
output_args,
epsilon=self.epsilon,
decay=float(self.decay),
block_size=self.prune_block_size,
delays=self.prune_delays,
prune_ratios=self.prune_ratios,
engine=self.engine,
)
else:
if weight_decay > 0:
net.Adagrad(
input_args,
output_args,
epsilon=self.epsilon,
decay=float(self.decay),
weight_decay=weight_decay,
engine=self.engine,
)
else:
net.Adagrad(
input_args,
output_args,
epsilon=self.epsilon,
decay=float(self.decay),
engine=self.engine,
)
if self.swa_enabled:
param_swa = str(param) + "_swa"
if not param_init_net.BlobIsDefined(param_swa):
param_init_net.ConstantFill([param], param_swa, value=0.0)
self._aux_params.local.append(param_swa)
net.SWA(
[param, param_swa, iteration],
[param, param_swa],
avg_start=self.swa_avg_start_it,
avg_end=self.swa_avg_end_it,
feedback_start=self.swa_feedback_start_it,
feedback_step=self.swa_feedback_step,
feedback_end=self.swa_feedback_end_it,
)
if self.ema_enabled:
param_ema = str(param) + "_ema"
if not param_init_net.BlobIsDefined(param_ema):
param_init_net.ConstantFill([param], param_ema, value=0.0)
self._aux_params.local.append(param_ema)
net.EMA(
[param, param_ema, iteration],
[param, param_ema],
ema_start=self.ema_start,
ema_end=self.ema_end,
ema_step=self.ema_step,
ema_alpha=self.ema_alpha,
)
if self.weight_scale:
net.WeightScale(
[param, iteration],
[param],
stepsize=self.weight_scale.stepsize,
upper_bound_iter=self.weight_scale.upper_bound_iter,
scale=float(self.weight_scale.scale),
)
if self.weight_scale.to_aux:
net.WeightScale(
[param_squared_sum, iteration],
[param_squared_sum],
stepsize=self.weight_scale.stepsize,
upper_bound_iter=self.weight_scale.upper_bound_iter,
scale=float(self.weight_scale.scale),
)
def scale_learning_rate(self, scale):
self.alpha *= scale
return
class WngradOptimizer(Optimizer):
def __init__(
self,
alpha=1.0,
epsilon=1e-9,
policy="fixed",
sparse_dedup_aggregator=None,
engine="",
moment_init=100.0,
lars=None,
output_effective_lr=False,
output_effective_lr_and_update=False,
**kwargs
):
super(WngradOptimizer, self).__init__()
self.alpha = alpha
self.epsilon = epsilon
self.policy = policy
self.sparse_dedup_aggregator = sparse_dedup_aggregator
self.engine = engine
self.moment_init = moment_init
self.lars = lars
self.output_effective_lr = output_effective_lr
self.output_effective_lr_and_update = output_effective_lr_and_update
self.init_kwargs = kwargs
def _run(self, net, param_init_net, param_info):
param = param_info.blob
grad = param_info.grad
if self.alpha <= 0:
return
self._clear_local_lr_multiplier()
if self.lars is not None and not isinstance(grad, core.GradientSlice):
assert self.lars >= 0, "Lars offset must be nonnegative, got {}".format(
self.lars
)
wd, trust, lr_max = self.create_lars_inputs(
param_init_net, 0.0, 1.0, np.finfo(np.float32).max
)
lr_lars_multiplier = net.Lars(
[param, grad, wd, trust, lr_max],
self.make_unique_blob_name(str(param) + "_lars"),
offset=self.lars,
lr_min=0.0,
)
current_scope = scope.CurrentDeviceScope()
self._add_local_lr_multiplier(
lr_lars_multiplier,
is_gpu_blob=(
current_scope is not None
and core.IsGPUDeviceType(current_scope.device_type)
),
)
lr, _ = self.build_lr(
net,
param_init_net,
base_learning_rate=self.alpha,
policy=self.policy,
**(self.init_kwargs)
)
moment = param_init_net.ConstantFill(
[], str(param) + "_moment", shape=[1], value=self.moment_init
)
self._aux_params.local.append(moment)
if isinstance(grad, core.GradientSlice):
grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
net.SparseWngrad(
[param, moment, grad.indices, grad.values, lr],
[param, moment],
epsilon=self.epsilon,
engine=self.engine,
)
else:
output_args = [param, moment]
if self.output_effective_lr_and_update:
output_args.append(str(param) + "_effective_lr")
output_args.append(str(param) + "_update")
elif self.output_effective_lr:
output_args.append(str(param) + "_effective_lr")
net.Wngrad(
[param, moment, grad, lr],
output_args,
epsilon=self.epsilon,
engine=self.engine,
)
def scale_learning_rate(self, scale):
self.alpha *= scale
return
class StormOptimizer(Optimizer):
def __init__(
self,
lr=0.1,
momentum=10.0,
beta=0.1,
grad_sq_init=0.01,
policy="fixed",
sparse_dedup_aggregator=None,
lars=None,
**kwargs
):
"""Constructor function to add STORM Optimizer
Args:
lr: learning rate scaling (called k in the original paper)
momentum: momentum scaling (called c in the original paper)
beta: initial value of denominator in adaptive learning rate (
called c in the original paper)
grad_sq_init: initial value of gradient squared accumulator.
policy: specifies how learning rate should be applied, options are
'fixed', 'step', 'exp', etc.
sparse_dedup_aggregator: specifies deduplication strategy for
gradient slices. Works while using sparse gradients. Options
include 'mean' and 'sum'.
lars: lars offset.
"""
super(StormOptimizer, self).__init__()
self.lr = lr
self.momentum = momentum
self.beta = beta
self.grad_sq_init = grad_sq_init
self.policy = policy
self.sparse_dedup_aggregator = sparse_dedup_aggregator
self.lars = lars
self.init_kwargs = kwargs
def _run(self, net, param_init_net, param_info):
param = param_info.blob
grad = param_info.grad
if self.lr <= 0:
return
self._clear_local_lr_multiplier()
if self.lars is not None and not isinstance(grad, core.GradientSlice):
assert self.lars >= 0, "Lars offset must be nonnegative, got {}".format(
self.lars
)
wd, trust, lr_max = self.create_lars_inputs(
param_init_net, 0.0, 1.0, np.finfo(np.float32).max
)
lr_lars_multiplier = net.Lars(
[param, grad, wd, trust, lr_max],
self.make_unique_blob_name(str(param) + "_lars"),
offset=self.lars,
lr_min=0.0,
)
current_scope = scope.CurrentDeviceScope()
self._add_local_lr_multiplier(
lr_lars_multiplier,
is_gpu_blob=(
current_scope is not None
and core.IsGPUDeviceType(current_scope.device_type)
),
)
lr, _ = self.build_lr(
net,
param_init_net,
base_learning_rate=self.lr,
policy=self.policy,
**(self.init_kwargs)
)
moment = param_init_net.ConstantFill(param, str(param) + "_moment", value=0.0)
self._aux_params.local.append(moment)
grad_sq_sum = param_init_net.ConstantFill(
[], str(param) + "_grad_sq_sum", shape=[1], value=self.grad_sq_init
)
self._aux_params.local.append(grad_sq_sum)
if isinstance(grad, core.GradientSlice):
grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
net.SparseStorm(
[param, moment, grad_sq_sum, grad.values, grad.indices, lr],
[param, moment, grad_sq_sum],
momentum=self.momentum,
beta=self.beta,
)
else:
net.Storm(
[param, moment, grad_sq_sum, grad, lr],
[param, moment, grad_sq_sum],
momentum=self.momentum,
beta=self.beta,
)
def scale_learning_rate(self, scale):
self.lr *= scale
class AdadeltaOptimizer(Optimizer):
def __init__(
self,
alpha=0.01,
epsilon=1e-4,
decay=0.95,
policy="fixed",
sparse_dedup_aggregator=None,
engine="",
**kwargs
):
"""Constructor function to add Adadelta Optimizer
Args:
alpha: learning rate
epsilon: attribute of Adadelta to avoid numerical issues
decay: attribute of Adadelta to decay the squared gradient sum
policy: specifies how learning rate should be applied, options are
"fixed", "step", "exp", etc.
sparse_dedup_aggregator: specifies deduplication strategy for
gradient slices. Works while using sparse gradients. Options
include "mean" and "sum".
engine: the engine used, options include "", "CUDNN", etc.
"""
super(AdadeltaOptimizer, self).__init__()
self.alpha = alpha
self.epsilon = epsilon
self.decay = decay
self.policy = policy
self.sparse_dedup_aggregator = sparse_dedup_aggregator
self.engine = engine
self.init_kwargs = kwargs
def _run(self, net, param_init_net, param_info):
param = param_info.blob
grad = param_info.grad
if self.alpha <= 0:
return
lr, _ = self.build_lr(
net,
param_init_net,
base_learning_rate=self.alpha,
policy=self.policy,
**(self.init_kwargs)
)
moment = param_init_net.ConstantFill(
[param], str(param) + "_squared_moment", value=0.0
)
moment_update = param_init_net.ConstantFill(
[param], str(param) + "_squared_moment_update", value=0.0
)
self._aux_params.local.append(moment)
self._aux_params.local.append(moment_update)
if isinstance(grad, core.GradientSlice):
grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
net.SparseAdadelta(
[param, moment, moment_update, grad.indices, grad.values, lr],
[param, moment, moment_update],
epsilon=self.epsilon,
decay=self.decay,
engine=self.engine,
)
else:
net.Adadelta(
[param, moment, moment_update, grad, lr],
[param, moment, moment_update],
epsilon=self.epsilon,
decay=self.decay,
engine=self.engine,
)
def scale_learning_rate(self, scale):
self.alpha *= scale
return
class FtrlOptimizer(Optimizer):
def __init__(
self,
alpha=0.01,
beta=1e-4,
lambda1=0,
lambda2=0,
sparse_dedup_aggregator=None,
engine="",
):
super(FtrlOptimizer, self).__init__()
self.alpha = alpha
self.beta = beta
self.lambda1 = lambda1
self.lambda2 = lambda2
self.sparse_dedup_aggregator = sparse_dedup_aggregator
self.engine = engine
def _run(self, net, param_init_net, param_info):
param = param_info.blob
grad = param_info.grad
if self.alpha <= 0:
return
nz = param_init_net.ConstantFill(
[param], str(param) + "_ftrl_nz", extra_shape=[2], value=0.0
)
self._aux_params.local.append(nz)
if isinstance(grad, core.GradientSlice):
grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
net.SparseFtrl(
[param, nz, grad.indices, grad.values],
[param, nz],
engine=self.engine,
alpha=self.alpha,
beta=self.beta,
lambda1=self.lambda1,
lambda2=self.lambda2,
)
else:
net.Ftrl(
[param, nz, grad],
[param, nz],
engine=self.engine,
alpha=self.alpha,
beta=self.beta,
lambda1=self.lambda1,
lambda2=self.lambda2,
)
def scale_learning_rate(self, scale):
self.alpha *= scale
return
class GFtrlOptimizer(Optimizer):
"""Group Lasso FTRL Optimizer."""
def __init__(
self,
alpha=0.01,
beta=1e-4,
lambda1=0,
lambda2=0,
sparse_dedup_aggregator=None,
engine="",
):
super(GFtrlOptimizer, self).__init__()
self.alpha = alpha
self.beta = beta
self.lambda1 = lambda1
self.lambda2 = lambda2
self.sparse_dedup_aggregator = sparse_dedup_aggregator
self.engine = engine
def _run(self, net, param_init_net, param_info):
param = param_info.blob
grad = param_info.grad
if self.alpha <= 0:
return
nz = param_init_net.ConstantFill(
[param], str(param) + "_gftrl_nz", extra_shape=[2], value=0.0
)
self._aux_params.local.append(nz)
net.GFtrl(
[param, nz, grad],
[param, nz],
engine=self.engine,
alpha=self.alpha,
beta=self.beta,
lambda1=self.lambda1,
lambda2=self.lambda2,
)
def scale_learning_rate(self, scale):
self.alpha *= scale
return
class AdamOptimizer(Optimizer):
def __init__(
self,
alpha=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
policy="fixed",
use_lr_adaption=False,
lr_alpha=0.01,
normalized_lr_adaption=True,
sparse_dedup_aggregator=None,
rowWise=False,
engine="",
enableRAdam=False,
use_smart_decay=False, # See https://fburl.com/2jdiwrhy for context.
**kwargs
):
super(AdamOptimizer, self).__init__()
self.alpha = alpha
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.policy = policy
self.use_lr_adaption = use_lr_adaption
self.lr_alpha = lr_alpha
self.normalized_lr_adaption = normalized_lr_adaption
self.sparse_dedup_aggregator = sparse_dedup_aggregator
self.rowWise = rowWise
self.engine = engine
self.enableRAdam = enableRAdam
if use_smart_decay:
if rowWise:
raise NotImplementedError(('Smart decay is not implemented for rowWise Adam. '
'Set rowWise or use_smart_decay to False.'))
if enableRAdam:
raise NotImplementedError(('Smart decay is not implemented for RAdam. '
'Set enableRAdam or use_smart_decay to False.'))
if use_lr_adaption:
raise NotImplementedError(('Smart decay is not implemented with lr_adaption. '
'Set use_lr_adaption or use_smart_decay to False.'))
self.use_smart_decay = use_smart_decay
self.init_kwargs = kwargs
def _run(self, net, param_init_net, param_info):
param = param_info.blob
grad = param_info.grad
if self.alpha <= 0:
return
lr, iteration = self.build_lr(
net,
param_init_net,
base_learning_rate=self.alpha,
policy=self.policy,
**(self.init_kwargs)
)
m1 = param_init_net.ConstantFill([param], param + "_first_moment", value=0.0)
if self.rowWise:
shapes, types = workspace.InferShapesAndTypes([param_init_net])
m2 = param_init_net.ConstantFill(
[], param + "_avg_second_moment", shape=[shapes[param][0]], value=0.0
)
else:
m2 = param_init_net.ConstantFill(
[param], param + "_second_moment", value=0.0
)
# Initialize "minibatch in which this parameter was last seen" for smart decay.
if self.use_smart_decay:
shapes, _ = workspace.InferShapesAndTypes([param_init_net])
last_seen = param_init_net.ConstantFill(
[], param + "_last_seen", shape=[shapes[param][0]], value=0, dtype=core.DataType.INT64
)
self._aux_params.local.append(last_seen)
self._aux_params.shared.append(iteration)
self._aux_params.local.append(m1)
self._aux_params.local.append(m2)
if self.rowWise:
assert isinstance(grad, core.GradientSlice), (
"If SparseAdam with rowWise=True, gradient must be "
"a gradientslice. PLease ensure that rowWise is not enabled "
"for the dense Adam optimizer, as it is not supported."
)
output_blobs = [param, m1, m2]
if self.use_smart_decay:
output_blobs.append(last_seen)
if self.use_lr_adaption:
effective_grad = str(param) + "_effective_grad"
output_blobs.append(effective_grad)
if isinstance(grad, core.GradientSlice):
grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
if self.rowWise:
op = "RowWiseSparseAdam"
elif self.use_smart_decay:
op = "SmartDecaySparseAdam"
else:
op = "SparseAdam"
# Currently, only SparseAdam support RAdam, other Adam Ops will support later
if op == "SparseAdam":
net.__getattr__(op)(
[param, m1, m2, grad.indices, grad.values, lr, iteration],
output_blobs,
beta1=self.beta1,
beta2=self.beta2,
epsilon=self.epsilon,
enableRAdam=self.enableRAdam,
)
elif op == "SmartDecaySparseAdam":
net.__getattr__(op)(
[param, m1, m2, last_seen, grad.indices, grad.values, lr, iteration],
output_blobs,
beta1=self.beta1,
beta2=self.beta2,
epsilon=self.epsilon,
)
else:
assert (
not self.enableRAdam
), "Currently, RowWiseSparseAdam is not supported by RAdam!"
net.__getattr__(op)(
[param, m1, m2, grad.indices, grad.values, lr, iteration],
output_blobs,
beta1=self.beta1,
beta2=self.beta2,
epsilon=self.epsilon,
)
if self.use_lr_adaption:
net.LearningRateAdaption(
[lr, grad.values, effective_grad],
[lr],
lr_alpha=self.lr_alpha,
normalized_lr_adaption=self.normalized_lr_adaption,
)
else:
net.Adam(
[param, m1, m2, grad, lr, iteration],
output_blobs,
beta1=self.beta1,
beta2=self.beta2,
epsilon=self.epsilon,
)
if self.use_lr_adaption:
net.LearningRateAdaption(
[lr, grad, effective_grad],
[lr],
lr_alpha=self.lr_alpha,
normalized_lr_adaption=self.normalized_lr_adaption,
)
def scale_learning_rate(self, scale):
self.alpha *= scale
return
class DecayAdagradOptimizer(Optimizer):
def __init__(
self,
alpha=0.01,
beta1=0.0,
beta2=0.999,
epsilon=0.1,
weight_decay=0.0,
ema_options=None,
bias_correction_first=True,
policy="fixed",
engine="",
**kwargs
):
super(DecayAdagradOptimizer, self).__init__()
self.alpha = alpha
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.weight_decay = weight_decay
self.bias_correction_first = bias_correction_first
self.policy = policy
self.engine = engine
self.init_kwargs = kwargs
self._process_ema_options(ema_options)
def _process_ema_options(self, ema_options):
self.ema_enabled = True if ema_options else False
if self.ema_enabled:
self.ema_start = ema_options.get("ema_start", None)
self.ema_end = ema_options.get("ema_end", None)
self.ema_step = ema_options.get("ema_step", None)
self.ema_alpha = ema_options.get("ema_alpha", None)
def _run(self, net, param_init_net, param_info):
param = param_info.blob
grad = param_info.grad
if self.alpha <= 0:
return
lr, iteration = self.build_lr(
net,
param_init_net,
base_learning_rate=self.alpha,
policy=self.policy,
**(self.init_kwargs)
)
if isinstance(grad, core.GradientSlice):
# hack for position weighted.
param_squared_sum = param_init_net.ConstantFill([param], param + "_squared_sum", value=0.0)
self._aux_params.local.append(param_squared_sum)
output_blobs = [param, param_squared_sum]
net.SparseAdagrad(
[param, param_squared_sum, grad.indices, grad.values, lr],
output_blobs,
epsilon=self.epsilon,
)
else:
m1 = param_init_net.ConstantFill([param], param + "_first_mo1ment", value=0.0)
m2 = param_init_net.ConstantFill([param], param + "_second_moment", value=0.0)
self._aux_params.shared.append(iteration)
self._aux_params.local.append(m1)
self._aux_params.local.append(m2)
output_blobs = [param, m1, m2]
net.DecayAdagrad(
[param, m1, m2, grad, lr, iteration],
output_blobs,
beta1=self.beta1,
beta2=self.beta2,
epsilon=self.epsilon,
weight_decay=self.weight_decay,
bias_correction_first=self.bias_correction_first,
)
if self.ema_enabled:
param_ema = str(param) + "_ema"
if not param_init_net.BlobIsDefined(param_ema):
param_init_net.ConstantFill([param], param_ema, value=0.0)
self._aux_params.local.append(param_ema)
net.EMA(
[param, param_ema, iteration],
[param, param_ema],
ema_start=self.ema_start,
ema_end=self.ema_end,
ema_step=self.ema_step,
ema_alpha=self.ema_alpha,
)
def scale_learning_rate(self, scale):
self.alpha *= scale
return
class YellowFinOptimizer(Optimizer):
"""YellowFin: An automatic tuner for momentum SGD
See https://arxiv.org/abs/1706.03471 for more details. This implementation
has separate learning rate and momentum per each parameter."""
def __init__(
self,
alpha=0.1,
mu=0.0,
beta=0.999,
curv_win_width=20,
zero_debias=True,
epsilon=0.1 ** 6,
policy="fixed",
sparse_dedup_aggregator=None,
**kwargs
):
super(YellowFinOptimizer, self).__init__()
self.alpha = alpha
self.mu = mu
self.beta = beta
self.curv_win_width = curv_win_width
self.zero_debias = zero_debias
self.epsilon = epsilon
self.policy = policy
self.sparse_dedup_aggregator = sparse_dedup_aggregator
self.init_kwargs = kwargs
def _run(self, net, param_init_net, param_info):
# Note: This is number of persistent scalars in YellowFin optimizer.
# It should always be the number of scalars being used. The same
# number should be used in class for the operation.
SCALARS_MEMORY_SIZE = 5
param = param_info.blob
grad = param_info.grad
moment = param_init_net.ConstantFill([param], param + "_moment", value=0.0)
curv_win = param_init_net.ConstantFill(
[], param + "_curv_win", shape=[self.curv_win_width], value=0.0
)
g_avg = param_init_net.ConstantFill([param], param + "_g_avg", value=0.0)
g2_avg = param_init_net.ConstantFill([param], param + "_g2_avg", value=0.0)
lr_avg = param_init_net.ConstantFill(
[], param + "_lr_avg", shape=[1], value=self.alpha
)
mu_avg = param_init_net.ConstantFill(
[], param + "_mu_avg", shape=[1], value=self.mu
)
scalars_memory = param_init_net.ConstantFill(
[], param + "_scalars_memory", shape=[SCALARS_MEMORY_SIZE], value=0.0
)
assert self.alpha > 0
assert not isinstance(
grad, core.GradientSlice
), "YellowFin does not support sparse gradients"
iteration = utils.BuildUniqueMutexIter(param_init_net, net, iter_val=0)
self._aux_params.shared.append(iteration)
self._aux_params.local.append(moment)
self._aux_params.local.append(lr_avg)
self._aux_params.local.append(mu_avg)
self._aux_params.local.append(curv_win)
self._aux_params.local.append(g_avg)
self._aux_params.local.append(g2_avg)
self._aux_params.local.append(scalars_memory)
yf_in_out_args = [
param,
moment,
lr_avg,
mu_avg,
curv_win,
g_avg,
g2_avg,
scalars_memory,
]
net.YellowFin(
yf_in_out_args + [grad, iteration],
yf_in_out_args,
beta=self.beta,
epsilon=self.epsilon,
curv_win_width=self.curv_win_width,
zero_debias=self.zero_debias,
)
def scale_learning_rate(self, scale):
self.alpha *= scale
return
class RmsPropOptimizer(Optimizer):
def __init__(
self,
alpha=0.01,
decay=0.9,
momentum=0.0,
epsilon=1e-5,
policy="fixed",
engine="",
**kwargs
):
super(RmsPropOptimizer, self).__init__()
self.alpha = alpha
self.decay = decay
self.momentum = momentum
self.epsilon = epsilon
self.policy = policy
self.engine = engine
self.init_kwargs = kwargs
def _run(self, net, param_init_net, param_info):
param = param_info.blob
grad = param_info.grad
assert self.alpha > 0
assert not isinstance(
grad, core.GradientSlice
), "RmsPropOptimizer doesn't support sparse gradients"
dev = scope.CurrentDeviceScope()
if dev is None:
dev = core.DeviceOption(caffe2_pb2.CPU)
ONE = param_init_net.ConstantFill(
[], "ONE_{}_{}".format(dev.device_type, dev.device_id), shape=[1], value=1.0
)
lr, _ = self.build_lr(
net,
param_init_net,
base_learning_rate=-self.alpha,
policy=self.policy,
**(self.init_kwargs)
)
grad_o = param_init_net.ConstantFill(
[param], str(param) + "_grad_o", values=0.0
)
ms = param_init_net.ConstantFill(
[param], str(param) + "_mean_squares", values=0.0
)
mom = param_init_net.ConstantFill([param], str(param) + "_momentum", values=0.0)
self._aux_params.local.append(ms)
self._aux_params.local.append(mom)
net.RmsProp(
[grad, ms, mom, ONE],
[grad_o, ms, mom],
decay=self.decay,
momentum=self.momentum,
epsilon=self.epsilon,
engine=self.engine,
)
net.MomentumSGDUpdate([grad_o, mom, lr, param], [grad_o, mom, param])
def scale_learning_rate(self, scale):
self.alpha *= scale
return
def _get_param_to_device(model):
# Infer blob devices by going through the net and param_init_net
# ops and observing the device used to create or use the blob.
param_to_device = core.InferBlobDevices(model.net)
param_to_device.update(core.InferBlobDevices(model.param_init_net))
return param_to_device
def get_param_device(param_name, grad, param_to_device=None, default_device=None):
device = default_device
param_to_device = param_to_device or {}
# We first check if parameter's device has been inferred. If not,
# we check the gradient. This can happen if parameter is not output
# by any blob but created by a FetchBlob.
if param_name in param_to_device:
device = param_to_device[param_name]
else:
if isinstance(grad, core.GradientSlice):
grad = grad
if str(grad.values) in param_to_device:
device = param_to_device[str(grad.values)]
elif str(grad.indices) in param_to_device:
device = param_to_device[str(grad.indices)]
else:
grad_name = str(grad)
if grad_name in param_to_device:
device = param_to_device[grad_name]
assert device is not None, "Cannot infer device for {}: no op creates it".format(
param_name
)
return device
def get_lr_injection():
"""
Gets current value for lr_injection, a multiplier for all base
learning rates.
Must set allow_lr_injection=True when building optimizer, as it
relies on synchronization over CPU.
"""
return workspace.FetchBlob(_LEARNING_RATE_INJECTION)
def set_lr_injection(lr_injection_value):
"""
Sets lr_injection, a multiplier for all base learning rates.
Must set allow_lr_injection=True when building optimizer, as it
relies on synchronization over CPU.
"""
workspace.FeedBlob(
_LEARNING_RATE_INJECTION,
np.array([float(lr_injection_value)], dtype=np.float32),
)
def _calc_norm_ratio(model, params, name_scope, param_to_device, max_gradient_norm):
with core.NameScope(name_scope):
grad_squared_sums = []
for i, param in enumerate(params):
device = get_param_device(str(param.blob), param.grad, param_to_device)
with core.DeviceScope(device):
grad = (
param.grad
if not isinstance(param.grad, core.GradientSlice)
else param.grad.values
)
grad_squared_sum_name = "grad_{}_squared_sum".format(i)
grad_squared_sum = model.net.SumSqrElements(grad, grad_squared_sum_name)
grad_squared_sum_cpu = model.net.EnsureCPUOutput(grad_squared_sum)
grad_squared_sums.append(grad_squared_sum_cpu)
with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
grad_squared_full_sum = model.net.Sum(
grad_squared_sums, "grad_squared_full_sum"
)
global_norm = model.net.Pow(
grad_squared_full_sum, "global_norm", exponent=0.5
)
clip_norm = model.param_init_net.ConstantFill(
[], "clip_norm", shape=[], value=float(max_gradient_norm)
)
max_norm = model.net.Max([global_norm, clip_norm], "max_norm")
norm_ratio = model.net.Div([clip_norm, max_norm], "norm_ratio")
return norm_ratio
def _build(
model,
optimizer,
weights_only=False,
use_param_info_optim=True,
max_gradient_norm=None,
allow_lr_injection=False,
):
param_to_device = _get_param_to_device(model)
# Validate there are no duplicate params
model.Validate()
params = []
for param_info in model.GetOptimizationParamInfo():
if weights_only and param_info.blob not in model.weights:
continue
params.append(param_info)
lr_multiplier = None
if max_gradient_norm is not None:
lr_multiplier = _calc_norm_ratio(
model,
params,
"norm_clipped_grad_update",
param_to_device,
max_gradient_norm,
)
if allow_lr_injection:
if not model.net.BlobIsDefined(_LEARNING_RATE_INJECTION):
lr_injection = model.param_init_net.ConstantFill(
[], _LEARNING_RATE_INJECTION, shape=[1], value=1.0
)
else:
lr_injection = _LEARNING_RATE_INJECTION
if lr_multiplier is None:
lr_multiplier = lr_injection
else:
lr_multiplier = model.net.Mul(
[lr_multiplier, lr_injection], "lr_multiplier", broadcast=1
)
optimizer.add_lr_multiplier(lr_multiplier)
for param_info in params:
param_name = str(param_info.blob)
device = get_param_device(param_name, param_info.grad, param_to_device)
with core.DeviceScope(device):
if param_info.optimizer and use_param_info_optim:
param_info.optimizer(model.net, model.param_init_net, param_info)
else:
optimizer(model.net, model.param_init_net, param_info)
return optimizer
def add_weight_decay(model, weight_decay):
"""Adds a decay to weights in the model.
This is a form of L2 regularization.
Args:
weight_decay: strength of the regularization
"""
_build(
model,
WeightDecayBuilder(weight_decay=weight_decay),
weights_only=True,
use_param_info_optim=False,
)
def build_sgd(
model,
base_learning_rate,
max_gradient_norm=None,
allow_lr_injection=False,
**kwargs
):
sgd_optimizer = SgdOptimizer(base_learning_rate, **kwargs)
return _build(
model,
sgd_optimizer,
max_gradient_norm=max_gradient_norm,
allow_lr_injection=allow_lr_injection,
)
def build_multi_precision_sgd(
model,
base_learning_rate,
max_gradient_norm=None,
allow_lr_injection=False,
**kwargs
):
multi_prec_sgd_optimizer = MultiPrecisionSgdOptimizer(base_learning_rate, **kwargs)
return _build(
model,
multi_prec_sgd_optimizer,
max_gradient_norm=max_gradient_norm,
allow_lr_injection=allow_lr_injection,
)
def build_fp16_sgd(model, base_learning_rate, **kwargs):
fp16_sgd_optimizer = FP16SgdOptimizer(base_learning_rate, **kwargs)
return _build(model, fp16_sgd_optimizer)
def build_ftrl(model, engine="SIMD", **kwargs):
if engine == "SIMD":
assert core.IsOperator("Ftrl_ENGINE_SIMD")
assert core.IsOperator("SparseFtrl_ENGINE_SIMD")
ftrl_optimizer = FtrlOptimizer(engine=engine, **kwargs)
return _build(model, ftrl_optimizer)
def build_gftrl(model, engine="", **kwargs):
if engine == "SIMD":
assert core.IsOperator("GFtrl_ENGINE_SIMD")
gftrl_optimizer = GFtrlOptimizer(engine=engine, **kwargs)
return _build(model, gftrl_optimizer)
def build_adagrad(
model,
base_learning_rate,
parameters=None,
max_gradient_norm=None,
allow_lr_injection=False,
**kwargs
):
adagrad_optimizer = AdagradOptimizer(alpha=base_learning_rate, **kwargs)
return _build(
model,
adagrad_optimizer,
max_gradient_norm=max_gradient_norm,
allow_lr_injection=allow_lr_injection,
)
def build_wngrad(
model,
base_learning_rate,
parameters=None,
max_gradient_norm=None,
allow_lr_injection=False,
**kwargs
):
wngrad_optimizer = WngradOptimizer(alpha=base_learning_rate, **kwargs)
return _build(
model,
wngrad_optimizer,
max_gradient_norm=max_gradient_norm,
allow_lr_injection=allow_lr_injection,
)
def build_storm(
model,
base_learning_rate,
parameters=None,
max_gradient_norm=None,
allow_lr_injection=False,
**kwargs
):
storm_optimizer = StormOptimizer(lr=base_learning_rate, **kwargs)
return _build(
model,
storm_optimizer,
max_gradient_norm=max_gradient_norm,
allow_lr_injection=allow_lr_injection,
)
def build_adadelta(
model,
base_learning_rate,
parameters=None,
max_gradient_norm=None,
allow_lr_injection=False,
**kwargs
):
adadelta_optimizer = AdadeltaOptimizer(alpha=base_learning_rate, **kwargs)
return _build(
model,
adadelta_optimizer,
max_gradient_norm=max_gradient_norm,
allow_lr_injection=allow_lr_injection,
)
def build_adam(
model,
base_learning_rate,
max_gradient_norm=None,
allow_lr_injection=False,
**kwargs
):
adam_optimizer = AdamOptimizer(alpha=base_learning_rate, **kwargs)
return _build(
model,
adam_optimizer,
max_gradient_norm=max_gradient_norm,
allow_lr_injection=allow_lr_injection,
)
def build_decay_adagrad(
model,
base_learning_rate,
max_gradient_norm=None,
allow_lr_injection=False,
**kwargs
):
decay_adagrad_optimizer = DecayAdagradOptimizer(alpha=base_learning_rate, **kwargs)
return _build(
model,
decay_adagrad_optimizer,
max_gradient_norm=max_gradient_norm,
allow_lr_injection=allow_lr_injection,
)
def build_yellowfin(model, base_learning_rate=0.1, **kwargs):
yellowfin_optimizer = YellowFinOptimizer(alpha=base_learning_rate, **kwargs)
return _build(model, yellowfin_optimizer)
def build_rms_prop(
model,
base_learning_rate,
max_gradient_norm=None,
allow_lr_injection=False,
**kwargs
):
rms_prop_optimizer = RmsPropOptimizer(alpha=base_learning_rate, **kwargs)
return _build(
model,
rms_prop_optimizer,
max_gradient_norm=max_gradient_norm,
allow_lr_injection=allow_lr_injection,
)
|