1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
|
torch.func Whirlwind Tour
=========================
What is torch.func?
-------------------
.. currentmodule:: torch.func
torch.func, previously known as functorch, is a library for
`JAX <https://github.com/google/jax>`_-like composable function transforms in
PyTorch.
- A "function transform" is a higher-order function that accepts a numerical
function and returns a new function that computes a different quantity.
- torch.func has auto-differentiation transforms (``grad(f)`` returns a function
that computes the gradient of ``f``), a vectorization/batching transform
(``vmap(f)`` returns a function that computes ``f`` over batches of inputs),
and others.
- These function transforms can compose with each other arbitrarily. For
example, composing ``vmap(grad(f))`` computes a quantity called
per-sample-gradients that stock PyTorch cannot efficiently compute today.
Why composable function transforms?
-----------------------------------
There are a number of use cases that are tricky to do in PyTorch today:
- computing per-sample-gradients (or other per-sample quantities)
- running ensembles of models on a single machine
- efficiently batching together tasks in the inner-loop of MAML
- efficiently computing Jacobians and Hessians
- efficiently computing batched Jacobians and Hessians
Composing :func:`vmap`, :func:`grad`, :func:`vjp`, and :func:`jvp` transforms
allows us to express the above without designing a separate subsystem for each.
What are the transforms?
------------------------
:func:`grad` (gradient computation)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
``grad(func)`` is our gradient computation transform. It returns a new function
that computes the gradients of ``func``. It assumes ``func`` returns a single-element
Tensor and by default it computes the gradients of the output of ``func`` w.r.t.
to the first input.
.. code-block:: python
import torch
from torch.func import grad
x = torch.randn([])
cos_x = grad(lambda x: torch.sin(x))(x)
assert torch.allclose(cos_x, x.cos())
# Second-order gradients
neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
assert torch.allclose(neg_sin_x, -x.sin())
:func:`vmap` (auto-vectorization)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Note: :func:`vmap` imposes restrictions on the code that it can be used on. For more
details, please see :ref:`ux-limitations`.
``vmap(func)(*inputs)`` is a transform that adds a dimension to all Tensor
operations in ``func``. ``vmap(func)`` returns a new function that maps ``func``
over some dimension (default: 0) of each Tensor in inputs.
vmap is useful for hiding batch dimensions: one can write a function func that
runs on examples and then lift it to a function that can take batches of
examples with ``vmap(func)``, leading to a simpler modeling experience:
.. code-block:: python
import torch
from torch.func import vmap
batch_size, feature_size = 3, 5
weights = torch.randn(feature_size, requires_grad=True)
def model(feature_vec):
# Very simple linear model with activation
assert feature_vec.dim() == 1
return feature_vec.dot(weights).relu()
examples = torch.randn(batch_size, feature_size)
result = vmap(model)(examples)
When composed with :func:`grad`, :func:`vmap` can be used to compute per-sample-gradients:
.. code-block:: python
from torch.func import vmap
batch_size, feature_size = 3, 5
def model(weights,feature_vec):
# Very simple linear model with activation
assert feature_vec.dim() == 1
return feature_vec.dot(weights).relu()
def compute_loss(weights, example, target):
y = model(weights, example)
return ((y - target) ** 2).mean() # MSELoss
weights = torch.randn(feature_size, requires_grad=True)
examples = torch.randn(batch_size, feature_size)
targets = torch.randn(batch_size)
inputs = (weights,examples, targets)
grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
:func:`vjp` (vector-Jacobian product)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The :func:`vjp` transform applies ``func`` to ``inputs`` and returns a new function
that computes the vector-Jacobian product (vjp) given some ``cotangents`` Tensors.
.. code-block:: python
from torch.func import vjp
inputs = torch.randn(3)
func = torch.sin
cotangents = (torch.randn(3),)
outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents)
:func:`jvp` (Jacobian-vector product)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The :func:`jvp` transforms computes Jacobian-vector-products and is also known as
"forward-mode AD". It is not a higher-order function unlike most other transforms,
but it returns the outputs of ``func(inputs)`` as well as the jvps.
.. code-block:: python
from torch.func import jvp
x = torch.randn(5)
y = torch.randn(5)
f = lambda x, y: (x * y)
_, out_tangent = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
assert torch.allclose(out_tangent, x + y)
:func:`jacrev`, :func:`jacfwd`, and :func:`hessian`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The :func:`jacrev` transform returns a new function that takes in ``x`` and returns
the Jacobian of the function with respect to ``x`` using reverse-mode AD.
.. code-block:: python
from torch.func import jacrev
x = torch.randn(5)
jacobian = jacrev(torch.sin)(x)
expected = torch.diag(torch.cos(x))
assert torch.allclose(jacobian, expected)
:func:`jacrev` can be composed with :func:`vmap` to produce batched jacobians:
.. code-block:: python
x = torch.randn(64, 5)
jacobian = vmap(jacrev(torch.sin))(x)
assert jacobian.shape == (64, 5, 5)
:func:`jacfwd` is a drop-in replacement for jacrev that computes Jacobians using
forward-mode AD:
.. code-block:: python
from torch.func import jacfwd
x = torch.randn(5)
jacobian = jacfwd(torch.sin)(x)
expected = torch.diag(torch.cos(x))
assert torch.allclose(jacobian, expected)
Composing :func:`jacrev` with itself or :func:`jacfwd` can produce hessians:
.. code-block:: python
def f(x):
return x.sin().sum()
x = torch.randn(5)
hessian0 = jacrev(jacrev(f))(x)
hessian1 = jacfwd(jacrev(f))(x)
:func:`hessian` is a convenience function that combines jacfwd and jacrev:
.. code-block:: python
from torch.func import hessian
def f(x):
return x.sin().sum()
x = torch.randn(5)
hess = hessian(f)(x)
|