1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
|
# Owner(s): ["module: dynamo"]
import unittest
from contextlib import contextmanager
from importlib import import_module
import torch
import torch._prims_common as utils
from torch._dynamo.utils import preserve_rng_state
from torch._inductor import config
from torch._inductor.compiler_bisector import CompilerBisector
from torch._inductor.test_case import TestCase
from torch.library import _scoped_library, Library
from torch.testing._internal.inductor_utils import HAS_CUDA
aten = torch.ops.aten
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
f32 = torch.float32
i64 = torch.int64
i32 = torch.int32
@requires_cuda
class TestCompilerBisector(TestCase):
test_ns = "_test_bisector"
def tearDown(self):
if hasattr(torch.ops, self.test_ns):
delattr(torch.ops, self.test_ns)
if hasattr(self, "lib"):
del self.lib.m
del self.lib
def get_op(self, name):
return getattr(getattr(torch.ops, self.test_ns), name).default
def get_lib(self):
lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901
self.lib = lib
return lib
def test_bad_decomp(self):
mod = import_module("torch._inductor.compile_fx")
def bad_exp_decomp(self, rate=1, generator=None):
assert generator is None
torch._check(
not utils.is_complex_dtype(self.dtype)
and not utils.is_integer_dtype(self.dtype)
and not utils.is_boolean_dtype(self.dtype),
lambda: f"Exponential distribution is a continuous probability distribution. \
dtype must be a floating point but you specified {self.dtype}",
)
torch._check(
rate > 0.0,
lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}",
)
return torch.rand_like(self) * float("nan")
@contextmanager
def patch_exp_decomp():
from torch._inductor.compile_fx import select_decomp_table as old_decomp
def get_decomp():
out = old_decomp()
out = out.copy()
out[aten.exponential.default] = bad_exp_decomp
return out
torch._inductor.compile_fx.select_decomp_table = get_decomp
try:
yield
finally:
torch._inductor.compile_fx.select_decomp_table = old_decomp
def vq(x):
return (x + 3).exponential_() * 10.5
def test_fn():
torch._dynamo.reset()
with patch_exp_decomp():
vq_compiled = torch.compile(vq)
x = torch.randn(4, 400, 256).cuda()
with torch._dynamo.utils.preserve_rng_state():
out = vq(x)
out_compiled = vq_compiled(x)
return not out_compiled.isnan().any()
out = CompilerBisector.do_bisect(test_fn)
self.assertEqual(out.backend, "aot_eager_decomp_partition")
self.assertEqual(out.subsystem, "decomposition")
self.assertEqual(out.bisect_number, 1)
self.assertTrue("aten.exponential" in out.debug_info)
def test_joint_graph(self):
from torch._inductor import config
def pass_fn(graph: torch.fx.Graph):
nodes = graph.find_nodes(
op="call_function", target=torch.ops.aten.add.Tensor
)
assert len(nodes) == 1
args = list(nodes[0].args)
args[1] = 2
nodes[0].args = tuple(args)
config.joint_custom_post_pass = pass_fn
def foo(x):
return x + 1
def test_fn():
torch._dynamo.reset()
inp = torch.rand([10], device="cuda")
out = foo(inp)
out_c = torch.compile(foo)(inp)
return torch.allclose(out, out_c)
out = CompilerBisector.do_bisect(test_fn)
self.assertEqual(out.backend, "inductor")
self.assertEqual(out.subsystem, "joint_graph_passes")
self.assertEqual(out.bisect_number, 4)
self.assertTrue("joint_custom_post_pass" in out.debug_info)
def test_rng(self):
def foo():
return torch.rand([10], device="cuda") + 1
def test_fn():
torch._dynamo.reset()
with preserve_rng_state():
out = foo()
with preserve_rng_state():
out_c = torch.compile(foo)()
return torch.allclose(out, out_c)
out = CompilerBisector.do_bisect(test_fn)
self.assertEqual(out.backend, "inductor")
self.assertEqual(out.subsystem, "inductor_fallback_random")
self.assertTrue("inductor_fallback_random" in out.debug_info)
def test_crossref(self):
test_ns = "bisect_ops"
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define("foo(Tensor x) -> Tensor")
op = self.get_op("foo")
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
# Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python
with torch._C._AutoDispatchBelowAutograd():
with torch._C._ExcludeDispatchKeyGuard(
torch._C.DispatchKeySet(
torch._C.DispatchKey.ADInplaceOrView
)
):
return op(x)
@staticmethod
def backward(ctx, gx):
return gx
def foo_impl(x):
return x.view_as(x).clone()
def foo_meta(x):
return x.view_as(x)
lib.impl("foo", Foo.apply, "Autograd")
lib.impl("foo", foo_impl, "CPU")
lib.impl("foo", foo_meta, "Meta")
x = torch.tensor(3.14159 / 3, requires_grad=True)
def test_fn():
torch._dynamo.reset()
try:
torch.testing.assert_allclose(torch.compile(op)(x), op(x))
except Exception:
return False
return True
out = CompilerBisector.do_bisect(test_fn)
self.assertEqual(out.backend, "aot_eager_decomp_partition_crossref")
def test_emulate_precision_casts(self):
def test_fn():
torch._dynamo.reset()
def calculate_scale(inp):
amax = torch.abs(torch.max(inp))
scale = 448.0 / torch.clamp(amax, min=1e-12)
scale = scale.to(torch.float32)
return scale
dtype = torch.bfloat16
torch.manual_seed(0)
inp = torch.randn(16, 16, 768, dtype=dtype, device="cuda")
eager_scale = calculate_scale(inp)
compile_scale = torch.compile(calculate_scale)(inp)
return torch.equal(eager_scale, compile_scale)
out = CompilerBisector.do_bisect(test_fn)
self.assertEqual(out.backend, "inductor")
self.assertEqual(out.subsystem, "inductor_emulate_precision_casts")
def test_bad_lowering(self):
def test_fn():
torch._dynamo.reset()
with config.patch("triton.inject_relu_bug_TESTING_ONLY", "accuracy"):
def my_func(x):
return ((x * -1) - 0.01).relu()
inp = torch.rand([100], device="cuda")
return torch.allclose(torch.compile(my_func)(inp), my_func(inp))
out = CompilerBisector.do_bisect(test_fn)
self.assertEqual(out.backend, "inductor")
self.assertEqual(out.subsystem, "lowerings")
self.assertEqual(out.bisect_number, 2)
self.assertTrue("relu" in out.debug_info)
def test_eager_backend(self):
# should indicate problem with first backend
def test_fn():
return False
out = CompilerBisector.do_bisect(test_fn)
self.assertEqual(out.backend, "eager")
self.assertEqual(out.subsystem, None)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()
|