1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
|
# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo
import torch._dynamo.test_case
@torch._dynamo.config.patch("capture_scalar_outputs", True)
class ViewTests(torch._dynamo.test_case.TestCase):
def test_view_to_2d(self):
@torch.compile(fullgraph=True, backend="eager")
def f(t, _u0):
u0 = t[0].item()
u1 = t[1].item()
torch._check_is_size(u0)
torch._check_is_size(u1)
n = u0 * u1
a = torch.randn(n)
return a.view(-1, _u0)
t = torch.tensor([2, 4], dtype=torch.int32)
f(t, 2)
def test_view_to_1d(self):
@torch.compile(fullgraph=True, backend="eager")
def f(t, _n):
u0 = t[0].item()
u1 = t[1].item()
torch._check_is_size(u0)
torch._check_is_size(u1)
a = torch.randn(u0, u1)
return a.view(_n)
t = torch.tensor([2, 4], dtype=torch.int32)
f(t, 8)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()
|