1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
|
Migrating from functorch to torch.func
======================================
torch.func, previously known as "functorch", is
`JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch.
functorch started as an out-of-tree library over at
the `pytorch/functorch <https://github.com/pytorch/functorch>`_ repository.
Our goal has always been to upstream functorch directly into PyTorch and provide
it as a core PyTorch library.
As the final step of the upstream, we've decided to migrate from being a top level package
(``functorch``) to being a part of PyTorch to reflect how the function transforms are
integrated directly into PyTorch core. As of PyTorch 2.0, we are deprecating
``import functorch`` and ask that users migrate to the newest APIs, which we
will maintain going forward. ``import functorch`` will be kept around to maintain
backwards compatibility for a couple of releases.
function transforms
-------------------
The following APIs are a drop-in replacement for the following
`functorch APIs <https://pytorch.org/functorch/1.13/functorch.html>`_.
They are fully backwards compatible.
============================== =======================================
functorch API PyTorch API (as of PyTorch 2.0)
============================== =======================================
functorch.vmap :func:`torch.vmap` or :func:`torch.func.vmap`
functorch.grad :func:`torch.func.grad`
functorch.vjp :func:`torch.func.vjp`
functorch.jvp :func:`torch.func.jvp`
functorch.jacrev :func:`torch.func.jacrev`
functorch.jacfwd :func:`torch.func.jacfwd`
functorch.hessian :func:`torch.func.hessian`
functorch.functionalize :func:`torch.func.functionalize`
============================== =======================================
Furthermore, if you are using torch.autograd.functional APIs, please try out
the :mod:`torch.func` equivalents instead. :mod:`torch.func` function
transforms are more composable and more performant in many cases.
=========================================== =======================================
torch.autograd.functional API torch.func API (as of PyTorch 2.0)
=========================================== =======================================
:func:`torch.autograd.functional.vjp` :func:`torch.func.grad` or :func:`torch.func.vjp`
:func:`torch.autograd.functional.jvp` :func:`torch.func.jvp`
:func:`torch.autograd.functional.jacobian` :func:`torch.func.jacrev` or :func:`torch.func.jacfwd`
:func:`torch.autograd.functional.hessian` :func:`torch.func.hessian`
=========================================== =======================================
NN module utilities
-------------------
We've changed the APIs to apply function transforms over NN modules to make them
fit better into the PyTorch design philosophy. The new API is different, so
please read this section carefully.
functorch.make_functional
^^^^^^^^^^^^^^^^^^^^^^^^^
:func:`torch.func.functional_call` is the replacement for
`functorch.make_functional <https://pytorch.org/functorch/1.13/generated/functorch.make_functional.html#functorch.make_functional>`_
and
`functorch.make_functional_with_buffers <https://pytorch.org/functorch/1.13/generated/functorch.make_functional_with_buffers.html#functorch.make_functional_with_buffers>`_.
However, it is not a drop-in replacement.
If you're in a hurry, you can use
`helper functions in this gist <https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf>`_
that emulate the behavior of functorch.make_functional and functorch.make_functional_with_buffers.
We recommend using :func:`torch.func.functional_call` directly because it is a more explicit
and flexible API.
Concretely, functorch.make_functional returns a functional module and parameters.
The functional module accepts parameters and inputs to the model as arguments.
:func:`torch.func.functional_call` allows one to call the forward pass of an existing
module using new parameters and buffers and inputs.
Here's an example of how to compute gradients of parameters of a model using functorch
vs :mod:`torch.func`::
# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
fmodel, params = functorch.make_functional(model)
def compute_loss(params, inputs, targets):
prediction = fmodel(params, inputs)
return torch.nn.functional.mse_loss(prediction, targets)
grads = functorch.grad(compute_loss)(params, inputs, targets)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
def compute_loss(params, inputs, targets):
prediction = torch.func.functional_call(model, params, (inputs,))
return torch.nn.functional.mse_loss(prediction, targets)
grads = torch.func.grad(compute_loss)(params, inputs, targets)
And here's an example of how to compute jacobians of model parameters::
# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
fmodel, params = functorch.make_functional(model)
jacobians = functorch.jacrev(fmodel)(params, inputs)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
from torch.func import jacrev, functional_call
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
# jacrev computes jacobians of argnums=0 by default.
# We set it to 1 to compute jacobians of params
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))
Note that it is important for memory consumption that you should only carry
around a single copy of your parameters. ``model.named_parameters()`` does not copy
the parameters. If in your model training you update the parameters of the model
in-place, then the ``nn.Module`` that is your model has the single copy of the
parameters and everything is OK.
However, if you want to carry your parameters around in a dictionary and update
them out-of-place, then there are two copies of parameters: the one in the
dictionary and the one in the ``model``. In this case, you should change
``model`` to not hold memory by converting it to the meta device via
``model.to('meta')``.
functorch.combine_state_for_ensemble
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Please use :func:`torch.func.stack_module_state` instead of
`functorch.combine_state_for_ensemble <https://pytorch.org/functorch/1.13/generated/functorch.combine_state_for_ensemble.html>`_
:func:`torch.func.stack_module_state` returns two dictionaries, one of stacked parameters, and
one of stacked buffers, that can then be used with :func:`torch.vmap` and :func:`torch.func.functional_call`
for ensembling.
For example, here is an example of how to ensemble over a very simple model::
import torch
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)
# ---------------
# using functorch
# ---------------
import functorch
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import copy
# Construct a version of the model with no memory by putting the Tensors on
# the meta device.
base_model = copy.deepcopy(models[0])
base_model.to('meta')
params, buffers = torch.func.stack_module_state(models)
# It is possible to vmap directly over torch.func.functional_call,
# but wrapping it in a function makes it clearer what is going on.
def call_single_model(params, buffers, data):
return torch.func.functional_call(base_model, (params, buffers), (data,))
output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
functorch.compile
-----------------
We are no longer supporting functorch.compile (also known as AOTAutograd)
as a frontend for compilation in PyTorch; we have integrated AOTAutograd
into PyTorch's compilation story. If you are a user, please use
:func:`torch.compile` instead.
|