1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 2440 2441 2442 2443 2444 2445 2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462 2463 2464 2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508 2509 2510 2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532 2533 2534 2535 2536 2537 2538 2539 2540 2541 2542 2543 2544 2545 2546 2547 2548 2549 2550 2551 2552 2553 2554 2555 2556 2557 2558 2559 2560 2561 2562 2563 2564 2565 2566 2567 2568 2569 2570 2571 2572 2573 2574 2575 2576 2577 2578 2579 2580 2581 2582 2583 2584 2585 2586 2587 2588 2589 2590 2591 2592 2593 2594 2595 2596 2597 2598 2599 2600 2601 2602 2603 2604 2605 2606 2607 2608
|
# -*- coding: utf-8 -*-
"""
Multi-lib backend for POT
The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch,
Jax, Cupy, or Tensorflow, POT code should work nonetheless.
To achieve that, POT provides backend classes which implements functions in their respective backend
imitating Numpy API. As a convention, we use nx instead of np to refer to the backend.
Examples
--------
>>> from ot.utils import list_to_array
>>> from ot.backend import get_backend
>>> def f(a, b): # the function does not know which backend to use
... a, b = list_to_array(a, b) # if a list in given, make it an array
... nx = get_backend(a, b) # infer the backend from the arguments
... c = nx.dot(a, b) # now use the backend to do any calculation
... return c
.. warning::
Tensorflow only works with the Numpy API. To activate it, please run the following:
.. code-block::
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()
Performance
--------
- CPU: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz
- GPU: Tesla V100-SXM2-32GB
- Date of the benchmark: December 8th, 2021
- Commit of benchmark: PR #316, https://github.com/PythonOT/POT/pull/316
.. raw:: html
<style>
#perftable {
width: 100%;
margin-bottom: 1em;
}
#perftable table{
border-collapse: collapse;
table-layout: fixed;
width: 100%;
}
#perftable th, #perftable td {
border: 1px solid #ddd;
padding: 8px;
font-size: smaller;
}
</style>
<div id="perftable">
<table>
<tr><th align="center" colspan="8">Sinkhorn Knopp - Averaged on 100 runs</th></tr>
<tr><th align="center">Bitsize</th><th align="center" colspan="7">32 bits</th></tr>
<tr><th align="center">Device</th><th align="center" colspan="3.0"">CPU</th><th align="center" colspan="4.0">GPU</tr>
<tr><th align="center">Sample size</th><th align="center">Numpy</th><th align="center">Pytorch</th><th align="center">Tensorflow</th><th align="center">Cupy</th><th align="center">Jax</th><th align="center">Pytorch</th><th align="center">Tensorflow</th></tr>
<tr><td align="center">50</td><td align="center">0.0008</td><td align="center">0.0022</td><td align="center">0.0151</td><td align="center">0.0095</td><td align="center">0.0193</td><td align="center">0.0051</td><td align="center">0.0293</td></tr>
<tr><td align="center">100</td><td align="center">0.0005</td><td align="center">0.0013</td><td align="center">0.0097</td><td align="center">0.0057</td><td align="center">0.0115</td><td align="center">0.0029</td><td align="center">0.0173</td></tr>
<tr><td align="center">500</td><td align="center">0.0009</td><td align="center">0.0016</td><td align="center">0.0110</td><td align="center">0.0058</td><td align="center">0.0115</td><td align="center">0.0029</td><td align="center">0.0166</td></tr>
<tr><td align="center">1000</td><td align="center">0.0021</td><td align="center">0.0021</td><td align="center">0.0145</td><td align="center">0.0056</td><td align="center">0.0118</td><td align="center">0.0029</td><td align="center">0.0168</td></tr>
<tr><td align="center">2000</td><td align="center">0.0069</td><td align="center">0.0043</td><td align="center">0.0278</td><td align="center">0.0059</td><td align="center">0.0118</td><td align="center">0.0030</td><td align="center">0.0165</td></tr>
<tr><td align="center">5000</td><td align="center">0.0707</td><td align="center">0.0314</td><td align="center">0.1395</td><td align="center">0.0074</td><td align="center">0.0125</td><td align="center">0.0035</td><td align="center">0.0198</td></tr>
<tr><td colspan="8"> </td></tr>
<tr><th align="center">Bitsize</th><th align="center" colspan="7">64 bits</th></tr>
<tr><th align="center">Device</th><th align="center" colspan="3.0"">CPU</th><th align="center" colspan="4.0">GPU</tr>
<tr><th align="center">Sample size</th><th align="center">Numpy</th><th align="center">Pytorch</th><th align="center">Tensorflow</th><th align="center">Cupy</th><th align="center">Jax</th><th align="center">Pytorch</th><th align="center">Tensorflow</th></tr>
<tr><td align="center">50</td><td align="center">0.0008</td><td align="center">0.0020</td><td align="center">0.0154</td><td align="center">0.0093</td><td align="center">0.0191</td><td align="center">0.0051</td><td align="center">0.0328</td></tr>
<tr><td align="center">100</td><td align="center">0.0005</td><td align="center">0.0013</td><td align="center">0.0094</td><td align="center">0.0056</td><td align="center">0.0114</td><td align="center">0.0029</td><td align="center">0.0169</td></tr>
<tr><td align="center">500</td><td align="center">0.0013</td><td align="center">0.0017</td><td align="center">0.0120</td><td align="center">0.0059</td><td align="center">0.0116</td><td align="center">0.0029</td><td align="center">0.0168</td></tr>
<tr><td align="center">1000</td><td align="center">0.0034</td><td align="center">0.0027</td><td align="center">0.0177</td><td align="center">0.0058</td><td align="center">0.0118</td><td align="center">0.0029</td><td align="center">0.0167</td></tr>
<tr><td align="center">2000</td><td align="center">0.0146</td><td align="center">0.0075</td><td align="center">0.0436</td><td align="center">0.0059</td><td align="center">0.0120</td><td align="center">0.0029</td><td align="center">0.0165</td></tr>
<tr><td align="center">5000</td><td align="center">0.1467</td><td align="center">0.0568</td><td align="center">0.2468</td><td align="center">0.0077</td><td align="center">0.0146</td><td align="center">0.0045</td><td align="center">0.0204</td></tr>
</table>
</div>
"""
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
# Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License
import numpy as np
import scipy
import scipy.linalg
import scipy.special as special
from scipy.sparse import issparse, coo_matrix, csr_matrix
import warnings
import time
try:
import torch
torch_type = torch.Tensor
except ImportError:
torch = False
torch_type = float
try:
import jax
import jax.numpy as jnp
import jax.scipy.special as jspecial
from jax.lib import xla_bridge
jax_type = jax.numpy.ndarray
except ImportError:
jax = False
jax_type = float
try:
import cupy as cp
import cupyx
cp_type = cp.ndarray
except ImportError:
cp = False
cp_type = float
try:
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
tf_type = tf.Tensor
except ImportError:
tf = False
tf_type = float
str_type_error = "All array should be from the same type/backend. Current types are : {}"
def get_backend_list():
"""Returns the list of available backends"""
lst = [NumpyBackend(), ]
if torch:
lst.append(TorchBackend())
if jax:
lst.append(JaxBackend())
if cp: # pragma: no cover
lst.append(CupyBackend())
if tf:
lst.append(TensorflowBackend())
return lst
def get_backend(*args):
"""Returns the proper backend for a list of input arrays
Also raises TypeError if all arrays are not from the same backend
"""
# check that some arrays given
if not len(args) > 0:
raise ValueError(" The function takes at least one parameter")
# check all same type
if not len(set(type(a) for a in args)) == 1:
raise ValueError(str_type_error.format([type(a) for a in args]))
if isinstance(args[0], np.ndarray):
return NumpyBackend()
elif isinstance(args[0], torch_type):
return TorchBackend()
elif isinstance(args[0], jax_type):
return JaxBackend()
elif isinstance(args[0], cp_type): # pragma: no cover
return CupyBackend()
elif isinstance(args[0], tf_type):
return TensorflowBackend()
else:
raise ValueError("Unknown type of non implemented backend.")
def to_numpy(*args):
"""Returns numpy arrays from any compatible backend"""
if len(args) == 1:
return get_backend(args[0]).to_numpy(args[0])
else:
return [get_backend(a).to_numpy(a) for a in args]
class Backend():
"""
Backend abstract class.
Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`,
:py:class:`CupyBackend`, :py:class:`TensorflowBackend`
- The `__name__` class attribute refers to the name of the backend.
- The `__type__` class attribute refers to the data structure used by the backend.
"""
__name__ = None
__type__ = None
__type_list__ = None
rng_ = None
def __str__(self):
return self.__name__
# convert batch of tensors to numpy
def to_numpy(self, *arrays):
"""Returns the numpy version of tensors"""
if len(arrays) == 1:
return self._to_numpy(arrays[0])
else:
return [self._to_numpy(array) for array in arrays]
# convert a tensor to numpy
def _to_numpy(self, a):
"""Returns the numpy version of a tensor"""
raise NotImplementedError()
# convert batch of arrays from numpy
def from_numpy(self, *arrays, type_as=None):
"""Creates tensors cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)"""
if len(arrays) == 1:
return self._from_numpy(arrays[0], type_as=type_as)
else:
return [self._from_numpy(array, type_as=type_as) for array in arrays]
# convert an array from numpy
def _from_numpy(self, a, type_as=None):
"""Creates a tensor cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)"""
raise NotImplementedError()
def set_gradients(self, val, inputs, grads):
"""Define the gradients for the value val wrt the inputs """
raise NotImplementedError()
def zeros(self, shape, type_as=None):
r"""
Creates a tensor full of zeros.
This function follows the api from :any:`numpy.zeros`
See: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html
"""
raise NotImplementedError()
def ones(self, shape, type_as=None):
r"""
Creates a tensor full of ones.
This function follows the api from :any:`numpy.ones`
See: https://numpy.org/doc/stable/reference/generated/numpy.ones.html
"""
raise NotImplementedError()
def arange(self, stop, start=0, step=1, type_as=None):
r"""
Returns evenly spaced values within a given interval.
This function follows the api from :any:`numpy.arange`
See: https://numpy.org/doc/stable/reference/generated/numpy.arange.html
"""
raise NotImplementedError()
def full(self, shape, fill_value, type_as=None):
r"""
Creates a tensor with given shape, filled with given value.
This function follows the api from :any:`numpy.full`
See: https://numpy.org/doc/stable/reference/generated/numpy.full.html
"""
raise NotImplementedError()
def eye(self, N, M=None, type_as=None):
r"""
Creates the identity matrix of given size.
This function follows the api from :any:`numpy.eye`
See: https://numpy.org/doc/stable/reference/generated/numpy.eye.html
"""
raise NotImplementedError()
def sum(self, a, axis=None, keepdims=False):
r"""
Sums tensor elements over given dimensions.
This function follows the api from :any:`numpy.sum`
See: https://numpy.org/doc/stable/reference/generated/numpy.sum.html
"""
raise NotImplementedError()
def cumsum(self, a, axis=None):
r"""
Returns the cumulative sum of tensor elements over given dimensions.
This function follows the api from :any:`numpy.cumsum`
See: https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html
"""
raise NotImplementedError()
def max(self, a, axis=None, keepdims=False):
r"""
Returns the maximum of an array or maximum along given dimensions.
This function follows the api from :any:`numpy.amax`
See: https://numpy.org/doc/stable/reference/generated/numpy.amax.html
"""
raise NotImplementedError()
def min(self, a, axis=None, keepdims=False):
r"""
Returns the maximum of an array or maximum along given dimensions.
This function follows the api from :any:`numpy.amin`
See: https://numpy.org/doc/stable/reference/generated/numpy.amin.html
"""
raise NotImplementedError()
def maximum(self, a, b):
r"""
Returns element-wise maximum of array elements.
This function follows the api from :any:`numpy.maximum`
See: https://numpy.org/doc/stable/reference/generated/numpy.maximum.html
"""
raise NotImplementedError()
def minimum(self, a, b):
r"""
Returns element-wise minimum of array elements.
This function follows the api from :any:`numpy.minimum`
See: https://numpy.org/doc/stable/reference/generated/numpy.minimum.html
"""
raise NotImplementedError()
def dot(self, a, b):
r"""
Returns the dot product of two tensors.
This function follows the api from :any:`numpy.dot`
See: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
"""
raise NotImplementedError()
def abs(self, a):
r"""
Computes the absolute value element-wise.
This function follows the api from :any:`numpy.absolute`
See: https://numpy.org/doc/stable/reference/generated/numpy.absolute.html
"""
raise NotImplementedError()
def exp(self, a):
r"""
Computes the exponential value element-wise.
This function follows the api from :any:`numpy.exp`
See: https://numpy.org/doc/stable/reference/generated/numpy.exp.html
"""
raise NotImplementedError()
def log(self, a):
r"""
Computes the natural logarithm, element-wise.
This function follows the api from :any:`numpy.log`
See: https://numpy.org/doc/stable/reference/generated/numpy.log.html
"""
raise NotImplementedError()
def sqrt(self, a):
r"""
Returns the non-ngeative square root of a tensor, element-wise.
This function follows the api from :any:`numpy.sqrt`
See: https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html
"""
raise NotImplementedError()
def power(self, a, exponents):
r"""
First tensor elements raised to powers from second tensor, element-wise.
This function follows the api from :any:`numpy.power`
See: https://numpy.org/doc/stable/reference/generated/numpy.power.html
"""
raise NotImplementedError()
def norm(self, a):
r"""
Computes the matrix frobenius norm.
This function follows the api from :any:`numpy.linalg.norm`
See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html
"""
raise NotImplementedError()
def any(self, a):
r"""
Tests whether any tensor element along given dimensions evaluates to True.
This function follows the api from :any:`numpy.any`
See: https://numpy.org/doc/stable/reference/generated/numpy.any.html
"""
raise NotImplementedError()
def isnan(self, a):
r"""
Tests element-wise for NaN and returns result as a boolean tensor.
This function follows the api from :any:`numpy.isnan`
See: https://numpy.org/doc/stable/reference/generated/numpy.isnan.html
"""
raise NotImplementedError()
def isinf(self, a):
r"""
Tests element-wise for positive or negative infinity and returns result as a boolean tensor.
This function follows the api from :any:`numpy.isinf`
See: https://numpy.org/doc/stable/reference/generated/numpy.isinf.html
"""
raise NotImplementedError()
def einsum(self, subscripts, *operands):
r"""
Evaluates the Einstein summation convention on the operands.
This function follows the api from :any:`numpy.einsum`
See: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
"""
raise NotImplementedError()
def sort(self, a, axis=-1):
r"""
Returns a sorted copy of a tensor.
This function follows the api from :any:`numpy.sort`
See: https://numpy.org/doc/stable/reference/generated/numpy.sort.html
"""
raise NotImplementedError()
def argsort(self, a, axis=None):
r"""
Returns the indices that would sort a tensor.
This function follows the api from :any:`numpy.argsort`
See: https://numpy.org/doc/stable/reference/generated/numpy.argsort.html
"""
raise NotImplementedError()
def searchsorted(self, a, v, side='left'):
r"""
Finds indices where elements should be inserted to maintain order in given tensor.
This function follows the api from :any:`numpy.searchsorted`
See: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html
"""
raise NotImplementedError()
def flip(self, a, axis=None):
r"""
Reverses the order of elements in a tensor along given dimensions.
This function follows the api from :any:`numpy.flip`
See: https://numpy.org/doc/stable/reference/generated/numpy.flip.html
"""
raise NotImplementedError()
def clip(self, a, a_min, a_max):
"""
Limits the values in a tensor.
This function follows the api from :any:`numpy.clip`
See: https://numpy.org/doc/stable/reference/generated/numpy.clip.html
"""
raise NotImplementedError()
def repeat(self, a, repeats, axis=None):
r"""
Repeats elements of a tensor.
This function follows the api from :any:`numpy.repeat`
See: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html
"""
raise NotImplementedError()
def take_along_axis(self, arr, indices, axis):
r"""
Gathers elements of a tensor along given dimensions.
This function follows the api from :any:`numpy.take_along_axis`
See: https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html
"""
raise NotImplementedError()
def concatenate(self, arrays, axis=0):
r"""
Joins a sequence of tensors along an existing dimension.
This function follows the api from :any:`numpy.concatenate`
See: https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html
"""
raise NotImplementedError()
def zero_pad(self, a, pad_width):
r"""
Pads a tensor.
This function follows the api from :any:`numpy.pad`
See: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
"""
raise NotImplementedError()
def argmax(self, a, axis=None):
r"""
Returns the indices of the maximum values of a tensor along given dimensions.
This function follows the api from :any:`numpy.argmax`
See: https://numpy.org/doc/stable/reference/generated/numpy.argmax.html
"""
raise NotImplementedError()
def argmin(self, a, axis=None):
r"""
Returns the indices of the minimum values of a tensor along given dimensions.
This function follows the api from :any:`numpy.argmin`
See: https://numpy.org/doc/stable/reference/generated/numpy.argmin.html
"""
raise NotImplementedError()
def mean(self, a, axis=None):
r"""
Computes the arithmetic mean of a tensor along given dimensions.
This function follows the api from :any:`numpy.mean`
See: https://numpy.org/doc/stable/reference/generated/numpy.mean.html
"""
raise NotImplementedError()
def std(self, a, axis=None):
r"""
Computes the standard deviation of a tensor along given dimensions.
This function follows the api from :any:`numpy.std`
See: https://numpy.org/doc/stable/reference/generated/numpy.std.html
"""
raise NotImplementedError()
def linspace(self, start, stop, num):
r"""
Returns a specified number of evenly spaced values over a given interval.
This function follows the api from :any:`numpy.linspace`
See: https://numpy.org/doc/stable/reference/generated/numpy.linspace.html
"""
raise NotImplementedError()
def meshgrid(self, a, b):
r"""
Returns coordinate matrices from coordinate vectors (Numpy convention).
This function follows the api from :any:`numpy.meshgrid`
See: https://numpy.org/doc/stable/reference/generated/numpy.meshgrid.html
"""
raise NotImplementedError()
def diag(self, a, k=0):
r"""
Extracts or constructs a diagonal tensor.
This function follows the api from :any:`numpy.diag`
See: https://numpy.org/doc/stable/reference/generated/numpy.diag.html
"""
raise NotImplementedError()
def unique(self, a):
r"""
Finds unique elements of given tensor.
This function follows the api from :any:`numpy.unique`
See: https://numpy.org/doc/stable/reference/generated/numpy.unique.html
"""
raise NotImplementedError()
def logsumexp(self, a, axis=None):
r"""
Computes the log of the sum of exponentials of input elements.
This function follows the api from :any:`scipy.special.logsumexp`
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.logsumexp.html
"""
raise NotImplementedError()
def stack(self, arrays, axis=0):
r"""
Joins a sequence of tensors along a new dimension.
This function follows the api from :any:`numpy.stack`
See: https://numpy.org/doc/stable/reference/generated/numpy.stack.html
"""
raise NotImplementedError()
def outer(self, a, b):
r"""
Computes the outer product between two vectors.
This function follows the api from :any:`numpy.outer`
See: https://numpy.org/doc/stable/reference/generated/numpy.outer.html
"""
raise NotImplementedError()
def reshape(self, a, shape):
r"""
Gives a new shape to a tensor without changing its data.
This function follows the api from :any:`numpy.reshape`
See: https://numpy.org/doc/stable/reference/generated/numpy.reshape.html
"""
raise NotImplementedError()
def seed(self, seed=None):
r"""
Sets the seed for the random generator.
This function follows the api from :any:`numpy.random.seed`
See: https://numpy.org/doc/stable/reference/generated/numpy.random.seed.html
"""
raise NotImplementedError()
def rand(self, *size, type_as=None):
r"""
Generate uniform random numbers.
This function follows the api from :any:`numpy.random.rand`
See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html
"""
raise NotImplementedError()
def randn(self, *size, type_as=None):
r"""
Generate normal Gaussian random numbers.
This function follows the api from :any:`numpy.random.rand`
See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html
"""
raise NotImplementedError()
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
r"""
Creates a sparse tensor in COOrdinate format.
This function follows the api from :any:`scipy.sparse.coo_matrix`
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html
"""
raise NotImplementedError()
def issparse(self, a):
r"""
Checks whether or not the input tensor is a sparse tensor.
This function follows the api from :any:`scipy.sparse.issparse`
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.issparse.html
"""
raise NotImplementedError()
def tocsr(self, a):
r"""
Converts this matrix to Compressed Sparse Row format.
This function follows the api from :any:`scipy.sparse.coo_matrix.tocsr`
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.tocsr.html
"""
raise NotImplementedError()
def eliminate_zeros(self, a, threshold=0.):
r"""
Removes entries smaller than the given threshold from the sparse tensor.
This function follows the api from :any:`scipy.sparse.csr_matrix.eliminate_zeros`
See: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.csr_matrix.eliminate_zeros.html
"""
raise NotImplementedError()
def todense(self, a):
r"""
Converts a sparse tensor to a dense tensor.
This function follows the api from :any:`scipy.sparse.csr_matrix.toarray`
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.toarray.html
"""
raise NotImplementedError()
def where(self, condition, x, y):
r"""
Returns elements chosen from x or y depending on condition.
This function follows the api from :any:`numpy.where`
See: https://numpy.org/doc/stable/reference/generated/numpy.where.html
"""
raise NotImplementedError()
def copy(self, a):
r"""
Returns a copy of the given tensor.
This function follows the api from :any:`numpy.copy`
See: https://numpy.org/doc/stable/reference/generated/numpy.copy.html
"""
raise NotImplementedError()
def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
r"""
Returns True if two arrays are element-wise equal within a tolerance.
This function follows the api from :any:`numpy.allclose`
See: https://numpy.org/doc/stable/reference/generated/numpy.allclose.html
"""
raise NotImplementedError()
def dtype_device(self, a):
r"""
Returns the dtype and the device of the given tensor.
"""
raise NotImplementedError()
def assert_same_dtype_device(self, a, b):
r"""
Checks whether or not the two given inputs have the same dtype as well as the same device
"""
raise NotImplementedError()
def squeeze(self, a, axis=None):
r"""
Remove axes of length one from a.
This function follows the api from :any:`numpy.squeeze`.
See: https://numpy.org/doc/stable/reference/generated/numpy.squeeze.html
"""
raise NotImplementedError()
def bitsize(self, type_as):
r"""
Gives the number of bits used by the data type of the given tensor.
"""
raise NotImplementedError()
def device_type(self, type_as):
r"""
Returns CPU or GPU depending on the device where the given tensor is located.
"""
raise NotImplementedError()
def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
r"""
Executes a benchmark of the given callable with the given arguments.
"""
raise NotImplementedError()
def solve(self, a, b):
r"""
Solves a linear matrix equation, or system of linear scalar equations.
This function follows the api from :any:`numpy.linalg.solve`.
See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.solve.html
"""
raise NotImplementedError()
def trace(self, a):
r"""
Returns the sum along diagonals of the array.
This function follows the api from :any:`numpy.trace`.
See: https://numpy.org/doc/stable/reference/generated/numpy.trace.html
"""
raise NotImplementedError()
def inv(self, a):
r"""
Computes the inverse of a matrix.
This function follows the api from :any:`scipy.linalg.inv`.
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.inv.html
"""
raise NotImplementedError()
def sqrtm(self, a):
r"""
Computes the matrix square root. Requires input to be definite positive.
This function follows the api from :any:`scipy.linalg.sqrtm`.
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.sqrtm.html
"""
raise NotImplementedError()
def isfinite(self, a):
r"""
Tests element-wise for finiteness (not infinity and not Not a Number).
This function follows the api from :any:`numpy.isfinite`.
See: https://numpy.org/doc/stable/reference/generated/numpy.isfinite.html
"""
raise NotImplementedError()
def array_equal(self, a, b):
r"""
True if two arrays have the same shape and elements, False otherwise.
This function follows the api from :any:`numpy.array_equal`.
See: https://numpy.org/doc/stable/reference/generated/numpy.array_equal.html
"""
raise NotImplementedError()
def is_floating_point(self, a):
r"""
Returns whether or not the input consists of floats
"""
raise NotImplementedError()
class NumpyBackend(Backend):
"""
NumPy implementation of the backend
- `__name__` is "numpy"
- `__type__` is np.ndarray
"""
__name__ = 'numpy'
__type__ = np.ndarray
__type_list__ = [np.array(1, dtype=np.float32),
np.array(1, dtype=np.float64)]
rng_ = np.random.RandomState()
def _to_numpy(self, a):
return a
def _from_numpy(self, a, type_as=None):
if type_as is None:
return a
elif isinstance(a, float):
return a
else:
return a.astype(type_as.dtype)
def set_gradients(self, val, inputs, grads):
# No gradients for numpy
return val
def zeros(self, shape, type_as=None):
if type_as is None:
return np.zeros(shape)
else:
return np.zeros(shape, dtype=type_as.dtype)
def ones(self, shape, type_as=None):
if type_as is None:
return np.ones(shape)
else:
return np.ones(shape, dtype=type_as.dtype)
def arange(self, stop, start=0, step=1, type_as=None):
return np.arange(start, stop, step)
def full(self, shape, fill_value, type_as=None):
if type_as is None:
return np.full(shape, fill_value)
else:
return np.full(shape, fill_value, dtype=type_as.dtype)
def eye(self, N, M=None, type_as=None):
if type_as is None:
return np.eye(N, M)
else:
return np.eye(N, M, dtype=type_as.dtype)
def sum(self, a, axis=None, keepdims=False):
return np.sum(a, axis, keepdims=keepdims)
def cumsum(self, a, axis=None):
return np.cumsum(a, axis)
def max(self, a, axis=None, keepdims=False):
return np.max(a, axis, keepdims=keepdims)
def min(self, a, axis=None, keepdims=False):
return np.min(a, axis, keepdims=keepdims)
def maximum(self, a, b):
return np.maximum(a, b)
def minimum(self, a, b):
return np.minimum(a, b)
def dot(self, a, b):
return np.dot(a, b)
def abs(self, a):
return np.abs(a)
def exp(self, a):
return np.exp(a)
def log(self, a):
return np.log(a)
def sqrt(self, a):
return np.sqrt(a)
def power(self, a, exponents):
return np.power(a, exponents)
def norm(self, a):
return np.sqrt(np.sum(np.square(a)))
def any(self, a):
return np.any(a)
def isnan(self, a):
return np.isnan(a)
def isinf(self, a):
return np.isinf(a)
def einsum(self, subscripts, *operands):
return np.einsum(subscripts, *operands)
def sort(self, a, axis=-1):
return np.sort(a, axis)
def argsort(self, a, axis=-1):
return np.argsort(a, axis)
def searchsorted(self, a, v, side='left'):
if a.ndim == 1:
return np.searchsorted(a, v, side)
else:
# this is a not very efficient way to make numpy
# searchsorted work on 2d arrays
ret = np.empty(v.shape, dtype=int)
for i in range(a.shape[0]):
ret[i, :] = np.searchsorted(a[i, :], v[i, :], side)
return ret
def flip(self, a, axis=None):
return np.flip(a, axis)
def outer(self, a, b):
return np.outer(a, b)
def clip(self, a, a_min, a_max):
return np.clip(a, a_min, a_max)
def repeat(self, a, repeats, axis=None):
return np.repeat(a, repeats, axis)
def take_along_axis(self, arr, indices, axis):
return np.take_along_axis(arr, indices, axis)
def concatenate(self, arrays, axis=0):
return np.concatenate(arrays, axis)
def zero_pad(self, a, pad_width):
return np.pad(a, pad_width)
def argmax(self, a, axis=None):
return np.argmax(a, axis=axis)
def argmin(self, a, axis=None):
return np.argmin(a, axis=axis)
def mean(self, a, axis=None):
return np.mean(a, axis=axis)
def std(self, a, axis=None):
return np.std(a, axis=axis)
def linspace(self, start, stop, num):
return np.linspace(start, stop, num)
def meshgrid(self, a, b):
return np.meshgrid(a, b)
def diag(self, a, k=0):
return np.diag(a, k)
def unique(self, a):
return np.unique(a)
def logsumexp(self, a, axis=None):
return special.logsumexp(a, axis=axis)
def stack(self, arrays, axis=0):
return np.stack(arrays, axis)
def reshape(self, a, shape):
return np.reshape(a, shape)
def seed(self, seed=None):
if seed is not None:
self.rng_.seed(seed)
def rand(self, *size, type_as=None):
return self.rng_.rand(*size)
def randn(self, *size, type_as=None):
return self.rng_.randn(*size)
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
if type_as is None:
return coo_matrix((data, (rows, cols)), shape=shape)
else:
return coo_matrix((data, (rows, cols)), shape=shape, dtype=type_as.dtype)
def issparse(self, a):
return issparse(a)
def tocsr(self, a):
if self.issparse(a):
return a.tocsr()
else:
return csr_matrix(a)
def eliminate_zeros(self, a, threshold=0.):
if threshold > 0:
if self.issparse(a):
a.data[self.abs(a.data) <= threshold] = 0
else:
a[self.abs(a) <= threshold] = 0
if self.issparse(a):
a.eliminate_zeros()
return a
def todense(self, a):
if self.issparse(a):
return a.toarray()
else:
return a
def where(self, condition, x=None, y=None):
if x is None and y is None:
return np.where(condition)
else:
return np.where(condition, x, y)
def copy(self, a):
return a.copy()
def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
def dtype_device(self, a):
if hasattr(a, "dtype"):
return a.dtype, "cpu"
else:
return type(a), "cpu"
def assert_same_dtype_device(self, a, b):
# numpy has implicit type conversion so we automatically validate the test
pass
def squeeze(self, a, axis=None):
return np.squeeze(a, axis=axis)
def bitsize(self, type_as):
return type_as.itemsize * 8
def device_type(self, type_as):
return "CPU"
def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
results = dict()
for type_as in self.__type_list__:
inputs = [self.from_numpy(arg, type_as=type_as) for arg in args]
for _ in range(warmup_runs):
callable(*inputs)
t0 = time.perf_counter()
for _ in range(n_runs):
callable(*inputs)
t1 = time.perf_counter()
key = ("Numpy", self.device_type(type_as), self.bitsize(type_as))
results[key] = (t1 - t0) / n_runs
return results
def solve(self, a, b):
return np.linalg.solve(a, b)
def trace(self, a):
return np.trace(a)
def inv(self, a):
return scipy.linalg.inv(a)
def sqrtm(self, a):
return scipy.linalg.sqrtm(a)
def isfinite(self, a):
return np.isfinite(a)
def array_equal(self, a, b):
return np.array_equal(a, b)
def is_floating_point(self, a):
return a.dtype.kind == "f"
class JaxBackend(Backend):
"""
JAX implementation of the backend
- `__name__` is "jax"
- `__type__` is jax.numpy.ndarray
"""
__name__ = 'jax'
__type__ = jax_type
__type_list__ = None
rng_ = None
def __init__(self):
self.rng_ = jax.random.PRNGKey(42)
self.__type_list__ = []
# available_devices = jax.devices("cpu")
available_devices = []
if xla_bridge.get_backend().platform == "gpu":
available_devices += jax.devices("gpu")
for d in available_devices:
self.__type_list__ += [
jax.device_put(jnp.array(1, dtype=jnp.float32), d),
jax.device_put(jnp.array(1, dtype=jnp.float64), d)
]
def _to_numpy(self, a):
return np.array(a)
def _change_device(self, a, type_as):
return jax.device_put(a, type_as.device_buffer.device())
def _from_numpy(self, a, type_as=None):
if isinstance(a, float):
a = np.array(a)
if type_as is None:
return jnp.array(a)
else:
return self._change_device(jnp.array(a).astype(type_as.dtype), type_as)
def set_gradients(self, val, inputs, grads):
from jax.flatten_util import ravel_pytree
val, = jax.lax.stop_gradient((val,))
ravelled_inputs, _ = ravel_pytree(inputs)
ravelled_grads, _ = ravel_pytree(grads)
aux = jnp.sum(ravelled_inputs * ravelled_grads) / 2
aux = aux - jax.lax.stop_gradient(aux)
val, = jax.tree_map(lambda z: z + aux, (val,))
return val
def zeros(self, shape, type_as=None):
if type_as is None:
return jnp.zeros(shape)
else:
return self._change_device(jnp.zeros(shape, dtype=type_as.dtype), type_as)
def ones(self, shape, type_as=None):
if type_as is None:
return jnp.ones(shape)
else:
return self._change_device(jnp.ones(shape, dtype=type_as.dtype), type_as)
def arange(self, stop, start=0, step=1, type_as=None):
return jnp.arange(start, stop, step)
def full(self, shape, fill_value, type_as=None):
if type_as is None:
return jnp.full(shape, fill_value)
else:
return self._change_device(jnp.full(shape, fill_value, dtype=type_as.dtype), type_as)
def eye(self, N, M=None, type_as=None):
if type_as is None:
return jnp.eye(N, M)
else:
return self._change_device(jnp.eye(N, M, dtype=type_as.dtype), type_as)
def sum(self, a, axis=None, keepdims=False):
return jnp.sum(a, axis, keepdims=keepdims)
def cumsum(self, a, axis=None):
return jnp.cumsum(a, axis)
def max(self, a, axis=None, keepdims=False):
return jnp.max(a, axis, keepdims=keepdims)
def min(self, a, axis=None, keepdims=False):
return jnp.min(a, axis, keepdims=keepdims)
def maximum(self, a, b):
return jnp.maximum(a, b)
def minimum(self, a, b):
return jnp.minimum(a, b)
def dot(self, a, b):
return jnp.dot(a, b)
def abs(self, a):
return jnp.abs(a)
def exp(self, a):
return jnp.exp(a)
def log(self, a):
return jnp.log(a)
def sqrt(self, a):
return jnp.sqrt(a)
def power(self, a, exponents):
return jnp.power(a, exponents)
def norm(self, a):
return jnp.sqrt(jnp.sum(jnp.square(a)))
def any(self, a):
return jnp.any(a)
def isnan(self, a):
return jnp.isnan(a)
def isinf(self, a):
return jnp.isinf(a)
def einsum(self, subscripts, *operands):
return jnp.einsum(subscripts, *operands)
def sort(self, a, axis=-1):
return jnp.sort(a, axis)
def argsort(self, a, axis=-1):
return jnp.argsort(a, axis)
def searchsorted(self, a, v, side='left'):
if a.ndim == 1:
return jnp.searchsorted(a, v, side)
else:
# this is a not very efficient way to make jax numpy
# searchsorted work on 2d arrays
return jnp.array([jnp.searchsorted(a[i, :], v[i, :], side) for i in range(a.shape[0])])
def flip(self, a, axis=None):
return jnp.flip(a, axis)
def outer(self, a, b):
return jnp.outer(a, b)
def clip(self, a, a_min, a_max):
return jnp.clip(a, a_min, a_max)
def repeat(self, a, repeats, axis=None):
return jnp.repeat(a, repeats, axis)
def take_along_axis(self, arr, indices, axis):
return jnp.take_along_axis(arr, indices, axis)
def concatenate(self, arrays, axis=0):
return jnp.concatenate(arrays, axis)
def zero_pad(self, a, pad_width):
return jnp.pad(a, pad_width)
def argmax(self, a, axis=None):
return jnp.argmax(a, axis=axis)
def argmin(self, a, axis=None):
return jnp.argmin(a, axis=axis)
def mean(self, a, axis=None):
return jnp.mean(a, axis=axis)
def std(self, a, axis=None):
return jnp.std(a, axis=axis)
def linspace(self, start, stop, num):
return jnp.linspace(start, stop, num)
def meshgrid(self, a, b):
return jnp.meshgrid(a, b)
def diag(self, a, k=0):
return jnp.diag(a, k)
def unique(self, a):
return jnp.unique(a)
def logsumexp(self, a, axis=None):
return jspecial.logsumexp(a, axis=axis)
def stack(self, arrays, axis=0):
return jnp.stack(arrays, axis)
def reshape(self, a, shape):
return jnp.reshape(a, shape)
def seed(self, seed=None):
if seed is not None:
self.rng_ = jax.random.PRNGKey(seed)
def rand(self, *size, type_as=None):
self.rng_, subkey = jax.random.split(self.rng_)
if type_as is not None:
return jax.random.uniform(subkey, shape=size, dtype=type_as.dtype)
else:
return jax.random.uniform(subkey, shape=size)
def randn(self, *size, type_as=None):
self.rng_, subkey = jax.random.split(self.rng_)
if type_as is not None:
return jax.random.normal(subkey, shape=size, dtype=type_as.dtype)
else:
return jax.random.normal(subkey, shape=size)
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
# Currently, JAX does not support sparse matrices
data = self.to_numpy(data)
rows = self.to_numpy(rows)
cols = self.to_numpy(cols)
nx = NumpyBackend()
coo_matrix = nx.coo_matrix(data, rows, cols, shape=shape, type_as=type_as)
matrix = nx.todense(coo_matrix)
return self.from_numpy(matrix)
def issparse(self, a):
# Currently, JAX does not support sparse matrices
return False
def tocsr(self, a):
# Currently, JAX does not support sparse matrices
return a
def eliminate_zeros(self, a, threshold=0.):
# Currently, JAX does not support sparse matrices
if threshold > 0:
return self.where(
self.abs(a) <= threshold,
self.zeros((1,), type_as=a),
a
)
return a
def todense(self, a):
# Currently, JAX does not support sparse matrices
return a
def where(self, condition, x=None, y=None):
if x is None and y is None:
return jnp.where(condition)
else:
return jnp.where(condition, x, y)
def copy(self, a):
# No need to copy, JAX arrays are immutable
return a
def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
return jnp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
def dtype_device(self, a):
return a.dtype, a.device_buffer.device()
def assert_same_dtype_device(self, a, b):
a_dtype, a_device = self.dtype_device(a)
b_dtype, b_device = self.dtype_device(b)
assert a_dtype == b_dtype, "Dtype discrepancy"
assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
def squeeze(self, a, axis=None):
return jnp.squeeze(a, axis=axis)
def bitsize(self, type_as):
return type_as.dtype.itemsize * 8
def device_type(self, type_as):
return self.dtype_device(type_as)[1].platform.upper()
def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
results = dict()
for type_as in self.__type_list__:
inputs = [self.from_numpy(arg, type_as=type_as) for arg in args]
for _ in range(warmup_runs):
a = callable(*inputs)
a.block_until_ready()
t0 = time.perf_counter()
for _ in range(n_runs):
a = callable(*inputs)
a.block_until_ready()
t1 = time.perf_counter()
key = ("Jax", self.device_type(type_as), self.bitsize(type_as))
results[key] = (t1 - t0) / n_runs
return results
def solve(self, a, b):
return jnp.linalg.solve(a, b)
def trace(self, a):
return jnp.trace(a)
def inv(self, a):
return jnp.linalg.inv(a)
def sqrtm(self, a):
L, V = jnp.linalg.eigh(a)
return (V * jnp.sqrt(L)[None, :]) @ V.T
def isfinite(self, a):
return jnp.isfinite(a)
def array_equal(self, a, b):
return jnp.array_equal(a, b)
def is_floating_point(self, a):
return a.dtype.kind == "f"
class TorchBackend(Backend):
"""
PyTorch implementation of the backend
- `__name__` is "torch"
- `__type__` is torch.Tensor
"""
__name__ = 'torch'
__type__ = torch_type
__type_list__ = None
rng_ = None
def __init__(self):
self.rng_ = torch.Generator()
self.rng_.seed()
self.__type_list__ = [torch.tensor(1, dtype=torch.float32),
torch.tensor(1, dtype=torch.float64)]
if torch.cuda.is_available():
self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda'))
self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda'))
from torch.autograd import Function
# define a function that takes inputs val and grads
# ad returns a val tensor with proper gradients
class ValFunction(Function):
@staticmethod
def forward(ctx, val, grads, *inputs):
ctx.grads = grads
return val
@staticmethod
def backward(ctx, grad_output):
# the gradients are grad
return (None, None) + tuple(g * grad_output for g in ctx.grads)
self.ValFunction = ValFunction
def _to_numpy(self, a):
return a.cpu().detach().numpy()
def _from_numpy(self, a, type_as=None):
if isinstance(a, float):
a = np.array(a)
if type_as is None:
return torch.from_numpy(a)
else:
return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device)
def set_gradients(self, val, inputs, grads):
Func = self.ValFunction
res = Func.apply(val, grads, *inputs)
return res
def zeros(self, shape, type_as=None):
if isinstance(shape, int):
shape = (shape,)
if type_as is None:
return torch.zeros(shape)
else:
return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device)
def ones(self, shape, type_as=None):
if isinstance(shape, int):
shape = (shape,)
if type_as is None:
return torch.ones(shape)
else:
return torch.ones(shape, dtype=type_as.dtype, device=type_as.device)
def arange(self, stop, start=0, step=1, type_as=None):
if type_as is None:
return torch.arange(start, stop, step)
else:
return torch.arange(start, stop, step, device=type_as.device)
def full(self, shape, fill_value, type_as=None):
if isinstance(shape, int):
shape = (shape,)
if type_as is None:
return torch.full(shape, fill_value)
else:
return torch.full(shape, fill_value, dtype=type_as.dtype, device=type_as.device)
def eye(self, N, M=None, type_as=None):
if M is None:
M = N
if type_as is None:
return torch.eye(N, m=M)
else:
return torch.eye(N, m=M, dtype=type_as.dtype, device=type_as.device)
def sum(self, a, axis=None, keepdims=False):
if axis is None:
return torch.sum(a)
else:
return torch.sum(a, axis, keepdim=keepdims)
def cumsum(self, a, axis=None):
if axis is None:
return torch.cumsum(a.flatten(), 0)
else:
return torch.cumsum(a, axis)
def max(self, a, axis=None, keepdims=False):
if axis is None:
return torch.max(a)
else:
return torch.max(a, axis, keepdim=keepdims)[0]
def min(self, a, axis=None, keepdims=False):
if axis is None:
return torch.min(a)
else:
return torch.min(a, axis, keepdim=keepdims)[0]
def maximum(self, a, b):
if isinstance(a, int) or isinstance(a, float):
a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
if isinstance(b, int) or isinstance(b, float):
b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
if hasattr(torch, "maximum"):
return torch.maximum(a, b)
else:
return torch.max(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0]
def minimum(self, a, b):
if isinstance(a, int) or isinstance(a, float):
a = torch.tensor([float(a)], dtype=b.dtype, device=b.device)
if isinstance(b, int) or isinstance(b, float):
b = torch.tensor([float(b)], dtype=a.dtype, device=a.device)
if hasattr(torch, "minimum"):
return torch.minimum(a, b)
else:
return torch.min(torch.stack(torch.broadcast_tensors(a, b)), axis=0)[0]
def dot(self, a, b):
return torch.matmul(a, b)
def abs(self, a):
return torch.abs(a)
def exp(self, a):
return torch.exp(a)
def log(self, a):
return torch.log(a)
def sqrt(self, a):
return torch.sqrt(a)
def power(self, a, exponents):
return torch.pow(a, exponents)
def norm(self, a):
return torch.sqrt(torch.sum(torch.square(a)))
def any(self, a):
return torch.any(a)
def isnan(self, a):
return torch.isnan(a)
def isinf(self, a):
return torch.isinf(a)
def einsum(self, subscripts, *operands):
return torch.einsum(subscripts, *operands)
def sort(self, a, axis=-1):
sorted0, indices = torch.sort(a, dim=axis)
return sorted0
def argsort(self, a, axis=-1):
sorted, indices = torch.sort(a, dim=axis)
return indices
def searchsorted(self, a, v, side='left'):
right = (side != 'left')
return torch.searchsorted(a, v, right=right)
def flip(self, a, axis=None):
if axis is None:
return torch.flip(a, tuple(i for i in range(len(a.shape))))
if isinstance(axis, int):
return torch.flip(a, (axis,))
else:
return torch.flip(a, dims=axis)
def outer(self, a, b):
return torch.outer(a, b)
def clip(self, a, a_min, a_max):
return torch.clamp(a, a_min, a_max)
def repeat(self, a, repeats, axis=None):
return torch.repeat_interleave(a, repeats, dim=axis)
def take_along_axis(self, arr, indices, axis):
return torch.gather(arr, axis, indices)
def concatenate(self, arrays, axis=0):
return torch.cat(arrays, dim=axis)
def zero_pad(self, a, pad_width):
from torch.nn.functional import pad
# pad_width is an array of ndim tuples indicating how many 0 before and after
# we need to add. We first need to make it compliant with torch syntax, that
# starts with the last dim, then second last, etc.
how_pad = tuple(element for tupl in pad_width[::-1] for element in tupl)
return pad(a, how_pad)
def argmax(self, a, axis=None):
return torch.argmax(a, dim=axis)
def argmin(self, a, axis=None):
return torch.argmin(a, dim=axis)
def mean(self, a, axis=None):
if axis is not None:
return torch.mean(a, dim=axis)
else:
return torch.mean(a)
def std(self, a, axis=None):
if axis is not None:
return torch.std(a, dim=axis, unbiased=False)
else:
return torch.std(a, unbiased=False)
def linspace(self, start, stop, num):
return torch.linspace(start, stop, num, dtype=torch.float64)
def meshgrid(self, a, b):
try:
return torch.meshgrid(a, b, indexing="xy")
except TypeError:
X, Y = torch.meshgrid(a, b)
return X.T, Y.T
def diag(self, a, k=0):
return torch.diag(a, diagonal=k)
def unique(self, a):
return torch.unique(a)
def logsumexp(self, a, axis=None):
if axis is not None:
return torch.logsumexp(a, dim=axis)
else:
return torch.logsumexp(a, dim=tuple(range(len(a.shape))))
def stack(self, arrays, axis=0):
return torch.stack(arrays, dim=axis)
def reshape(self, a, shape):
return torch.reshape(a, shape)
def seed(self, seed=None):
if isinstance(seed, int):
self.rng_.manual_seed(seed)
elif isinstance(seed, torch.Generator):
self.rng_ = seed
else:
raise ValueError("Non compatible seed : {}".format(seed))
def rand(self, *size, type_as=None):
if type_as is not None:
return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device)
else:
return torch.rand(size=size, generator=self.rng_)
def randn(self, *size, type_as=None):
if type_as is not None:
return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device)
else:
return torch.randn(size=size, generator=self.rng_)
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
if type_as is None:
return torch.sparse_coo_tensor(torch.stack([rows, cols]), data, size=shape)
else:
return torch.sparse_coo_tensor(
torch.stack([rows, cols]), data, size=shape,
dtype=type_as.dtype, device=type_as.device
)
def issparse(self, a):
return getattr(a, "is_sparse", False) or getattr(a, "is_sparse_csr", False)
def tocsr(self, a):
# Versions older than 1.9 do not support CSR tensors. PyTorch 1.9 and 1.10 offer a very limited support
return self.todense(a)
def eliminate_zeros(self, a, threshold=0.):
if self.issparse(a):
if threshold > 0:
mask = self.abs(a) <= threshold
mask = ~mask
mask = mask.nonzero()
else:
mask = a._values().nonzero()
nv = a._values().index_select(0, mask.view(-1))
ni = a._indices().index_select(1, mask.view(-1))
return self.coo_matrix(nv, ni[0], ni[1], shape=a.shape, type_as=a)
else:
if threshold > 0:
a[self.abs(a) <= threshold] = 0
return a
def todense(self, a):
if self.issparse(a):
return a.to_dense()
else:
return a
def where(self, condition, x=None, y=None):
if x is None and y is None:
return torch.where(condition)
else:
return torch.where(condition, x, y)
def copy(self, a):
return torch.clone(a)
def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
def dtype_device(self, a):
return a.dtype, a.device
def assert_same_dtype_device(self, a, b):
a_dtype, a_device = self.dtype_device(a)
b_dtype, b_device = self.dtype_device(b)
assert a_dtype == b_dtype, "Dtype discrepancy"
assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
def squeeze(self, a, axis=None):
if axis is None:
return torch.squeeze(a)
else:
return torch.squeeze(a, dim=axis)
def bitsize(self, type_as):
return torch.finfo(type_as.dtype).bits
def device_type(self, type_as):
return type_as.device.type.replace("cuda", "gpu").upper()
def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
results = dict()
for type_as in self.__type_list__:
inputs = [self.from_numpy(arg, type_as=type_as) for arg in args]
for _ in range(warmup_runs):
callable(*inputs)
if self.device_type(type_as) == "GPU": # pragma: no cover
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
else:
start = time.perf_counter()
for _ in range(n_runs):
callable(*inputs)
if self.device_type(type_as) == "GPU": # pragma: no cover
end.record()
torch.cuda.synchronize()
duration = start.elapsed_time(end) / 1000.
else:
end = time.perf_counter()
duration = end - start
key = ("Pytorch", self.device_type(type_as), self.bitsize(type_as))
results[key] = duration / n_runs
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
def solve(self, a, b):
return torch.linalg.solve(a, b)
def trace(self, a):
return torch.trace(a)
def inv(self, a):
return torch.linalg.inv(a)
def sqrtm(self, a):
L, V = torch.linalg.eigh(a)
return (V * torch.sqrt(L)[None, :]) @ V.T
def isfinite(self, a):
return torch.isfinite(a)
def array_equal(self, a, b):
return torch.equal(a, b)
def is_floating_point(self, a):
return a.dtype.is_floating_point
class CupyBackend(Backend): # pragma: no cover
"""
CuPy implementation of the backend
- `__name__` is "cupy"
- `__type__` is cp.ndarray
"""
__name__ = 'cupy'
__type__ = cp_type
__type_list__ = None
rng_ = None
def __init__(self):
self.rng_ = cp.random.RandomState()
self.__type_list__ = [
cp.array(1, dtype=cp.float32),
cp.array(1, dtype=cp.float64)
]
def _to_numpy(self, a):
return cp.asnumpy(a)
def _from_numpy(self, a, type_as=None):
if isinstance(a, float):
a = np.array(a)
if type_as is None:
return cp.asarray(a)
else:
with cp.cuda.Device(type_as.device):
return cp.asarray(a, dtype=type_as.dtype)
def set_gradients(self, val, inputs, grads):
# No gradients for cupy
return val
def zeros(self, shape, type_as=None):
if isinstance(shape, (list, tuple)):
shape = tuple(int(i) for i in shape)
if type_as is None:
return cp.zeros(shape)
else:
with cp.cuda.Device(type_as.device):
return cp.zeros(shape, dtype=type_as.dtype)
def ones(self, shape, type_as=None):
if isinstance(shape, (list, tuple)):
shape = tuple(int(i) for i in shape)
if type_as is None:
return cp.ones(shape)
else:
with cp.cuda.Device(type_as.device):
return cp.ones(shape, dtype=type_as.dtype)
def arange(self, stop, start=0, step=1, type_as=None):
return cp.arange(start, stop, step)
def full(self, shape, fill_value, type_as=None):
if isinstance(shape, (list, tuple)):
shape = tuple(int(i) for i in shape)
if type_as is None:
return cp.full(shape, fill_value)
else:
with cp.cuda.Device(type_as.device):
return cp.full(shape, fill_value, dtype=type_as.dtype)
def eye(self, N, M=None, type_as=None):
if type_as is None:
return cp.eye(N, M)
else:
with cp.cuda.Device(type_as.device):
return cp.eye(N, M, dtype=type_as.dtype)
def sum(self, a, axis=None, keepdims=False):
return cp.sum(a, axis, keepdims=keepdims)
def cumsum(self, a, axis=None):
return cp.cumsum(a, axis)
def max(self, a, axis=None, keepdims=False):
return cp.max(a, axis, keepdims=keepdims)
def min(self, a, axis=None, keepdims=False):
return cp.min(a, axis, keepdims=keepdims)
def maximum(self, a, b):
return cp.maximum(a, b)
def minimum(self, a, b):
return cp.minimum(a, b)
def abs(self, a):
return cp.abs(a)
def exp(self, a):
return cp.exp(a)
def log(self, a):
return cp.log(a)
def sqrt(self, a):
return cp.sqrt(a)
def power(self, a, exponents):
return cp.power(a, exponents)
def dot(self, a, b):
return cp.dot(a, b)
def norm(self, a):
return cp.sqrt(cp.sum(cp.square(a)))
def any(self, a):
return cp.any(a)
def isnan(self, a):
return cp.isnan(a)
def isinf(self, a):
return cp.isinf(a)
def einsum(self, subscripts, *operands):
return cp.einsum(subscripts, *operands)
def sort(self, a, axis=-1):
return cp.sort(a, axis)
def argsort(self, a, axis=-1):
return cp.argsort(a, axis)
def searchsorted(self, a, v, side='left'):
if a.ndim == 1:
return cp.searchsorted(a, v, side)
else:
# this is a not very efficient way to make numpy
# searchsorted work on 2d arrays
ret = cp.empty(v.shape, dtype=int)
for i in range(a.shape[0]):
ret[i, :] = cp.searchsorted(a[i, :], v[i, :], side)
return ret
def flip(self, a, axis=None):
return cp.flip(a, axis)
def outer(self, a, b):
return cp.outer(a, b)
def clip(self, a, a_min, a_max):
return cp.clip(a, a_min, a_max)
def repeat(self, a, repeats, axis=None):
return cp.repeat(a, repeats, axis)
def take_along_axis(self, arr, indices, axis):
return cp.take_along_axis(arr, indices, axis)
def concatenate(self, arrays, axis=0):
return cp.concatenate(arrays, axis)
def zero_pad(self, a, pad_width):
return cp.pad(a, pad_width)
def argmax(self, a, axis=None):
return cp.argmax(a, axis=axis)
def argmin(self, a, axis=None):
return cp.argmin(a, axis=axis)
def mean(self, a, axis=None):
return cp.mean(a, axis=axis)
def std(self, a, axis=None):
return cp.std(a, axis=axis)
def linspace(self, start, stop, num):
return cp.linspace(start, stop, num)
def meshgrid(self, a, b):
return cp.meshgrid(a, b)
def diag(self, a, k=0):
return cp.diag(a, k)
def unique(self, a):
return cp.unique(a)
def logsumexp(self, a, axis=None):
# Taken from
# https://github.com/scipy/scipy/blob/v1.7.1/scipy/special/_logsumexp.py#L7-L127
a_max = cp.amax(a, axis=axis, keepdims=True)
if a_max.ndim > 0:
a_max[~cp.isfinite(a_max)] = 0
elif not cp.isfinite(a_max):
a_max = 0
tmp = cp.exp(a - a_max)
s = cp.sum(tmp, axis=axis)
out = cp.log(s)
a_max = cp.squeeze(a_max, axis=axis)
out += a_max
return out
def stack(self, arrays, axis=0):
return cp.stack(arrays, axis)
def reshape(self, a, shape):
return cp.reshape(a, shape)
def seed(self, seed=None):
if seed is not None:
self.rng_.seed(seed)
def rand(self, *size, type_as=None):
if type_as is None:
return self.rng_.rand(*size)
else:
with cp.cuda.Device(type_as.device):
return self.rng_.rand(*size, dtype=type_as.dtype)
def randn(self, *size, type_as=None):
if type_as is None:
return self.rng_.randn(*size)
else:
with cp.cuda.Device(type_as.device):
return self.rng_.randn(*size, dtype=type_as.dtype)
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
data = self.from_numpy(data)
rows = self.from_numpy(rows)
cols = self.from_numpy(cols)
if type_as is None:
return cupyx.scipy.sparse.coo_matrix(
(data, (rows, cols)), shape=shape
)
else:
with cp.cuda.Device(type_as.device):
return cupyx.scipy.sparse.coo_matrix(
(data, (rows, cols)), shape=shape, dtype=type_as.dtype
)
def issparse(self, a):
return cupyx.scipy.sparse.issparse(a)
def tocsr(self, a):
if self.issparse(a):
return a.tocsr()
else:
return cupyx.scipy.sparse.csr_matrix(a)
def eliminate_zeros(self, a, threshold=0.):
if threshold > 0:
if self.issparse(a):
a.data[self.abs(a.data) <= threshold] = 0
else:
a[self.abs(a) <= threshold] = 0
if self.issparse(a):
a.eliminate_zeros()
return a
def todense(self, a):
if self.issparse(a):
return a.toarray()
else:
return a
def where(self, condition, x=None, y=None):
if x is None and y is None:
return cp.where(condition)
else:
return cp.where(condition, x, y)
def copy(self, a):
return a.copy()
def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
return cp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
def dtype_device(self, a):
return a.dtype, a.device
def assert_same_dtype_device(self, a, b):
a_dtype, a_device = self.dtype_device(a)
b_dtype, b_device = self.dtype_device(b)
# cupy has implicit type conversion so
# we automatically validate the test for type
assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
def squeeze(self, a, axis=None):
return cp.squeeze(a, axis=axis)
def bitsize(self, type_as):
return type_as.itemsize * 8
def device_type(self, type_as):
return "GPU"
def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
mempool = cp.get_default_memory_pool()
pinned_mempool = cp.get_default_pinned_memory_pool()
results = dict()
for type_as in self.__type_list__:
inputs = [self.from_numpy(arg, type_as=type_as) for arg in args]
start_gpu = cp.cuda.Event()
end_gpu = cp.cuda.Event()
for _ in range(warmup_runs):
callable(*inputs)
start_gpu.synchronize()
start_gpu.record()
for _ in range(n_runs):
callable(*inputs)
end_gpu.record()
end_gpu.synchronize()
key = ("Cupy", self.device_type(type_as), self.bitsize(type_as))
t_gpu = cp.cuda.get_elapsed_time(start_gpu, end_gpu) / 1000.
results[key] = t_gpu / n_runs
mempool.free_all_blocks()
pinned_mempool.free_all_blocks()
return results
def solve(self, a, b):
return cp.linalg.solve(a, b)
def trace(self, a):
return cp.trace(a)
def inv(self, a):
return cp.linalg.inv(a)
def sqrtm(self, a):
L, V = cp.linalg.eigh(a)
return (V * self.sqrt(L)[None, :]) @ V.T
def isfinite(self, a):
return cp.isfinite(a)
def array_equal(self, a, b):
return cp.array_equal(a, b)
def is_floating_point(self, a):
return a.dtype.kind == "f"
class TensorflowBackend(Backend):
__name__ = "tf"
__type__ = tf_type
__type_list__ = None
rng_ = None
def __init__(self):
self.seed(None)
self.__type_list__ = [
tf.convert_to_tensor([1], dtype=tf.float32),
tf.convert_to_tensor([1], dtype=tf.float64)
]
tmp = self.randn(15, 10)
try:
tmp.reshape((150, 1))
except AttributeError:
warnings.warn(
"To use TensorflowBackend, you need to activate the tensorflow "
"numpy API. You can activate it by running: \n"
"from tensorflow.python.ops.numpy_ops import np_config\n"
"np_config.enable_numpy_behavior()",
stacklevel=2
)
def _to_numpy(self, a):
return a.numpy()
def _from_numpy(self, a, type_as=None):
if isinstance(a, float):
a = np.array(a)
if not isinstance(a, self.__type__):
if type_as is None:
return tf.convert_to_tensor(a)
else:
return tf.convert_to_tensor(a, dtype=type_as.dtype)
else:
if type_as is None:
return a
else:
return tf.cast(a, dtype=type_as.dtype)
def set_gradients(self, val, inputs, grads):
@tf.custom_gradient
def tmp(input):
def grad(upstream):
return grads
return val, grad
return tmp(inputs)
def zeros(self, shape, type_as=None):
if type_as is None:
return tnp.zeros(shape)
else:
return tnp.zeros(shape, dtype=type_as.dtype)
def ones(self, shape, type_as=None):
if type_as is None:
return tnp.ones(shape)
else:
return tnp.ones(shape, dtype=type_as.dtype)
def arange(self, stop, start=0, step=1, type_as=None):
return tnp.arange(start, stop, step)
def full(self, shape, fill_value, type_as=None):
if type_as is None:
return tnp.full(shape, fill_value)
else:
return tnp.full(shape, fill_value, dtype=type_as.dtype)
def eye(self, N, M=None, type_as=None):
if type_as is None:
return tnp.eye(N, M)
else:
return tnp.eye(N, M, dtype=type_as.dtype)
def sum(self, a, axis=None, keepdims=False):
return tnp.sum(a, axis, keepdims=keepdims)
def cumsum(self, a, axis=None):
return tnp.cumsum(a, axis)
def max(self, a, axis=None, keepdims=False):
return tnp.max(a, axis, keepdims=keepdims)
def min(self, a, axis=None, keepdims=False):
return tnp.min(a, axis, keepdims=keepdims)
def maximum(self, a, b):
return tnp.maximum(a, b)
def minimum(self, a, b):
return tnp.minimum(a, b)
def dot(self, a, b):
if len(b.shape) == 1:
if len(a.shape) == 1:
# inner product
return tf.reduce_sum(tf.multiply(a, b))
else:
# matrix vector
return tf.linalg.matvec(a, b)
else:
if len(a.shape) == 1:
return tf.linalg.matvec(b.T, a.T).T
else:
return tf.matmul(a, b)
def abs(self, a):
return tnp.abs(a)
def exp(self, a):
return tnp.exp(a)
def log(self, a):
return tnp.log(a)
def sqrt(self, a):
return tnp.sqrt(a)
def power(self, a, exponents):
return tnp.power(a, exponents)
def norm(self, a):
return tf.math.reduce_euclidean_norm(a)
def any(self, a):
return tnp.any(a)
def isnan(self, a):
return tnp.isnan(a)
def isinf(self, a):
return tnp.isinf(a)
def einsum(self, subscripts, *operands):
return tnp.einsum(subscripts, *operands)
def sort(self, a, axis=-1):
return tnp.sort(a, axis)
def argsort(self, a, axis=-1):
return tnp.argsort(a, axis)
def searchsorted(self, a, v, side='left'):
return tf.searchsorted(a, v, side=side)
def flip(self, a, axis=None):
return tnp.flip(a, axis)
def outer(self, a, b):
return tnp.outer(a, b)
def clip(self, a, a_min, a_max):
return tnp.clip(a, a_min, a_max)
def repeat(self, a, repeats, axis=None):
return tnp.repeat(a, repeats, axis)
def take_along_axis(self, arr, indices, axis):
return tnp.take_along_axis(arr, indices, axis)
def concatenate(self, arrays, axis=0):
return tnp.concatenate(arrays, axis)
def zero_pad(self, a, pad_width):
return tnp.pad(a, pad_width, mode="constant")
def argmax(self, a, axis=None):
return tnp.argmax(a, axis=axis)
def argmin(self, a, axis=None):
return tnp.argmin(a, axis=axis)
def mean(self, a, axis=None):
return tnp.mean(a, axis=axis)
def std(self, a, axis=None):
return tnp.std(a, axis=axis)
def linspace(self, start, stop, num):
return tnp.linspace(start, stop, num)
def meshgrid(self, a, b):
return tnp.meshgrid(a, b)
def diag(self, a, k=0):
return tnp.diag(a, k)
def unique(self, a):
return tf.sort(tf.unique(tf.reshape(a, [-1]))[0])
def logsumexp(self, a, axis=None):
return tf.math.reduce_logsumexp(a, axis=axis)
def stack(self, arrays, axis=0):
return tnp.stack(arrays, axis)
def reshape(self, a, shape):
return tnp.reshape(a, shape)
def seed(self, seed=None):
if isinstance(seed, int):
self.rng_ = tf.random.Generator.from_seed(seed)
elif isinstance(seed, tf.random.Generator):
self.rng_ = seed
elif seed is None:
self.rng_ = tf.random.Generator.from_non_deterministic_state()
else:
raise ValueError("Non compatible seed : {}".format(seed))
def rand(self, *size, type_as=None):
if type_as is None:
return self.rng_.uniform(size, minval=0., maxval=1.)
else:
return self.rng_.uniform(
size, minval=0., maxval=1., dtype=type_as.dtype
)
def randn(self, *size, type_as=None):
if type_as is None:
return self.rng_.normal(size)
else:
return self.rng_.normal(size, dtype=type_as.dtype)
def _convert_to_index_for_coo(self, tensor):
if isinstance(tensor, self.__type__):
return int(self.max(tensor)) + 1
else:
return int(np.max(tensor)) + 1
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
if shape is None:
shape = (
self._convert_to_index_for_coo(rows),
self._convert_to_index_for_coo(cols)
)
if type_as is not None:
data = self.from_numpy(data, type_as=type_as)
sparse_tensor = tf.sparse.SparseTensor(
indices=tnp.stack([rows, cols]).T,
values=data,
dense_shape=shape
)
# if type_as is not None:
# sparse_tensor = self.from_numpy(sparse_tensor, type_as=type_as)
# SparseTensor are not subscriptable so we use dense tensors
return self.todense(sparse_tensor)
def issparse(self, a):
return isinstance(a, tf.sparse.SparseTensor)
def tocsr(self, a):
return a
def eliminate_zeros(self, a, threshold=0.):
if self.issparse(a):
values = a.values
if threshold > 0:
mask = self.abs(values) <= threshold
else:
mask = values == 0
return tf.sparse.retain(a, ~mask)
else:
if threshold > 0:
a = tnp.where(self.abs(a) > threshold, a, 0.)
return a
def todense(self, a):
if self.issparse(a):
return tf.sparse.to_dense(tf.sparse.reorder(a))
else:
return a
def where(self, condition, x=None, y=None):
if x is None and y is None:
return tnp.where(condition)
else:
return tnp.where(condition, x, y)
def copy(self, a):
return tf.identity(a)
def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
return tnp.allclose(
a, b, rtol=rtol, atol=atol, equal_nan=equal_nan
)
def dtype_device(self, a):
return a.dtype, a.device.split("device:")[1]
def assert_same_dtype_device(self, a, b):
a_dtype, a_device = self.dtype_device(a)
b_dtype, b_device = self.dtype_device(b)
assert a_dtype == b_dtype, "Dtype discrepancy"
assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
def squeeze(self, a, axis=None):
return tnp.squeeze(a, axis=axis)
def bitsize(self, type_as):
return type_as.dtype.size * 8
def device_type(self, type_as):
return self.dtype_device(type_as)[1].split(":")[0]
def _bench(self, callable, *args, n_runs=1, warmup_runs=1):
results = dict()
device_contexts = [tf.device("/CPU:0")]
if len(tf.config.list_physical_devices('GPU')) > 0: # pragma: no cover
device_contexts.append(tf.device("/GPU:0"))
for device_context in device_contexts:
with device_context:
for type_as in self.__type_list__:
inputs = [self.from_numpy(arg, type_as=type_as) for arg in args]
for _ in range(warmup_runs):
callable(*inputs)
t0 = time.perf_counter()
for _ in range(n_runs):
res = callable(*inputs)
_ = res.numpy()
t1 = time.perf_counter()
key = (
"Tensorflow",
self.device_type(inputs[0]),
self.bitsize(type_as)
)
results[key] = (t1 - t0) / n_runs
return results
def solve(self, a, b):
return tf.linalg.solve(a, b)
def trace(self, a):
return tf.linalg.trace(a)
def inv(self, a):
return tf.linalg.inv(a)
def sqrtm(self, a):
return tf.linalg.sqrtm(a)
def isfinite(self, a):
return tnp.isfinite(a)
def array_equal(self, a, b):
return tnp.array_equal(a, b)
def is_floating_point(self, a):
return a.dtype.is_floating
|