1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
|
# Owner(s): ["module: autograd"]
import torch
from torch.testing._internal.common_utils import TestCase, run_tests, gradcheck
class TestAutogradComplex(TestCase):
def test_view_func_for_complex_views(self):
# case 1: both parent and child have view_func
x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True)
y = x.detach().requires_grad_(True)
x0 = x.clone()
x1 = torch.view_as_complex(x0)
x2 = torch.view_as_real(x1)
x2.mul_(2)
x2.sum().backward()
y0 = y.clone()
y0.mul_(2)
y0.sum().backward()
self.assertEqual(x.grad, y.grad)
# case 2: parent has view_func but child does not
x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True)
y = x.detach().requires_grad_(True)
def fn(a):
b = a.clone()
b1 = torch.view_as_complex(b)
b2 = b1.reshape(b1.numel())
return b2
x0 = fn(x)
x0.mul_(2)
x0.sum().backward()
y0 = fn(y)
y1 = y0.mul(2)
y1.sum().backward()
self.assertEqual(x.grad, y.grad)
# case 3: parent does not have a view_func but child does
x = torch.randn(10, dtype=torch.cdouble, requires_grad=True)
y = x.detach().requires_grad_(True)
def fn(a, dim0_size=5):
b = a.clone()
b1 = b.reshape(dim0_size, 2)
b2 = torch.view_as_real(b1)
return b2
x0 = fn(x)
x0.mul_(2)
x0.sum().backward()
y0 = fn(y)
y1 = y0.mul(2)
y1.sum().backward()
self.assertEqual(x.grad, y.grad)
def test_view_with_multi_output(self):
x = torch.randn(2, 2, 2, dtype=torch.double)
x1 = torch.view_as_complex(x)
# Taking an invalid view should always be allowed as long as it is not
# modified inplace
res = x1.unbind(0)
with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
res[0] += torch.rand(2, requires_grad=True)
x.requires_grad_(True)
x1 = torch.view_as_complex(x)
# Taking an invalid view should always be allowed as long as it is not
# modified inplace
res = x1.unbind(0)
with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
res[0] += torch.rand(2, requires_grad=True)
def as_identity(self):
# view_as_real and view_as_complex behavior should be like an identity
def func(z):
z_ = torch.view_as_complex(z)
z_select = torch.select(z_, z_.dim() - 1, 0)
z_select_real = torch.view_as_real(z_select)
return z_select_real.sum()
z = torch.randn(10, 2, 2, dtype=torch.double, requires_grad=True)
gradcheck(func, [z])
func(z).backward()
z1 = z.clone().detach().requires_grad_(True)
torch.select(z1, z1.dim() - 2, 0).sum().backward()
self.assertEqual(z.grad, z1.grad)
if __name__ == '__main__':
run_tests()
|