1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
|
# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo import eval_frame
from torch._dynamo.hooks import Hooks
c = 10
def fn1(a, b):
return a + b - c
def fn2(a, b):
x = 0
y = 1
def modify():
nonlocal x
x += a + b + c
for _ in range(2):
modify()
return x + y
def fn3():
yield 1
yield 2
with_debug_nops = eval_frame._optimize_catch_errors(
torch._dynamo.testing.debug_insert_nops, Hooks(None, None)
)
class NopTests(torch._dynamo.test_case.TestCase):
@with_debug_nops
def test1(self):
self.assertEqual(fn1(1, 2), -7)
self.assertEqual(fn1(1, 2), -7)
@with_debug_nops
def test2(self):
self.assertEqual(fn2(1, 2), 27)
self.assertEqual(fn2(1, 2), 27)
@with_debug_nops
def test3(self):
t = fn3()
self.assertEqual(next(t), 1)
self.assertEqual(next(t), 2)
self.assertRaises(StopIteration, lambda: next(t))
def test_extended_args(self):
too_many_adds = "+".join(["a", "b"] * 256)
source = (
f"lambda a, b: ({too_many_adds}+a if a.sum() > 0 else {too_many_adds} - b)"
)
fn = eval(source)
a = torch.ones(1)
b = torch.ones(1)
fn = with_debug_nops(fn)
self.assertEqual(fn(a, b).sum(), 513)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()
|