File: func.whirlwind_tour.rst

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (196 lines) | stat: -rw-r--r-- 6,487 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
torch.func Whirlwind Tour
=========================

What is torch.func?
-------------------

.. currentmodule:: torch.func

torch.func, previously known as functorch, is a library for
`JAX <https://github.com/google/jax>`_-like composable function transforms in
PyTorch.

- A "function transform" is a higher-order function that accepts a numerical
  function and returns a new function that computes a different quantity.
- torch.func has auto-differentiation transforms (``grad(f)`` returns a function
  that computes the gradient of ``f``), a vectorization/batching transform
  (``vmap(f)`` returns a function that computes ``f`` over batches of inputs),
  and others.
- These function transforms can compose with each other arbitrarily. For
  example, composing ``vmap(grad(f))`` computes a quantity called
  per-sample-gradients that stock PyTorch cannot efficiently compute today.

Why composable function transforms?
-----------------------------------
There are a number of use cases that are tricky to do in PyTorch today:
- computing per-sample-gradients (or other per-sample quantities)

- running ensembles of models on a single machine
- efficiently batching together tasks in the inner-loop of MAML
- efficiently computing Jacobians and Hessians
- efficiently computing batched Jacobians and Hessians

Composing :func:`vmap`, :func:`grad`, :func:`vjp`, and :func:`jvp` transforms
allows us to express the above without designing a separate subsystem for each.

What are the transforms?
------------------------

:func:`grad` (gradient computation)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

``grad(func)`` is our gradient computation transform. It returns a new function
that computes the gradients of ``func``. It assumes ``func`` returns a single-element
Tensor and by default it computes the gradients of the output of ``func`` w.r.t.
to the first input.

.. code-block:: python

    import torch
    from torch.func import grad
    x = torch.randn([])
    cos_x = grad(lambda x: torch.sin(x))(x)
    assert torch.allclose(cos_x, x.cos())

    # Second-order gradients
    neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
    assert torch.allclose(neg_sin_x, -x.sin())

:func:`vmap` (auto-vectorization)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Note: :func:`vmap` imposes restrictions on the code that it can be used on. For more
details, please see :ref:`ux-limitations`.

``vmap(func)(*inputs)`` is a transform that adds a dimension to all Tensor
operations in ``func``. ``vmap(func)`` returns a new function that maps ``func``
over some dimension (default: 0) of each Tensor in inputs.

vmap is useful for hiding batch dimensions: one can write a function func that
runs on examples and then lift it to a function that can take batches of
examples with ``vmap(func)``, leading to a simpler modeling experience:

.. code-block:: python

    import torch
    from torch.func import vmap
    batch_size, feature_size = 3, 5
    weights = torch.randn(feature_size, requires_grad=True)

    def model(feature_vec):
        # Very simple linear model with activation
        assert feature_vec.dim() == 1
        return feature_vec.dot(weights).relu()

    examples = torch.randn(batch_size, feature_size)
    result = vmap(model)(examples)

When composed with :func:`grad`, :func:`vmap` can be used to compute per-sample-gradients:

.. code-block:: python

    from torch.func import vmap
    batch_size, feature_size = 3, 5

    def model(weights,feature_vec):
        # Very simple linear model with activation
        assert feature_vec.dim() == 1
        return feature_vec.dot(weights).relu()

    def compute_loss(weights, example, target):
        y = model(weights, example)
        return ((y - target) ** 2).mean()  # MSELoss

    weights = torch.randn(feature_size, requires_grad=True)
    examples = torch.randn(batch_size, feature_size)
    targets = torch.randn(batch_size)
    inputs = (weights,examples, targets)
    grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)

:func:`vjp` (vector-Jacobian product)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The :func:`vjp` transform applies ``func`` to ``inputs`` and returns a new function
that computes the vector-Jacobian product (vjp) given some ``cotangents`` Tensors.

.. code-block:: python

    from torch.func import vjp

    inputs = torch.randn(3)
    func = torch.sin
    cotangents = (torch.randn(3),)

    outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents)

:func:`jvp` (Jacobian-vector product)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The :func:`jvp` transforms computes Jacobian-vector-products and is also known as
"forward-mode AD". It is not a higher-order function unlike most other transforms,
but it returns the outputs of ``func(inputs)`` as well as the jvps.

.. code-block:: python

    from torch.func import jvp
    x = torch.randn(5)
    y = torch.randn(5)
    f = lambda x, y: (x * y)
    _, out_tangent = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
    assert torch.allclose(out_tangent, x + y)

:func:`jacrev`, :func:`jacfwd`, and :func:`hessian`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The :func:`jacrev` transform returns a new function that takes in ``x`` and returns
the Jacobian of the function with respect to ``x`` using reverse-mode AD.

.. code-block:: python

    from torch.func import jacrev
    x = torch.randn(5)
    jacobian = jacrev(torch.sin)(x)
    expected = torch.diag(torch.cos(x))
    assert torch.allclose(jacobian, expected)

:func:`jacrev` can be composed with :func:`vmap` to produce batched jacobians:

.. code-block:: python

    x = torch.randn(64, 5)
    jacobian = vmap(jacrev(torch.sin))(x)
    assert jacobian.shape == (64, 5, 5)

:func:`jacfwd` is a drop-in replacement for jacrev that computes Jacobians using
forward-mode AD:

.. code-block:: python

    from torch.func import jacfwd
    x = torch.randn(5)
    jacobian = jacfwd(torch.sin)(x)
    expected = torch.diag(torch.cos(x))
    assert torch.allclose(jacobian, expected)

Composing :func:`jacrev` with itself or :func:`jacfwd` can produce hessians:

.. code-block:: python

    def f(x):
        return x.sin().sum()

    x = torch.randn(5)
    hessian0 = jacrev(jacrev(f))(x)
    hessian1 = jacfwd(jacrev(f))(x)

:func:`hessian` is a convenience function that combines jacfwd and jacrev:

.. code-block:: python

    from torch.func import hessian

    def f(x):
        return x.sin().sum()

    x = torch.randn(5)
    hess = hessian(f)(x)