File: func.api.md

package info (click to toggle)
pytorch 2.9.1%2Bdfsg-1~exp2
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 180,096 kB
  • sloc: python: 1,473,255; cpp: 942,030; ansic: 79,796; asm: 7,754; javascript: 2,502; java: 1,962; sh: 1,809; makefile: 628; xml: 8
file content (88 lines) | stat: -rw-r--r-- 1,898 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# torch.func API Reference

```{eval-rst}
.. currentmodule:: torch.func
```

```{eval-rst}
.. automodule:: torch.func
```

## Function Transforms
```{eval-rst}
.. autosummary::
    :toctree: generated
    :nosignatures:

     vmap
     grad
     grad_and_value
     vjp
     jvp
     linearize
     jacrev
     jacfwd
     hessian
     functionalize
```

## Utilities for working with torch.nn.Modules

In general, you can transform over a function that calls a ``torch.nn.Module``.
For example, the following is an example of computing a jacobian of a function
that takes three values and returns three values:

```python
model = torch.nn.Linear(3, 3)

def f(x):
    return model(x)

x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)
```

However, if you want to do something like compute a jacobian over the parameters of the model, then there needs to be a way to construct a function where the parameters are the inputs to the function. That's what {func}`functional_call` is for: it accepts an nn.Module, the transformed `parameters`, and the inputs to the Module's forward pass. It returns the value of running the Module's forward pass with the replaced parameters.

Here's how we would compute the Jacobian over the parameters

```python
model = torch.nn.Linear(3, 3)

def f(params, x):
    return torch.func.functional_call(model, params, x)

x = torch.randn(3)
jacobian = jacrev(f)(dict(model.named_parameters()), x)
```

```{eval-rst}
.. autosummary::
    :toctree: generated
    :nosignatures:

    functional_call
    stack_module_state
    replace_all_batch_norm_modules_
```

If you're looking for information on fixing Batch Norm modules, please follow the
guidance here

```{eval-rst}
.. toctree::
   :maxdepth: 1

   func.batch_norm
```

## Debug utilities

```{eval-rst}
.. autosummary::
    :toctree: generated
    :nosignatures:

     debug_unwrap
```