File: index.rst

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (78 lines) | stat: -rw-r--r-- 2,625 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
:github_url: https://github.com/pytorch/functorch

functorch
===================================

.. currentmodule:: functorch

functorch is `JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch.

.. warning::

   We've integrated functorch into PyTorch. As the final step of the
   integration, the functorch APIs are deprecated as of PyTorch 2.0.
   Please use the torch.func APIs instead and see the
   `migration guide <https://pytorch.org/docs/main/func.migrating.html>`_
   and `docs <https://pytorch.org/docs/main/func.html>`_
   for more details.

What are composable function transforms?
----------------------------------------

- A "function transform" is a higher-order function that accepts a numerical function
  and returns a new function that computes a different quantity.

- functorch has auto-differentiation transforms (``grad(f)`` returns a function that
  computes the gradient of ``f``), a vectorization/batching transform (``vmap(f)``
  returns a function that computes ``f`` over batches of inputs), and others.

- These function transforms can compose with each other arbitrarily. For example,
  composing ``vmap(grad(f))`` computes a quantity called per-sample-gradients that
  stock PyTorch cannot efficiently compute today.

Why composable function transforms?
-----------------------------------

There are a number of use cases that are tricky to do in PyTorch today:

- computing per-sample-gradients (or other per-sample quantities)
- running ensembles of models on a single machine
- efficiently batching together tasks in the inner-loop of MAML
- efficiently computing Jacobians and Hessians
- efficiently computing batched Jacobians and Hessians

Composing :func:`vmap`, :func:`grad`, and :func:`vjp` transforms allows us to express the above without designing a separate subsystem for each.
This idea of composable function transforms comes from the `JAX framework <https://github.com/google/jax>`_.

Read More
---------

Check out our `whirlwind tour <whirlwind_tour>`_ or some of our tutorials mentioned below.


.. toctree::
   :maxdepth: 2
   :caption: functorch: Getting Started

   install
   notebooks/whirlwind_tour.ipynb
   ux_limitations

.. toctree::
   :maxdepth: 2
   :caption: functorch API Reference and Notes

   functorch
   experimental
   aot_autograd

.. toctree::
   :maxdepth: 1
   :caption: functorch Tutorials

   notebooks/jacobians_hessians.ipynb
   notebooks/ensembling.ipynb
   notebooks/per_sample_grads.ipynb
   notebooks/neural_tangent_kernels.ipynb
   notebooks/aot_autograd_optimizations.ipynb
   notebooks/minifier.ipynb