1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494
|
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Union, Tuple, List, Any, Optional
import torch
from functools import partial, wraps
import contextlib
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
from .pytree_hacks import tree_map_, treespec_pprint
import torch.autograd.forward_ad as fwAD
from .vmap import vmap, doesnt_support_saved_tensors_hooks
from torch._decomp import decomposition_table
from torch._C._functorch import (
_wrap_for_grad,
_unwrap_for_grad,
_grad_increment_nesting,
_grad_decrement_nesting,
_jvp_increment_nesting,
_jvp_decrement_nesting,
set_fwd_grad_enabled,
get_fwd_grad_enabled,
_wrap_functional_tensor,
_unwrap_functional_tensor,
_func_decrement_nesting,
_func_increment_nesting,
_assert_wrapped_functional,
_propagate_functional_input_mutation,
set_inplace_requires_grad_allowed,
get_inplace_requires_grad_allowed
)
argnums_t = Union[int, Tuple[int, ...]]
@contextlib.contextmanager
def enable_inplace_requires_grad(enabled=True):
prev_state = get_inplace_requires_grad_allowed()
set_inplace_requires_grad_allowed(enabled)
try:
yield
finally:
set_inplace_requires_grad_allowed(prev_state)
def _create_differentiable(inps, level=None):
def create_differentiable(x):
if isinstance(x, torch.Tensor):
with enable_inplace_requires_grad():
return x.requires_grad_()
raise ValueError(f'Thing passed to transform API must be Tensor, '
f'got {type(x)}')
return tree_map(create_differentiable, inps)
def _undo_create_differentiable(inps, level=None):
def unwrap_tensors(x):
if isinstance(x, torch.Tensor):
return _unwrap_for_grad(x, level)
# TODO: Remove the following hack for namedtuples
if isinstance(x, tuple):
return tree_map(unwrap_tensors, tuple(x))
raise RuntimeError(f"Expected tensors, got unsupported type {type(x)}")
return tree_map(unwrap_tensors, inps)
def _is_differentiable(maybe_tensor):
if not isinstance(maybe_tensor, torch.Tensor):
return False
return maybe_tensor.requires_grad
def _any_differentiable(tensor_or_tuple_of_tensors):
flat_args, _ = tree_unflatten(tensor_or_tuple_of_tensors)
return any(tuple(map(_is_differentiable, flat_args)))
def _wrap_tensor_for_grad(maybe_tensor, level):
if not isinstance(maybe_tensor, torch.Tensor):
return maybe_tensor
return _wrap_for_grad(maybe_tensor, level)
def _wrap_all_tensors(tensor_pytree, level):
return tree_map(partial(_wrap_tensor_for_grad, level=level), tensor_pytree)
def _as_tuple(val):
if isinstance(val, tuple):
return val
return (val,)
# Version of autograd.grad that handles outputs that don't depend on inputs
def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True):
if grad_outputs is None:
diff_outputs = tuple(out for out in outputs if out.requires_grad)
else:
result = tuple((out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad)
if len(result) == 0:
diff_outputs, grad_outputs = (), ()
else:
diff_outputs, grad_outputs = zip(*result)
if len(diff_outputs) == 0:
return tuple(torch.zeros_like(inp) for inp in inputs)
grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs,
retain_graph=retain_graph,
create_graph=create_graph,
allow_unused=True)
grad_inputs = tuple(torch.zeros_like(inp) if gi is None else gi
for gi, inp in zip(grad_inputs, inputs))
return grad_inputs
# NOTE [grad and vjp interaction with no_grad]
#
# def f(x):
# with torch.no_grad():
# c = x ** 2
# return x - c
#
# The thing to consider is if enable_grad is on/off before grad gets called.
#
# Case 1: enable_grad is on.
# grad(f)(x)
# In this case, `grad` should respect the inner torch.no_grad.
#
# Case 2: enable_grad is off
# with torch.no_grad():
# grad(f)(x)
# In this case, `grad` should respect the inner torch.no_grad, but not the
# outer one. This is because `grad` is a "function transform": its result
# should not depend on the result of a context manager outside of `f`.
#
# This gives us the following desired behavior:
# - (nested) grad transforms must obey torch.no_grad inside them
# - (nested) grad transforms should not obey torch.no_grad outside them
#
# To achieve this behavior, upon entering grad/vjp:
# - we save the current ("previous") is_grad_enabled (*)
# - we unconditionally enable grad.
#
# Inside DynamicLayerBackFallback, when we're temporarily popping `grad` layer
# off the stack:
# - if grad_mode is disabled, then we do nothing. (there is a torch.no_grad
# active, all subsequent grad transforms must obey it).
# - if grad_mode is enabled, and the previous is_grad_enabled (*) is False,
# then we temporarily restore the previous `is_grad_enabled`. This is
# because we're crossing the boundary from a `grad` outside the
# no_grad to a `grad` inside the no_grad.
#
# NB: vjp has some interesting behavior because the vjp's callable can be called
# under a different grad_mode than the forward computation...
#
# NB: forward-mode AD: forward-mode AD doesn't respect torch.no_grad, but
# it respects c10::AutoFwGradMode. We've implemented the same logic for
# our jvp transform (it will have special handling if FwGradMode is disabled).
# How do we increment and decrement the nesting? I don't think we can.
def vjp(func: Callable, *primals, has_aux: bool = False):
"""
Standing for the vector-Jacobian product, returns a tuple containing the
results of :attr:`func` applied to :attr:`primals` and a function that, when
given ``cotangents``, computes the reverse-mode Jacobian of :attr:`func` with
respect to :attr:`primals` times ``cotangents``.
Args:
func (Callable): A Python function that takes one or more arguments. Must
return one or more Tensors.
primals (Tensors): Positional arguments to :attr:`func` that must all be
Tensors. The returned function will also be computing the
derivative with respect to these arguments
has_aux (bool): Flag indicating that :attr:`func` returns a
``(output, aux)`` tuple where the first element is the output of
the function to be differentiated and the second element is
other auxiliary objects that will not be differentiated.
Default: False.
Returns:
Returns a ``(output, vjp_fn)`` tuple containing the output of :attr:`func`
applied to :attr:`primals` and a function that computes the vjp of
:attr:`func` with respect to all :attr:`primals` using the cotangents passed
to the returned function. If ``has_aux is True``, then instead returns a
``(output, vjp_fn, aux)`` tuple.
The returned ``vjp_fn`` function will return a tuple of each VJP.
When used in simple cases, :func:`vjp` behaves the same as :func:`grad`
>>> x = torch.randn([5])
>>> f = lambda x: x.sin().sum()
>>> (_, vjpfunc) = functorch.vjp(f, x)
>>> grad = vjpfunc(torch.tensor(1.))[0]
>>> assert torch.allclose(grad, functorch.grad(f)(x))
However, :func:`vjp` can support functions with multiple outputs by
passing in the cotangents for each of the outputs
>>> x = torch.randn([5])
>>> f = lambda x: (x.sin(), x.cos())
>>> (_, vjpfunc) = functorch.vjp(f, x)
>>> vjps = vjpfunc((torch.ones([5]), torch.ones([5])))
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
:func:`vjp` can even support outputs being Python structs
>>> x = torch.randn([5])
>>> f = lambda x: {'first': x.sin(), 'second': x.cos()}
>>> (_, vjpfunc) = functorch.vjp(f, x)
>>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])}
>>> vjps = vjpfunc(cotangents)
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
The function returned by :func:`vjp` will compute the partials with
respect to each of the :attr:`primals`
>>> x, y = torch.randn([5, 4]), torch.randn([4, 5])
>>> (_, vjpfunc) = functorch.vjp(torch.matmul, x, y)
>>> cotangents = torch.randn([5, 5])
>>> vjps = vjpfunc(cotangents)
>>> assert len(vjps) == 2
>>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1)))
>>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))
:attr:`primals` are the positional arguments for :attr:`f`. All kwargs use their
default value
>>> x = torch.randn([5])
>>> def f(x, scale=4.):
>>> return x * 4.
>>>
>>> (_, vjpfunc) = functorch.vjp(f, x)
>>> vjps = vjpfunc(torch.ones_like(x))
>>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))
.. note::
Using PyTorch ``torch.no_grad`` together with ``vjp``.
Case 1: Using ``torch.no_grad`` inside a function:
>>> def f(x):
>>> with torch.no_grad():
>>> c = x ** 2
>>> return x - c
In this case, ``vjp(f)(x)`` will respect the inner ``torch.no_grad``.
Case 2: Using ``vjp`` inside ``torch.no_grad`` context manager:
>>> with torch.no_grad():
>>> vjp(f)(x)
In this case, ``vjp`` will respect the inner ``torch.no_grad``, but not the
outer one. This is because ``vjp`` is a "function transform": its result
should not depend on the result of a context manager outside of ``f``.
"""
return _vjp_with_argnums(func, *primals, has_aux=has_aux)
@doesnt_support_saved_tensors_hooks
def _vjp_with_argnums(func: Callable, *primals, argnums: Optional[argnums_t] = None, has_aux: bool = False):
# This is the same function as vjp but also accepts an argnums argument
# All args are the same as vjp except for the added argument
# argnums (Optional[int or tuple[int]]): Optional, specifies the argument(s) to compute gradients with respect to.
# If None, computes the gradients with respect to all inputs (used for vjp). Default: None
#
# WARN: Users should NOT call this function directly and should just be calling vjp.
# It is only separated so that inputs passed to jacrev but not differentiated get the correct wrappers.
#
# NOTE: All error messages are produced as if vjp was being called, even if this was called by jacrev
#
# Returns the same two elements as :func:`vjp` but the function returned, vjp_fn, returns a tuple of VJPs
# for only the primal elements given by argnums.
level = _grad_increment_nesting()
try:
# See NOTE [grad and vjp interaction with no_grad]
with torch.enable_grad():
primals = _wrap_all_tensors(primals, level)
if argnums is None:
diff_primals = _create_differentiable(primals, level)
else:
diff_primals = _slice_argnums(primals, argnums, as_tuple=False)
tree_map_(partial(_create_differentiable, level=level), diff_primals)
primals_out = func(*primals)
if has_aux:
if not (isinstance(primals_out, tuple) and len(primals_out) == 2):
raise RuntimeError(
"vjp(f, *primals): output of function f should be a tuple: (output, aux) "
"if has_aux is True"
)
primals_out, aux = primals_out
aux = _undo_create_differentiable(aux, level)
flat_primals_out, primals_out_spec = tree_flatten(primals_out)
assert_non_empty_tensor_output(flat_primals_out, 'vjp(f, *primals)')
flat_diff_primals, primals_spec = tree_flatten(diff_primals)
results = _undo_create_differentiable(primals_out, level)
for primal_out in flat_primals_out:
assert isinstance(primal_out, torch.Tensor)
if primal_out.is_floating_point() or primal_out.is_complex():
continue
raise RuntimeError("vjp(f, ...): All outputs of f must be "
"floating-point or complex Tensors, got Tensor "
f"with dtype {primal_out.dtype}")
def wrapper(cotangents, retain_graph=True, create_graph=None):
if create_graph is None:
create_graph = torch.is_grad_enabled()
flat_cotangents, cotangents_spec = tree_flatten(cotangents)
if primals_out_spec != cotangents_spec:
raise RuntimeError(
f'Expected pytree structure of cotangents to be the same '
f'as pytree structure of outputs to the function. '
f'cotangents: {treespec_pprint(cotangents_spec)}, '
f'primal output: {treespec_pprint(primals_out_spec)}')
result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangents,
retain_graph=retain_graph, create_graph=create_graph)
return tree_unflatten(result, primals_spec)
finally:
_grad_decrement_nesting()
if has_aux:
return results, wrapper, aux
else:
return results, wrapper
def _safe_zero_index(x):
assert len(x) == 1
return x[0]
def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False):
"""
Computes the Jacobian of :attr:`func` with respect to the arg(s) at index
:attr:`argnum` using reverse mode autodiff
Args:
func (function): A Python function that takes one or more arguments,
one of which must be a Tensor, and returns one or more Tensors
argnums (int or Tuple[int]): Optional, integer or tuple of integers,
saying which arguments to get the Jacobian with respect to.
Default: 0.
has_aux (bool): Flag indicating that :attr:`func` returns a
``(output, aux)`` tuple where the first element is the output of
the function to be differentiated and the second element is
auxiliary objects that will not be differentiated.
Default: False.
Returns:
Returns a function that takes in the same inputs as :attr:`func` and
returns the Jacobian of :attr:`func` with respect to the arg(s) at
:attr:`argnums`. If ``has_aux is True``, then the returned function
instead returns a ``(jacobian, aux)`` tuple where ``jacobian``
is the Jacobian and ``aux`` is auxiliary objects returned by :attr:`func`.
A basic usage with a pointwise, unary operation will give a diagonal array
as the Jacobian
>>> from functorch import jacrev
>>> x = torch.randn(5)
>>> jacobian = jacrev(torch.sin)(x)
>>> expected = torch.diag(torch.cos(x))
>>> assert torch.allclose(jacobian, expected)
If you would like to compute the output of the function as well as the
jacobian of the function, use the ``has_aux`` flag to return the output
as an auxiliary object:
>>> from functorch import jacrev
>>> x = torch.randn(5)
>>>
>>> def f(x):
>>> return x.sin()
>>>
>>> def g(x):
>>> result = f(x)
>>> return result, result
>>>
>>> jacobian_f, f_x = jacrev(g, has_aux=True)(x)
>>> assert torch.allclose(f_x, f(x))
:func:`jacrev` can be composed with vmap to produce batched
Jacobians:
>>> from functorch import jacrev, vmap
>>> x = torch.randn(64, 5)
>>> jacobian = vmap(jacrev(torch.sin))(x)
>>> assert jacobian.shape == (64, 5, 5)
Additionally, :func:`jacrev` can be composed with itself to produce
Hessians
>>> from functorch import jacrev
>>> def f(x):
>>> return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hessian = jacrev(jacrev(f))(x)
>>> assert torch.allclose(hessian, torch.diag(-x.sin()))
By default, :func:`jacrev` computes the Jacobian with respect to the first
input. However, it can compute the Jacboian with respect to a different
argument by using :attr:`argnums`:
>>> from functorch import jacrev
>>> def f(x, y):
>>> return x + y ** 2
>>>
>>> x, y = torch.randn(5), torch.randn(5)
>>> jacobian = jacrev(f, argnums=1)(x, y)
>>> expected = torch.diag(2 * y)
>>> assert torch.allclose(jacobian, expected)
Additionally, passing a tuple to :attr:`argnums` will compute the Jacobian
with respect to multiple arguments
>>> from functorch import jacrev
>>> def f(x, y):
>>> return x + y ** 2
>>>
>>> x, y = torch.randn(5), torch.randn(5)
>>> jacobian = jacrev(f, argnums=(0, 1))(x, y)
>>> expectedX = torch.diag(torch.ones_like(x))
>>> expectedY = torch.diag(2 * y)
>>> assert torch.allclose(jacobian[0], expectedX)
>>> assert torch.allclose(jacobian[1], expectedY)
.. note::
Using PyTorch ``torch.no_grad`` together with ``jacrev``.
Case 1: Using ``torch.no_grad`` inside a function:
>>> def f(x):
>>> with torch.no_grad():
>>> c = x ** 2
>>> return x - c
In this case, ``jacrev(f)(x)`` will respect the inner ``torch.no_grad``.
Case 2: Using ``jacrev`` inside ``torch.no_grad`` context manager:
>>> with torch.no_grad():
>>> jacrev(f)(x)
In this case, ``jacrev`` will respect the inner ``torch.no_grad``, but not the
outer one. This is because ``jacrev`` is a "function transform": its result
should not depend on the result of a context manager outside of ``f``.
"""
@wraps(func)
def wrapper_fn(*args):
vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux)
if has_aux:
output, vjp_fn, aux = vjp_out
else:
output, vjp_fn = vjp_out
# See NOTE: [Computing jacobian with vmap and vjp for multiple outputs]
flat_output, output_spec = tree_flatten(output)
# NB: vjp already checks that all outputs are tensors
# Step 1: Construct grad_outputs by splitting the standard basis
flat_output_numels = tuple(out.numel() for out in flat_output)
flat_basis = _construct_standard_basis_for(flat_output, flat_output_numels)
basis = tree_unflatten(flat_basis, output_spec)
results = vmap(vjp_fn)(basis)
primals = _slice_argnums(args, argnums)
flat_primals, primals_spec = tree_flatten(primals)
flat_results, results_spec = tree_flatten(results)
# Step 2: The returned jacobian is one big tensor per input. In this step,
# we split each Tensor by output.
flat_results = [result.split(flat_output_numels, dim=0) for result in flat_results]
flat_input_flat_output = [
tuple(split.view(out.shape + primal.shape)
for split, out in zip(splits, flat_output))
for splits, primal in zip(flat_results, flat_primals)
]
# Step 3: Right now, `jacobian` is a List[List[Tensor]].
# The outer List corresponds to the number of primals,
# the inner List corresponds to the number of outputs.
# We need to:
# a. Exchange the order of the outer List and inner List
# b. tree_unflatten the inner Lists (which correspond to the primals)
# c. handle the argnums=int case
# d. tree_unflatten the outer List (which corresponds to the outputs)
flat_output_flat_input = tuple(zip(*flat_input_flat_output))
flat_output_input = tuple(tree_unflatten(flat_input, primals_spec)
for flat_input in flat_output_flat_input)
if isinstance(argnums, int):
flat_output_input = tuple(_safe_zero_index(flat_input)
for flat_input in flat_output_input)
output_input = tree_unflatten(flat_output_input, output_spec)
if has_aux:
return output_input, aux
return output_input
return wrapper_fn
# NOTE: [Computing jacobian with vmap and vjp for multiple outputs]
#
# Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3).
# It turns out we can compute the jacobian of this function with a single
# call to autograd.grad by using vmap over the correct grad_outputs.
#
# Firstly, one way to compute the jacobian is to stack x**2 and x.sum()
# into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()])
#
# To get the first row of the jacobian, we call
# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0]))
# To get the 2nd row of the jacobian, we call
# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0]))
# and so on.
#
# Using vmap, we can vectorize all 4 of these computations into one by
# passing the standard basis for R^4 as the grad_output.
# vmap(partial(autograd.grad, g(x), x))(torch.eye(4)).
#
# Now, how do we compute the jacobian *without stacking the output*?
# We can just split the standard basis across the outputs. So to
# compute the jacobian of f(x), we'd use
# >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...))
# The grad_outputs looks like the following:
# ( torch.tensor([[1, 0, 0],
# [0, 1, 0],
# [0, 0, 1],
# [0, 0, 0]]),
# torch.tensor([[0],
# [0],
# [0],
# [1]]) )
#
# But we're not done yet!
# >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...)))
# returns a Tensor of shape [4, 3]. We have to remember to split the
# jacobian of shape [4, 3] into two:
# - one of shape [3, 3] for the first output
# - one of shape [ 3] for the second output
def _construct_standard_basis_for(tensors, tensor_numels):
# This function:
# - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix.
# - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`.
# - Each chunk corresponds to one tensor. The chunk has the same dtype and
# device as the tensor
#
# For example, with tensor_numels = [1, 2, 1], this function returns:
# ( tensor([[1], tensor([[0, 0], tensor([[0],
# [0], [1, 0], [0],
# [0], [0, 1], [0],
# [0]]) , [0, 0]]) , [1]]) )
#
# Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors)
# Precondition: tensors always has at least one element.
#
# See NOTE: [Computing jacobian with vmap and grad for multiple tensors]
# for context behind this function.
assert len(tensors) == len(tensor_numels)
assert len(tensors) > 0
total_numel = sum(tensor_numels)
diag_start_indices = (0, *torch.tensor(tensor_numels).cumsum(dim=0)[:-1].neg().unbind())
chunks = tuple(tensor.new_zeros(total_numel, tensor_numel)
for tensor, tensor_numel in zip(tensors, tensor_numels))
for chunk, diag_start_idx in zip(chunks, diag_start_indices):
chunk.diagonal(diag_start_idx).fill_(1)
chunks = tuple(chunk.view(total_numel, *tensor.shape)
for chunk, tensor in zip(chunks, tensors))
return chunks
def _validate_and_wrap_argnum(argnum, num_args):
if not isinstance(argnum, int):
raise RuntimeError(f'argnum must be int, got: {type(argnum)}')
if argnum >= 0 and argnum < num_args:
return argnum
if argnum < 0 and argnum >= -num_args:
return argnum + num_args
raise RuntimeError(f'Got argnum={argnum}, but only {num_args} positional inputs')
def _check_unique_non_empty(argnums):
if isinstance(argnums, tuple):
if len(argnums) == 0:
raise RuntimeError("argnums must be non-empty")
if len(set(argnums)) != len(argnums):
raise RuntimeError(f"argnums elements must be unique, got {argnums}")
def _replace_args(old_args, new_args, argnums):
if isinstance(argnums, int):
if len(new_args) != 1:
raise RuntimeError(f'new_args should be of size 1, was of size {len(new_args)}')
return tuple(new_args[0] if i == argnums else old_args[i] for i in range(len(old_args)))
if isinstance(argnums, tuple):
if len(new_args) != len(argnums):
raise RuntimeError(
"new_args should have the same size as argnums. "
f"Argnums size {len(argnums)}, new_args size {len(new_args)}")
def get_right_elem(i):
return new_args[argnums.index(i)] if i in argnums else old_args[i]
return tuple(get_right_elem(i) for i in range(len(old_args)))
raise RuntimeError(f'argnums must be int or Tuple[int, ...], got: {type(argnums)}')
def _validate_and_wrap_argnums(argnums, num_args):
if isinstance(argnums, int):
return _validate_and_wrap_argnum(argnums, num_args)
if isinstance(argnums, tuple):
return tuple(_validate_and_wrap_argnum(argnum, num_args) for argnum in argnums)
raise AssertionError("Should never get here")
def _slice_argnums(args, argnums, as_tuple=True):
if not isinstance(argnums, int) and not isinstance(argnums, tuple):
raise RuntimeError(f'argnums must be int or Tuple[int, ...], got: {type(argnums)}')
argnums = _validate_and_wrap_argnums(argnums, len(args))
_check_unique_non_empty(argnums)
if isinstance(argnums, int):
if as_tuple:
return (args[argnums],)
else:
return args[argnums]
return tuple(args[i] for i in argnums)
JVP_NESTING = 0
@contextlib.contextmanager
def noop():
yield
@contextlib.contextmanager
def enable_fwd_grad(enabled=True):
prev_state = get_fwd_grad_enabled()
set_fwd_grad_enabled(enabled)
try:
yield
finally:
set_fwd_grad_enabled(prev_state)
def assert_flat_tuple_of_tensors(elts: Any, api: str, argname: str) -> None:
if not isinstance(elts, tuple):
raise RuntimeError(
f'{api}: Expected {argname} to be a tuple of Tensors, got {type(elts)}')
for elt in elts:
if isinstance(elt, torch.Tensor):
continue
raise RuntimeError(
f'{api}: Expected {argname} to be a tuple of Tensors, got '
f'a tuple with an element of type {type(elt)}')
if len(elts) == 0:
raise RuntimeError(
f'{api}: Expected {argname} to be a non-empty tuple of Tensors.')
def assert_non_empty_tensor_output(output: List[Any], api: str) -> None:
if output == [None] or len(output) < 1:
raise RuntimeError(
f'{api}: Expected f to be a function that has non-empty output (got output = {output})'
)
for o in output:
if not isinstance(o, torch.Tensor):
raise RuntimeError(
f'{api}: expected f(*primals) to return only tensors'
f', got unsupported type {type(o)}'
)
def assert_output_is_tensor_or_tensors(output: Any, api: str) -> None:
if isinstance(output, torch.Tensor):
return
if not isinstance(output, tuple):
raise RuntimeError(
f'{api}: Expected output of f to be a Tensor or Tensors, got '
f'{type(output)}')
if len(output) == 0:
raise RuntimeError(
f'{api}: Expected output of f to be a non-empty tuple of Tensors.')
for out in output:
if isinstance(out, torch.Tensor):
continue
raise RuntimeError(
f'{api}: Expected output of f to be a Tensor or Tensors, got '
f'{type(out)} as an output')
def assert_non_empty_list_of_tensors(output: List[torch.Tensor], api: str, argname: str) -> None:
if len(output) == 0:
raise RuntimeError(
f'{api}: Expected {argname} to contain at least one Tensor.')
for out in output:
if isinstance(out, torch.Tensor):
continue
raise RuntimeError(
f'{api}: Expected {argname} to only contain Tensors, got '
f'{type(out)}')
jvp_str = 'jvp(f, primals, tangents)'
def safe_unpack_dual(dual, strict):
if not isinstance(dual, torch.Tensor):
raise RuntimeError(
f'{jvp_str}: expected f(*args) to return only tensors'
f', got unsupported type {type(dual)}'
)
primal, tangent = fwAD.unpack_dual(dual)
if tangent is None:
if strict:
raise RuntimeError(
'jvp(f, primals, tangents, strict=True): '
'The output of f is independent of '
'the inputs. This is not allowed with strict=True.')
tangent = torch.zeros_like(primal)
return primal, tangent
def jvp(func: Callable, primals: Any, tangents: Any, *, strict: bool = False, has_aux: bool = False):
"""
Standing for the Jacobian-vector product, returns a tuple containing
the output of `func(*primals)` and the "Jacobian of ``func`` evaluated at
``primals``" times ``tangents``. This is also known as forward-mode autodiff.
Args:
func (function): A Python function that takes one or more arguments,
one of which must be a Tensor, and returns one or more Tensors
primals (Tensors): Positional arguments to :attr:`func` that must all be
Tensors. The returned function will also be computing the
derivative with respect to these arguments
tangents (Tensors): The "vector" for which Jacobian-vector-product is
computed. Must be the same structure and sizes as the inputs to
``func``.
has_aux (bool): Flag indicating that :attr:`func` returns a
``(output, aux)`` tuple where the first element is the output of
the function to be differentiated and the second element is
other auxiliary objects that will not be differentiated.
Default: False.
Returns:
Returns a ``(output, jvp_out)`` tuple containing the output of ``func``
evaluated at ``primals`` and the Jacobian-vector product.
If ``has_aux is True``, then instead returns a ``(output, jvp_out, aux)`` tuple.
.. note::
You may see this API error out with "forward-mode AD not implemented
for operator X". If so, please file a bug report and we will prioritize it.
jvp is useful when you wish to compute gradients of a function R^1 -> R^N
>>> from functorch import jvp
>>> x = torch.randn([])
>>> f = lambda x: x * torch.tensor([1., 2., 3])
>>> value, grad = jvp(f, (x,), (torch.tensor(1.),))
>>> assert torch.allclose(value, f(x))
>>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))
:func:`jvp` can support functions with multiple inputs by passing in the
tangents for each of the inputs
>>> from functorch import jvp
>>> x = torch.randn(5)
>>> y = torch.randn(5)
>>> f = lambda x, y: (x * y)
>>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
>>> assert torch.allclose(output, x + y)
"""
return _jvp_with_argnums(func, primals, tangents, argnums=None, strict=strict, has_aux=has_aux)
@doesnt_support_saved_tensors_hooks
def _jvp_with_argnums(func: Callable, primals: Any, tangents: Any, argnums: Optional[argnums_t], *,
strict: bool = False, has_aux: bool):
# This is the same function as jvp but also accepts an argnums argument
# Most args are the same as jvp except for the added argument
# argnums (Optional[int or tuple[int]]): Optional, specifies the argument(s) to compute gradients with respect to.
# If None, computes the gradients with respect to all inputs (used for jvp). Default: None
# Because of this, tangents must be of length argnums and matches up to the corresponding primal whose index is
# given by argnums
#
# WARN: Users should NOT call this function directly and should just be calling jvp.
# It is only separated so that inputs passed to jacfwd but not differentiated get the correct wrappers.
#
# NOTE: All error messages are produced as if jvp was being called, even if this was called by jacfwd
#
# Returns the same two elements as :func:`jvp` but the returned tuple, ``jvp_out``, only has JVPs with respect to
# the primals given by argnums
if not isinstance(primals, tuple):
raise RuntimeError(
f'{jvp_str}: Expected primals to be a tuple. '
f'E.g. it should be valid to call f(*primals).')
diff_args = primals if argnums is None else _slice_argnums(primals, argnums)
flat_primals, primals_spec = tree_flatten(diff_args)
flat_tangents, tangents_spec = tree_flatten(tangents)
if primals_spec != tangents_spec:
raise RuntimeError(
f'{jvp_str}: Expected primals and tangents to have the same python '
f'structure. For example, if primals is a tuple of 3 tensors, '
f'tangents also must be. Got primals with structure {primals_spec} '
f'and tangents with structure {tangents_spec}')
assert_non_empty_list_of_tensors(flat_primals, jvp_str, 'primals')
assert_non_empty_list_of_tensors(flat_tangents, jvp_str, 'tangents')
level = _jvp_increment_nesting()
try:
global JVP_NESTING
JVP_NESTING += 1
with enable_fwd_grad():
ctx = fwAD.dual_level if JVP_NESTING == 1 else noop
with ctx():
flat_duals = tuple(fwAD.make_dual(p, t)
for p, t in zip(flat_primals, flat_tangents))
duals = tree_unflatten(flat_duals, primals_spec)
if argnums is not None:
primals = _wrap_all_tensors(primals, level)
duals = _replace_args(primals, duals, argnums)
result_duals = func(*duals)
if has_aux:
if not (isinstance(result_duals, tuple) and len(result_duals) == 2):
raise RuntimeError(
f"{jvp_str}: output of function f should be a tuple: (output, aux) "
"if has_aux is True"
)
result_duals, aux = result_duals
aux = _undo_create_differentiable(aux, level)
result_duals, spec = tree_flatten(result_duals)
assert_non_empty_tensor_output(result_duals, jvp_str)
primals_out, tangents_out = \
zip(*[safe_unpack_dual(dual, strict) for dual in result_duals])
primals_out = tree_map(
partial(_undo_create_differentiable, level=level), primals_out)
tangents_out = tree_map(
partial(_undo_create_differentiable, level=level), tangents_out)
primals_out_unflatten = tree_unflatten(primals_out, spec)
tangents_out_unflatten = tree_unflatten(tangents_out, spec)
if has_aux:
return primals_out_unflatten, tangents_out_unflatten, aux
return primals_out_unflatten, tangents_out_unflatten
finally:
_jvp_decrement_nesting()
JVP_NESTING -= 1
def safe_unflatten(tensor, dim, shape):
if len(shape) == 0:
assert tensor.shape[dim] == 1
return tensor.squeeze(dim)
return tensor.unflatten(dim, shape)
def jacfwd(func: Callable, argnums: argnums_t = 0, has_aux: bool = False, *, randomness: str = "error"):
"""
Computes the Jacobian of :attr:`func` with respect to the arg(s) at index
:attr:`argnum` using forward-mode autodiff
Args:
func (function): A Python function that takes one or more arguments,
one of which must be a Tensor, and returns one or more Tensors
argnums (int or Tuple[int]): Optional, integer or tuple of integers,
saying which arguments to get the Jacobian with respect to.
Default: 0.
has_aux (bool): Flag indicating that :attr:`func` returns a
``(output, aux)`` tuple where the first element is the output of
the function to be differentiated and the second element is
auxiliary objects that will not be differentiated.
Default: False.
randomness(str): Flag indicating what type of randomness to use.
See :func:`vmap` for more detail. Allowed: "different", "same", "error".
Default: "error"
Returns:
Returns a function that takes in the same inputs as :attr:`func` and
returns the Jacobian of :attr:`func` with respect to the arg(s) at
:attr:`argnums`. If ``has_aux is True``, then the returned function
instead returns a ``(jacobian, aux)`` tuple where ``jacobian``
is the Jacobian and ``aux`` is auxiliary objects returned by :attr:`func`.
.. note::
You may see this API error out with "forward-mode AD not implemented
for operator X". If so, please file a bug report and we will prioritize it.
An alternative is to use :func:`jacrev`, which has better operator coverage.
A basic usage with a pointwise, unary operation will give a diagonal array
as the Jacobian
>>> from functorch import jacfwd
>>> x = torch.randn(5)
>>> jacobian = jacfwd(torch.sin)(x)
>>> expected = torch.diag(torch.cos(x))
>>> assert torch.allclose(jacobian, expected)
:func:`jacfwd` can be composed with vmap to produce batched
Jacobians:
>>> from functorch import jacfwd, vmap
>>> x = torch.randn(64, 5)
>>> jacobian = vmap(jacfwd(torch.sin))(x)
>>> assert jacobian.shape == (64, 5, 5)
If you would like to compute the output of the function as well as the
jacobian of the function, use the ``has_aux`` flag to return the output
as an auxiliary object:
>>> from functorch import jacfwd
>>> x = torch.randn(5)
>>>
>>> def f(x):
>>> return x.sin()
>>>
>>> def g(x):
>>> result = f(x)
>>> return result, result
>>>
>>> jacobian_f, f_x = jacfwd(g, has_aux=True)(x)
>>> assert torch.allclose(f_x, f(x))
Additionally, :func:`jacrev` can be composed with itself or :func:`jacrev`
to produce Hessians
>>> from functorch import jacfwd, jacrev
>>> def f(x):
>>> return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hessian = jacfwd(jacrev(f))(x)
>>> assert torch.allclose(hessian, torch.diag(-x.sin()))
By default, :func:`jacfwd` computes the Jacobian with respect to the first
input. However, it can compute the Jacboian with respect to a different
argument by using :attr:`argnums`:
>>> from functorch import jacfwd
>>> def f(x, y):
>>> return x + y ** 2
>>>
>>> x, y = torch.randn(5), torch.randn(5)
>>> jacobian = jacfwd(f, argnums=1)(x, y)
>>> expected = torch.diag(2 * y)
>>> assert torch.allclose(jacobian, expected)
Additionally, passing a tuple to :attr:`argnums` will compute the Jacobian
with respect to multiple arguments
>>> from functorch import jacfwd
>>> def f(x, y):
>>> return x + y ** 2
>>>
>>> x, y = torch.randn(5), torch.randn(5)
>>> jacobian = jacfwd(f, argnums=(0, 1))(x, y)
>>> expectedX = torch.diag(torch.ones_like(x))
>>> expectedY = torch.diag(2 * y)
>>> assert torch.allclose(jacobian[0], expectedX)
>>> assert torch.allclose(jacobian[1], expectedY)
"""
@wraps(func)
def wrapper_fn(*args):
primals = args if argnums is None else _slice_argnums(args, argnums)
flat_primals, primals_spec = tree_flatten(primals)
flat_primals_numels = tuple(p.numel() for p in flat_primals)
flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels)
basis = tree_unflatten(flat_basis, primals_spec)
def push_jvp(basis):
output = _jvp_with_argnums(func, args, basis, argnums=argnums, has_aux=has_aux)
if has_aux:
_, jvp_out, aux = output
return jvp_out, aux
_, jvp_out = output
return jvp_out
results = vmap(push_jvp, randomness=randomness)(basis)
if has_aux:
results, aux = results
# aux is in the standard basis format, e.g. NxN matrix
# We need to fetch the first element as original `func` output
flat_aux, aux_spec = tree_flatten(aux)
flat_aux = [value[0] for value in flat_aux]
aux = tree_unflatten(flat_aux, aux_spec)
jac_outs, spec = tree_flatten(results)
# Most probably below output check can never raise an error
# as jvp should test the output before
# assert_non_empty_output(jac_outs, 'jacfwd(f, ...)(*args)')
jac_outs_ins = tuple(
tuple(
safe_unflatten(jac_out_in, -1, primal.shape)
for primal, jac_out_in in
zip(flat_primals, jac_out.movedim(0, -1).split(flat_primals_numels, dim=-1))
)
for jac_out in jac_outs
)
jac_outs_ins = tuple(tree_unflatten(jac_ins, primals_spec) for jac_ins in jac_outs_ins)
if isinstance(argnums, int):
jac_outs_ins = tuple(jac_ins[0] for jac_ins in jac_outs_ins)
if has_aux:
return tree_unflatten(jac_outs_ins, spec), aux
return tree_unflatten(jac_outs_ins, spec)
return wrapper_fn
def hessian(func, argnums=0):
"""
Computes the Hessian of :attr:`func` with respect to the arg(s) at index
:attr:`argnum` via a forward-over-reverse strategy.
The forward-over-reverse strategy (composing ``jacfwd(jacrev(func))``) is
a good default for good performance. It is possible to compute Hessians
through other compositions of :func:`jacfwd` and :func:`jacrev` like
``jacfwd(jacfwd(func))`` or ``jacrev(jacrev(func))``.
Args:
func (function): A Python function that takes one or more arguments,
one of which must be a Tensor, and returns one or more Tensors
argnums (int or Tuple[int]): Optional, integer or tuple of integers,
saying which arguments to get the Hessian with respect to.
Default: 0.
Returns:
Returns a function that takes in the same inputs as :attr:`func` and
returns the Hessian of :attr:`func` with respect to the arg(s) at
:attr:`argnums`.
.. note::
You may see this API error out with "forward-mode AD not implemented
for operator X". If so, please file a bug report and we will prioritize it.
An alternative is to use ``jacrev(jacrev(func))``, which has better
operator coverage.
A basic usage with a R^N -> R^1 function gives a N x N Hessian:
>>> from functorch import hessian
>>> def f(x):
>>> return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hess = jacfwd(jacrev(f))(x)
>>> assert torch.allclose(hess, torch.diag(-x.sin()))
"""
return jacfwd(jacrev(func, argnums), argnums)
def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
"""
Returns a function to compute a tuple of the gradient and primal, or
forward, computation.
Args:
func (Callable): A Python function that takes one or more arguments.
Must return a single-element Tensor. If specified :attr:`has_aux`
equals ``True``, function can return a tuple of single-element
Tensor and other auxiliary objects: ``(output, aux)``.
argnums (int or Tuple[int]): Specifies arguments to compute gradients
with respect to. :attr:`argnums` can be single integer or tuple of
integers. Default: 0.
has_aux (bool): Flag indicating that :attr:`func` returns a tensor and
other auxiliary objects: ``(output, aux)``. Default: False.
Returns:
Function to compute a tuple of gradients with respect to its inputs
and the forward computation. By default, the output of the function is
a tuple of the gradient tensor(s) with respect to the first argument
and the primal computation. If specified :attr:`has_aux` equals
``True``, tuple of gradients and tuple of the forward computation with
output auxiliary objects is returned. If :attr:`argnums` is a tuple of
integers, a tuple of a tuple of the output gradients with respect to
each :attr:`argnums` value and the forward computation is returned.
See :func:`grad` for examples
"""
@doesnt_support_saved_tensors_hooks
@wraps(func)
def wrapper(*args, **kwargs):
level = _grad_increment_nesting()
try:
output, aux, grad_input = None, None, None
# See NOTE [grad and vjp interaction with no_grad]
with torch.enable_grad():
args = _wrap_all_tensors(args, level)
kwargs = _wrap_all_tensors(kwargs, level)
diff_args = _slice_argnums(args, argnums, as_tuple=False)
tree_map_(partial(_create_differentiable, level=level), diff_args)
output = func(*args, **kwargs)
if has_aux:
if not (isinstance(output, tuple) and len(output) == 2):
raise RuntimeError(
"grad_and_value(f)(*args): output of function f should be a tuple: (output, aux) "
"if has_aux is True"
)
output, aux = output
if not isinstance(output, torch.Tensor):
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
f'to return a Tensor, got {type(output)}')
if output.dim() != 0:
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
'to return a scalar Tensor, got tensor with '
f'{output.dim()} dims. Maybe you wanted to '
'use the vjp or jacrev APIs instead?')
flat_diff_args, spec = tree_flatten(diff_args)
# NB: need create_graph so that backward pass isn't run in no_grad mode
flat_outputs = _as_tuple(output)
flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True)
grad_input = tree_unflatten(flat_grad_input, spec)
grad_input = _undo_create_differentiable(grad_input, level)
output = _undo_create_differentiable(output, level)
if aux is not None:
aux = _undo_create_differentiable(aux, level)
if has_aux:
return grad_input, (output, aux)
return grad_input, output
finally:
_grad_decrement_nesting()
return wrapper
def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
"""``grad`` operator helps computing gradients of :attr:`func` with respect to the
input(s) specified by :attr:`argnums`. This operator can be nested to
compute higher-order gradients.
Args:
func (Callable): A Python function that takes one or more arguments.
Must return a single-element Tensor. If specified :attr:`has_aux` equals ``True``,
function can return a tuple of single-element Tensor and other auxiliary objects:
``(output, aux)``.
argnums (int or Tuple[int]): Specifies arguments to compute gradients with respect to.
:attr:`argnums` can be single integer or tuple of integers. Default: 0.
has_aux (bool): Flag indicating that :attr:`func` returns a tensor and other
auxiliary objects: ``(output, aux)``. Default: False.
Returns:
Function to compute gradients with respect to its inputs. By default, the output of
the function is the gradient tensor(s) with respect to the first argument.
If specified :attr:`has_aux` equals ``True``, tuple of gradients and output auxiliary objects
is returned. If :attr:`argnums` is a tuple of integers, a tuple of output gradients with
respect to each :attr:`argnums` value is returned.
Example of using ``grad``:
>>> from functorch import grad
>>> x = torch.randn([])
>>> cos_x = grad(lambda x: torch.sin(x))(x)
>>> assert torch.allclose(cos_x, x.cos())
>>>
>>> # Second-order gradients
>>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
>>> assert torch.allclose(neg_sin_x, -x.sin())
When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients:
>>> from functorch import grad
>>> from functorch import vmap
>>> batch_size, feature_size = 3, 5
>>>
>>> def model(weights, feature_vec):
>>> # Very simple linear model with activation
>>> assert feature_vec.dim() == 1
>>> return feature_vec.dot(weights).relu()
>>>
>>> def compute_loss(weights, example, target):
>>> y = model(weights, example)
>>> return ((y - target) ** 2).mean() # MSELoss
>>>
>>> weights = torch.randn(feature_size, requires_grad=True)
>>> examples = torch.randn(batch_size, feature_size)
>>> targets = torch.randn(batch_size)
>>> inputs = (weights, examples, targets)
>>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
Example of using ``grad`` with :attr:`has_aux` and :attr:`argnums`:
>>> from functorch import grad
>>> def my_loss_func(y, y_pred):
>>> loss_per_sample = (0.5 * y_pred - y) ** 2
>>> loss = loss_per_sample.mean()
>>> return loss, (y_pred, loss_per_sample)
>>>
>>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True)
>>> y_true = torch.rand(4)
>>> y_preds = torch.rand(4, requires_grad=True)
>>> out = fn(y_true, y_preds)
>>> > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))
.. note::
Using PyTorch ``torch.no_grad`` together with ``grad``.
Case 1: Using ``torch.no_grad`` inside a function:
>>> def f(x):
>>> with torch.no_grad():
>>> c = x ** 2
>>> return x - c
In this case, ``grad(f)(x)`` will respect the inner ``torch.no_grad``.
Case 2: Using ``grad`` inside ``torch.no_grad`` context manager:
>>> with torch.no_grad():
>>> grad(f)(x)
In this case, ``grad`` will respect the inner ``torch.no_grad``, but not the
outer one. This is because ``grad`` is a "function transform": its result
should not depend on the result of a context manager outside of ``f``.
"""
@wraps(func)
def wrapper(*args, **kwargs):
results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
if has_aux:
grad, (_, aux) = results
return grad, aux
grad, _ = results
return grad
return wrapper
def _maybe_wrap_functional_tensor(maybe_tensor, level):
if not isinstance(maybe_tensor, torch.Tensor):
return maybe_tensor
wrapped = _wrap_functional_tensor(maybe_tensor, level)
_assert_wrapped_functional(maybe_tensor, wrapped)
return wrapped
def _wrap_all_tensors_to_functional(tensor_pytree, level):
return tree_map(partial(_maybe_wrap_functional_tensor, level=level), tensor_pytree)
def _maybe_unwrap_functional_tensor(maybe_tensor, *, reapply_views: bool):
if not isinstance(maybe_tensor, torch.Tensor):
return maybe_tensor
if not torch._is_functional_tensor(maybe_tensor):
# If it's not a functional tensor, just return it.
# This can happen if we functionalize a fn that returns a global,
# which was never wrapped properly.
return maybe_tensor
return _unwrap_functional_tensor(maybe_tensor, reapply_views)
def _unwrap_all_tensors_from_functional(tensor_pytree, *, reapply_views: bool):
return tree_map(lambda t: _maybe_unwrap_functional_tensor(t, reapply_views=reapply_views), tensor_pytree)
def functionalize(func: Callable, *, remove: str = 'mutations') -> Callable:
"""
functionalize is a transform that can be used to remove (intermediate)
mutations and aliasing from a function, while preserving the function's
semantics.
``functionalize(func)`` returns a new function with the same semantics
as ``func``, but with all intermediate mutations removed.
Every inplace operation performed on an intermediate tensor:
``intermediate.foo_()``
gets replaced by its out-of-place equivalent:
``intermediate_updated = intermediate.foo()``.
functionalize is useful for shipping a pytorch program off to
backends or compilers that aren't able to easily represent
mutations or aliasing operators.
Args:
func (Callable): A Python function that takes one or more arguments.
remove (str): An optional string argument, that takes on either
the value 'mutations' or 'mutations_and_views'.
If 'mutations' is passed in then all mutating operators
will be replaced with their non-mutating equivalents.
If 'mutations_and_views' is passed in, then additionally, all aliasing
operators will be replaced with their non-aliasing equivalents.
Default: 'mutations'.
Returns:
Returns a new "functionalized" function. It takes the same inputs as
:attr:`func`, and has the same behavior, but any mutations
(and optionally aliasing) performed on intermeidate tensors
in the function will be removed.
functionalize will also remove mutations (and views) that were performed on function inputs.
However to preserve semantics, functionalize will "fix up" the mutations after
the transform has finished running, by detecting if any tensor inputs "should have"
been mutated, and copying the new data back to the inputs if necessary.
Example::
>>> import torch
>>> from functorch import make_fx
>>> from functorch.experimental import functionalize
>>>
>>> A function that uses mutations and views, but only on intermediate tensors.
>>> def f(a):
... b = a + 1
... c = b.view(-1)
... c.add_(1)
... return b
...
>>> inpt = torch.randn(2)
>>>
>>> out1 = f(inpt)
>>> out2 = functionalize(f)(inpt)
>>>
>>> # semantics are the same (outputs are equivalent)
>>> print(torch.allclose(out1, out2))
True
>>>
>>> f_traced = make_fx(f)(inpt)
>>> f_no_mutations_traced = make_fx(functionalize(f))(inpt)
>>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
>>>
>>> print(f_traced.code)
def forward(self, a_1):
add = torch.ops.aten.add(a_1, 1); a_1 = None
view = torch.ops.aten.view(add, [-1])
add_ = torch.ops.aten.add_(view, 1); view = None
return add
>>> print(f_no_mutations_traced.code)
def forward(self, a_1):
add = torch.ops.aten.add(a_1, 1); a_1 = None
view = torch.ops.aten.view(add, [-1]); add = None
add_1 = torch.ops.aten.add(view, 1); view = None
view_1 = torch.ops.aten.view(add_1, [2]); add_1 = None
return view_1
>>> print(f_no_mutations_and_views_traced.code)
def forward(self, a_1):
add = torch.ops.aten.add(a_1, 1); a_1 = None
view_copy = torch.ops.aten.view_copy(add, [-1]); add = None
add_1 = torch.ops.aten.add(view_copy, 1); view_copy = None
view_copy_1 = torch.ops.aten.view_copy(add_1, [2]); add_1 = None
return view_copy_1
>>> A function that mutates its input tensor
>>> def f(a):
... b = a.view(-1)
... b.add_(1)
... return a
...
>>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
>>>
>>> All mutations and views have been removed,
>>> but there is an extra copy_ in the graph to correctly apply the mutation to the input
>>> after the function has completed.
>>> print(f_no_mutations_and_views_traced.code)
def forward(self, a_1):
view_copy = torch.ops.aten.view_copy(a_1, [-1])
add = torch.ops.aten.add(view_copy, 1); view_copy = None
view_copy_1 = torch.ops.aten.view_copy(add, [2]); add = None
copy_ = torch.ops.aten.copy_(a_1, view_copy_1); a_1 = None
return view_copy_1
There are a few "failure modes" for functionalize that are worth calling out:
(1) Like other functorch transforms, `functionalize()` doesn't work with functions
that directly use `.backward()`. The same is true for torch.autograd.grad.
If you want to use autograd, you can compute gradients directly
with `functionalize(grad(f))`.
(2) Like other functorch transforms, `functionalize()` doesn't work with global state.
If you call `functionalize(f)` on a function that takes views / mutations of
non-local state, functionalization will simply no-op and pass the view/mutation
calls directly to the backend.
One way to work around this is is to ensure that any non-local state creation
is wrapped into a larger function, which you then call functionalize on.
(3) `resize_()` has some limitations: functionalize will only work on programs
that use resize_()` as long as the tensor being resized is not a view.
(4) `as_strided()` has some limitations: functionalize will not work on
`as_strided()` calls that result in tensors with overlapping memory.
Finally, a helpful mental model for understanding functionalization is that
most user pytorch programs are writting with the public torch API.
When executed, torch operators are generally decomposed into
our internal C++ "ATen" API.
The logic for functionalization happens entirely at the level of ATen.
Functionalization knows how to take every aliasing operator in ATen,
and map it to its non-aliasing equivalent
(e.g. ``tensor.view({-1})`` -> ``at::view_copy(tensor, {-1})``),
and how to take every mutating operator in ATen,
and map it to its non-mutating equivalent
(e.g. ``tensor.add_(1)`` -> ``at::add(tensor, -1)``),
while tracking aliases and mutations out-of-line to know when to fix things up.
Information about which ATen operators are aliasing or mutating all comes from
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml.
"""
if remove == 'mutations':
reapply_views = True
elif remove == 'mutations_and_views':
reapply_views = False
else:
raise RuntimeError(
f"functionalize(f, remove='mutations'): received invalid argument for remove={remove}."
" Valid options are:\n"
" remove='mutations': all inplace and out= operators will be removed from the program, and replaced"
" with their out-of-place equivalents.\n"
" remove='mutations_and_views': In addition to the above, all aliasing operators {view} will be"
" replaced with their non-aliasing counterparts, {view}_copy.\n"
)
@doesnt_support_saved_tensors_hooks
@wraps(func)
def wrapped(*args, **kwargs):
try:
func_level = _func_increment_nesting(reapply_views)
func_args = _wrap_all_tensors_to_functional(args, func_level)
func_kwargs = _wrap_all_tensors_to_functional(kwargs, func_level)
flattened_unwrapped_args, _ = tree_flatten(args)
flattened_wrapped_args, _ = tree_flatten(func_args)
flattened_unwrapped_kwargs, _ = tree_flatten(kwargs)
flattened_wrapped_kwargs, _ = tree_flatten(func_kwargs)
func_outputs = func(*func_args, **func_kwargs)
outputs = _unwrap_all_tensors_from_functional(func_outputs, reapply_views=reapply_views)
flat_outputs, func_out_spec = tree_flatten(outputs)
for a in flattened_wrapped_args + flattened_wrapped_kwargs:
if isinstance(a, torch.Tensor):
# Call sync_() on the inputs, to ensure that any pending mutations have been applied.
torch._sync(a)
# And if any mutations were applied to the inputs, we need to propagate them back to the user.
for unwrapped, wrapped in zip(flattened_unwrapped_args, flattened_wrapped_args):
if isinstance(unwrapped, torch.Tensor) and isinstance(wrapped, torch.Tensor):
_propagate_functional_input_mutation(unwrapped, wrapped)
for unwrapped, wrapped in zip(flattened_unwrapped_kwargs, flattened_wrapped_kwargs):
if isinstance(unwrapped, torch.Tensor) and isinstance(wrapped, torch.Tensor):
_propagate_functional_input_mutation(unwrapped, wrapped)
return outputs
finally:
_func_decrement_nesting()
return wrapped
# use an alternate way to register an operator into the decomposition table
# _register_jit_decomposition doesn't work for some operators, e.g. addr,
# because the Tensor types generated cannot be unioned by torchscript
# decomp should be type OpOverload
vmap_decompositions_lib = torch.library.Library("aten", "IMPL", "FuncTorchBatched")
def _register_python_decomposition_vmap(decomp):
if decomp in decomposition_table:
vmap_decompositions_lib.impl(decomp, decomposition_table[decomp])
else:
raise RuntimeError(f"could not find decomposition for {decomp}")
_register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default)
_register_python_decomposition_vmap(torch.ops.aten.addr.default)
|