1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
|
# Owner(s): ["oncall: jit"]
import torch
from torch.testing._internal.jit_utils import JitTestCase
from typing import List
class TestAutodiffJit(JitTestCase):
def test_undefined_tensor_lists(self):
def fn(tensor_list: List[torch.Tensor], add_tensor):
cat = torch.cat(tensor_list, dim=1)
r = torch.sin(cat + add_tensor)
return r
fn_s = torch.jit.script(fn)
a = torch.rand((3, 6), requires_grad=True)
b = torch.rand((3, 10), requires_grad=True)
x = [a, b]
y = torch.rand((3, 16), requires_grad=True)
ret = fn_s(x, y)
ret.sum().backward()
ret = fn_s(x, y)
ret.sum().backward()
ret = fn_s(x, y)
s = ret.sum()
# backward_fn expects 2 inputs: (grad_output, current_grad_r)
# current_grad_r is provided because we need to add this contribution
# to grad_r when we return it.
backward_fn = s.grad_fn.next_functions[0][0]
# check behavior with defined tensor
grad_out = torch.rand((3, 16))
grad_inputs = backward_fn(grad_out, None)
# expect 3 tensors: grad_y, grad_a, grad_b
self.assertEqual(3, len(grad_inputs))
for x in grad_inputs:
self.assertTrue(isinstance(x, torch.Tensor))
# now test with undefined grad_out
grad_inputs = backward_fn(None, None)
# expect all of them to be None
self.assertEqual(3, len(grad_inputs))
for x in grad_inputs:
if x is not None:
self.assertEqual(0, torch.max(torch.abs(x)).item())
def test_requires_grad_outputs(self):
# outputs should require_grad only if eager outputs would require_grad.
def fn(a, b, c):
return a.relu() + b.relu(), c.relu()
a = torch.rand((10, 10), requires_grad=False)
b = torch.rand((10, 10), requires_grad=False)
c = torch.rand((10, 10), requires_grad=True)
fn_s = torch.jit.script(fn)
for i in range(4):
x, y = fn_s(a, b, c)
self.assertFalse(x.requires_grad)
self.assertTrue(y.requires_grad)
def test_requires_grad_outputs_profiled_twice(self):
# the value "r" is used twice, by gammaln and by entr, so it is profiled twice.
# So during autodiff graph formation the profile nodes are unmerged because
# they are aliasing. Then the DifferentiableGraph doesn't have a profile
# node on the output. The requires_grad info should then be added onto the
# output value (otherwise autodiff will make the output require_grad).
# Note: this relies on gammaln and entr not having autodiff implementations.
def fn(a, b, c):
r = a.relu().relu()
return torch.special.gammaln(r), torch.special.entr(r), c.cos().relu()
fn_s = torch.jit.script(fn)
a = torch.rand((10, 10), requires_grad=False)
b = torch.rand((10, 10), requires_grad=False)
c = torch.rand((10, 10), requires_grad=True)
for i in range(4):
x_s, y_s, z_s = fn_s(a, b, c)
x, y, z = fn(a, b, c)
self.assertEqual(x_s.requires_grad, x.requires_grad)
self.assertEqual(y_s.requires_grad, y.requires_grad)
self.assertEqual(z_s.requires_grad, z.requires_grad)
def test_requires_grad_outputs_side_effects(self):
# same as above, but also add a CallFunction in between.
@torch.jit.ignore
def python_fn(x):
return x.relu()
def fn(a, b, c):
r = a.relu().relu()
z = python_fn(r)
return torch.relu(r), torch.nn.functional.gelu(r), c.cos().relu()
fn_s = torch.jit.script(fn)
a = torch.rand((10, 10), requires_grad=False)
b = torch.rand((10, 10), requires_grad=False)
c = torch.rand((10, 10), requires_grad=True)
for i in range(4):
x_s, y_s, z_s = fn_s(a, b, c)
x, y, z = fn(a, b, c)
self.assertEqual(x_s.requires_grad, x.requires_grad)
self.assertEqual(y_s.requires_grad, y.requires_grad)
self.assertEqual(z_s.requires_grad, z.requires_grad)
def test_autodiff_requires_grad_nograd(self):
@torch.jit.ignore
def python_fn(x):
return x.relu()
def fn(a, b, c):
x = a.sin().relu()
y = python_fn(b)
with torch.no_grad():
z = x + c
return x, y, z
fn_s = torch.jit.script(fn)
a = torch.rand((10, 10), requires_grad=True)
b = torch.rand((10, 10), requires_grad=True)
c = torch.rand((10, 10), requires_grad=True)
for i in range(4):
x_s, y_s, z_s = fn_s(a, b, c)
x, y, z = fn(a, b, c)
self.assertEqual(x_s.requires_grad, x.requires_grad)
self.assertEqual(y_s.requires_grad, y.requires_grad)
self.assertEqual(z_s.requires_grad, z.requires_grad)
|