1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
|
import torch
from torch.testing._internal.common_utils import TestCase, run_tests
from functorch.experimental.cond import cond
from torch.fx.experimental.proxy_tensor import make_fx
import unittest
class TestControlFlow(TestCase):
def test_cond_no_trace(self):
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
x = torch.randn(4)
result = cond(False, true_fn, false_fn, [x])
self.assertEqual(result, torch.cos(x))
class TestControlFlowTraced(TestCase):
def test_cond_traced_not_nested(self):
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
def f(x, y):
return cond(y, true_fn, false_fn, [x])
x = torch.randn(4)
graph = make_fx(f)(x, torch.tensor(False))
result_true = graph.forward(x, torch.tensor(True))
result_false = graph.forward(x, torch.tensor(False))
self.assertFalse(torch.allclose(result_true, result_false))
self.assertEqual(result_true, torch.sin(x))
self.assertEqual(result_false, torch.cos(x))
@unittest.expectedFailure
def test_cond_nested_traced(self):
def true_nested(y):
return y * y
def false_nested(y):
return y + y
def true_fn(x, pred2):
z = cond(pred2, true_nested, false_nested, [x])
return x + z
def false_fn(x, _):
return x.cos()
def f(x, pred, pred2):
return cond(pred, true_fn, false_fn, [x, pred2])
x = torch.randn(4)
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))
result_true_true = graph.forward(x, torch.tensor(True), torch.tensor(True)) # True + True -> x * x
result_true_false = graph.forward(x, torch.tensor(True), torch.tensor(False)) # True + True -> x + x
result_false_true = graph.forward(x, torch.tensor(False), torch.tensor(True)) # False + either -> cos
result_false_false = graph.forward(x, torch.tensor(False), torch.tensor(False)) # False + either -> cos
self.assertNotEqual(result_true_true, result_true_false)
self.assertFalse(torch.allclose(result_false_true, result_true_true))
self.assertEqual(result_false_true, result_false_false)
self.assertEqual(result_true_true, (x * x) + x)
self.assertEqual(result_true_false, x + x + x)
self.assertEqual(result_false_true, torch.cos(x))
@unittest.expectedFailure
def test_cond_nested_traced_other_inputs(self):
def true_nested(y):
return y * y
def false_nested(y):
return y + y
def true_fn(k, pred2):
z = cond(pred2, true_nested, false_nested, [k])
return torch.add(torch.tensor([.25, .25]), z)
def false_fn(k, _):
return k.cos()
def f(k, pred, pred2):
return cond(pred, true_fn, false_fn, [k, pred2])
x = torch.tensor([0.5, 0.5])
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))
a = torch.tensor([1.0, 1.0])
result_true_true = graph.forward(a, torch.tensor(True), torch.tensor(True))
self.assertEqual(result_true_true, (a * a) + torch.tensor([0.25, 0.25]))
b = torch.tensor([2.0, 2.0])
result_true_true = graph.forward(b, torch.tensor(True), torch.tensor(True))
self.assertEqual(result_true_true, (b * b) + torch.tensor([0.25, 0.25]))
@unittest.expectedFailure
def test_cond_nested_traced_multi(self):
def true_a(y):
return y * y
def false_a(y):
return y + y
def true_b(y, z):
return y + z
def false_b(y, z):
return y * z
def f(x, pred, pred2):
a_out = cond(pred, true_a, false_a, [x])
b_out = cond(pred2, true_b, false_b, [x, x])
return a_out + b_out
x = torch.randn(4)
graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))
# Brittle, yet, delicious
out = """
def forward(self, x_1, pred_1, pred2_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.cond(pred_1, true_graph_0, false_graph_0, [[x_1]]);
pred_1 = true_graph_0 = false_graph_0 = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
conditional_1 = torch.ops.cond(pred2_1, true_graph_1, false_graph_1, [[x_1, x_1]]);
pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
add = torch.ops.aten.add.Tensor(conditional, conditional_1); conditional = conditional_1 = None
return add
"""
code = graph.code
# Normalization hack, cause .code makes some weird whitespace
code = "".join(code.split())
out = "".join(out.split())
self.assertEqual(code, out)
code = graph.true_graph_0.code
out = """
def forward(self, flat_args):
flat_args_1, = fx_pytree.tree_flatten_spec([flat_args], self._in_spec)
mul = torch.ops.aten.mul.Tensor(flat_args_1, flat_args_1); flat_args_1 = None
return pytree.tree_unflatten([mul], self._out_spec)
"""
# Normalization hack, cause .code makes some weird whitespace
code = "".join(code.split())
out = "".join(out.split())
self.assertEqual(code, out)
def test_assert_on_mismatch_type_size(self):
def true_fn(x):
return x.sin()
def false_fn(x):
return (x, x)
def f(x, y):
return cond(y, true_fn, false_fn, [x])
x = torch.randn(4)
with self.assertRaises(AssertionError):
make_fx(f)(x, torch.tensor(False))
@unittest.expectedFailure
def test_assert_on_mismatch_tensor_size(self):
def true_fn(x):
return x.sin()
def false_fn(x):
return torch.zeros([10, 10])
def f(x, y):
return cond(y, true_fn, false_fn, [x])
x = torch.randn(4)
with self.assertRaises(AssertionError):
make_fx(f)(x, torch.tensor(False))
if __name__ == '__main__':
run_tests()
|