1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
|
# mypy: allow-untyped-defs
import torch
class MyAutogradFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone()
@staticmethod
def backward(ctx, grad_output):
return grad_output + 1
class AutogradFunction(torch.nn.Module):
"""
TorchDynamo does not keep track of backward() on autograd functions. We recommend to
use `allow_in_graph` to mitigate this problem.
"""
def forward(self, x):
return MyAutogradFunction.apply(x)
example_args = (torch.randn(3, 2),)
model = AutogradFunction()
|