1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
|
# Owner(s): ["module: inductor"]
import os
import re
import unittest
import torch
from torch import nn
from torch._dynamo.testing import reset_rng_state
from torch._inductor import config, test_operators
from torch._inductor.codegen.multi_kernel import MultiKernelCall
from torch._inductor.test_case import TestCase
from torch._inductor.utils import run_and_get_code
from torch.nn import functional as F
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
skipIfXpu,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
class TransformerSnippet(nn.Module):
def __init__(self) -> None:
super().__init__()
self.ln1 = nn.LayerNorm(64)
self.ln2 = nn.LayerNorm(64)
def forward(self, x1, x2):
x1 = F.dropout(x1, 0.1)
x2 = F.dropout(self.ln1(x2), 0.1)
return self.ln2(x1 + x2)
def example_inputs(self):
return (torch.randn(2, 64).to(GPU_TYPE), torch.randn(2, 64).to(GPU_TYPE))
def _contains_multi_kernel_code(wrapper_code: str):
return (
re.search(r"multi_kernel_[^ ]* = async_compile.multi_kernel[(]", wrapper_code)
is not None
)
def make_cpp_wrapper_test(orig_test, **extra_args):
"""
Wrap an existing test into a new test with cpp-wrapper enabled.
Make this as a free function rather than staticmethod in MultiKernelTest.
Otherwise we get 'TypeError: 'staticmethod' object is not callable'
error in py3.8. (py3.10 works)
"""
@config.patch("cpp_wrapper", True)
@skipIfXpu(msg="cpp wrapper doesn't currently work on the XPU stack")
def fn(self):
# The same kernel may have been compiled by previous tests with
# cpp_wrapper disabled. Clear the cache so we go ahead to re-compile
# the kernel with cpp_wrapper enabled.
from torch._inductor import codecache
codecache.PyCodeCache.cache_clear()
return orig_test(self, **extra_args)
return fn
@config.patch(
{
"triton.multi_kernel": int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "1")),
"benchmark_kernel": True,
}
)
@instantiate_parametrized_tests
class MultiKernelTest(TestCase):
def test_softmax(self, expect_multi_kernel=True):
x = torch.rand(2, 1024).to(GPU_TYPE)
ref = torch.softmax(x, -1)
compiled_fn = torch.compile(torch.softmax)
act, wrapper_code = run_and_get_code(compiled_fn, x, -1)
# wrapper_code will contains 2 entries if cpp_wrapper=True.
# One for the first pass and one for the second pass.
# We mainly care about the wrapper for the final pass here.
wrapper_code = wrapper_code[-1]
self.assertEqual(ref, act)
if expect_multi_kernel:
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
else:
self.assertFalse(_contains_multi_kernel_code(wrapper_code))
@parametrize("force_kernel", (0, 1))
@unittest.mock.patch.dict(
os.environ, {"TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE": "1"}
)
def test_softmax_force_non_persistent_reduction(self, force_kernel):
"""
Force a specific sub-kernel being picked by mocking the benchmark result.
"""
x = torch.rand(2, 1024).to(GPU_TYPE)
mock_latency = [0.2, 0.2]
mock_latency[force_kernel] = 0.1 # this make sure force_kernel will be picked
def f(x):
return torch.softmax(x, -1) + force_kernel
orig_run = MultiKernelCall.run
picked_kernel = None
def mock_run(self, *args, **kwargs):
out = orig_run(self, *args, **kwargs)
nonlocal picked_kernel
picked_kernel = self.picked_kernel
return out
with unittest.mock.patch.object(
MultiKernelCall, "run", mock_run
), unittest.mock.patch.object(
MultiKernelCall,
"benchmark_sub_kernels",
lambda *args, **kwargs: mock_latency,
):
torch.compile(f)(x)
self.assertEqual(picked_kernel, force_kernel)
@config.patch("warn_mix_layout", True)
def test_softmax_warn_mixed_layout(self):
self.test_softmax()
test_softmax_cpp_wrapper = make_cpp_wrapper_test(
test_softmax, expect_multi_kernel=True
)
def test_layernorm(self):
ln = nn.LayerNorm(1024).to(GPU_TYPE)
x = torch.rand(2, 1024).to(GPU_TYPE)
ref = ln(x)
act = torch.compile(ln)(x)
self.assertEqual(ref, act, atol=1e-4, rtol=1e-4)
def test_inplace_update(self):
"""
Inductor generate inplace kernel for mul.
"""
def f(x, y):
return x.sum(dim=-1, keepdims=True) * (y @ y)
x = torch.rand(1024, 1024).to(GPU_TYPE)
y = torch.rand(1024, 1024).to(GPU_TYPE)
ref = f(x, y)
act = torch.compile(f)(x, y)
self.assertEqual(ref, act)
def test_transformer_snippet(self):
model = TransformerSnippet().to(GPU_TYPE)
x = model.example_inputs()
def f(*x):
y = model(*x)
return y
reset_rng_state()
ref = f(*x)
opt_f = torch.compile(f)
reset_rng_state()
act = opt_f(*x)
# don't compare tensor if using inductor random number generator.
# inductor random number implementation is different to eager.
# We should fallback to eager if we want to test accuracy.
if config.fallback_random:
self.assertEqual(ref, act, atol=1e-4, rtol=1e-4)
def test_transformer_snippet_with_fallback_random(self):
"""
Same as test_transformer_snippet but fallback the random number
generator to eager so we can check accuracy.
"""
with config.patch("fallback_random", True):
self.test_transformer_snippet()
def test_batchnorm_training(self):
"""
For training, batchnorm will tracking running mean/variance during forward pass.
The kernel generated by inductor currently will pass in those tensors twice as arguments:
once for input and once for output. They are ruled out as in-out argument because
they are considered as graph inputs.
Multi-kernel previously assumes that we never pass the same argument mutli times
for a kernel. No mater if we change inductor behavior to assure that, it's better
to make multi-kernel being able to handle those cases.
"""
bn = nn.BatchNorm2d(3).to(GPU_TYPE)
@torch.compile
def f(x):
bn(x).sum().backward()
_, (wrapper_code, _) = run_and_get_code(
f, torch.randn(2, 3, 8, 8, device=GPU_TYPE)
)
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
def test_pass_same_arg_multi_times(self):
"""
A super simple example that simulate how BatchNorm update the running
stats.
Inductor currently pass the same tensor multiple times for the generated
kernel: once for input and once for output.
Here is a paster for the generated kernel (without multi-kernel enabled):
https://gist.github.com/shunting314/f0b446b4b9a28f4940e31dcd3e809cf9
"""
def f(x, y):
x = x.sum(dim=1, keepdim=False)
y.copy_(y * 0.9 + x * 0.1)
x = torch.randn(8, 16, device=GPU_TYPE)
y = torch.randn(8, device=GPU_TYPE)
y_ref = y.clone()
ref = f(x, y_ref)
act = torch.compile(f)(x, y)
self.assertEqual(y_ref, y)
def test_reduction_scratch_buffer(self, force_multi_kernel=1):
"""
The explicited realized buffer in the test function will be passed in
as a scratch buffer for the non-persistent reduction kernel but
can be skipped for the persistent reduction kernel.
This causes different argument lists for non-persistent reduction kernel and
persistent reduction kernel.
Check documentation around torch._inductor.config.triton.multi_kernel about
how to interpret the force_multi_kernel argument.
"""
def f(x):
x = x.sum(dim=-1, keepdim=True) + x
x = test_operators.realize(x)
x = x.sum(dim=-1, keepdim=True) + x
return x
x = torch.rand(16, 16, device=GPU_TYPE)
ref = f(x)
with config.patch("triton.multi_kernel", force_multi_kernel):
act = torch.compile(f)(x)
self.assertEqual(ref, act)
def test_split_scan(self, force_multi_kernel=1):
def f(x):
x = x.view(-1)
return torch.cumsum(x, 0)
x = make_tensor(10, 3, 352, 352, low=0, dtype=torch.float32, device=GPU_TYPE)
expect = f(x)
with config.patch("triton.multi_kernel", force_multi_kernel):
actual = torch.compile(f)(x)
self.assertEqual(expect, actual)
def test_sort_disables_multi_kernel(self, force_multi_kernel=1):
"""
Sort currently requires a persistent kernel, so multi-kernel is not
possible. Make sure this falls back gracefully.
"""
def f(x):
return x.sort(-1).values
x = torch.rand(32, 32, device=GPU_TYPE)
expect = f(x)
with config.patch("triton.multi_kernel", force_multi_kernel):
actual = torch.compile(f)(x)
self.assertEqual(expect, actual)
# Use benchmarking to pick the faster kernel
test_reduction_scratch_buffer_cpp_wrapper = make_cpp_wrapper_test(
test_reduction_scratch_buffer, force_multi_kernel=1
)
# force pick persistent reduction. This can be a good test since this persistent
# reduction uses less call arguments than the corresponding non-persistent
# reduction.
test_reduction_scratch_buffer_cpp_wrapper_persistent_reduction = (
make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=2)
)
# force pick non-persistent reduction
test_reduction_scratch_buffer_cpp_wrapper_non_persistent_reduction = (
make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=3)
)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_GPU:
run_tests()
|