1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
|
# Owner(s): ["module: inductor"]
import contextlib
from unittest import skipIf
import torch
import torch.distributed as dist
from torch._inductor import config, metrics
from torch._inductor.comm_analysis import estimate_nccl_collective_runtime
from torch._inductor.compile_fx import compile_fx, compile_fx_inner
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import is_collective
from torch.testing._internal.common_device_type import expectedFailureXPU
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
aten = torch.ops.aten
c10d = torch.ops.c10d_functional
_c10d = torch.ops._c10d_functional
def compile_but_use_eager(gm, example_inputs):
def inner_compile(gm, *args, **kwargs):
compile_fx_inner(gm, *args, **kwargs)
return gm
return compile_fx(gm, example_inputs, inner_compile=inner_compile)
def calculate_runtime(f, *args) -> float:
"""
Assumes all inputs are fp32
"""
metrics.reset()
torch.compile(f, backend=compile_but_use_eager)(*args)
print(metrics.node_runtimes)
ret = 0.0
for pair in metrics.node_runtimes:
ret += pair[1]
return ret
DEVICE = GPU_TYPE
def T(*size, dtype=torch.float32, device=DEVICE, grad=False) -> torch.Tensor:
return torch.randn(size, dtype=dtype, device=device, requires_grad=grad)
class TestCase(InductorTestCase):
device = DEVICE
"""
Helper methods to compare runtime estimate against 0. Since this estimate is hardware dependent,
stronger comparisons may fail dependending on the host's specs.
atol/rtol must be provided explicitly with each call, since precision/rel_tol overrides are not always utilized
"""
def setUp(self):
super().setUp()
# These tests check metrics.node_runtimes and we don't save / restore
# those in the FX graph cache.
self._test_snode_stack = contextlib.ExitStack()
self._test_snode_stack.enter_context(
config.patch({"fx_graph_remote_cache": False})
)
def tearDown(self):
self._test_snode_stack.close()
super().tearDown()
def assertZero(self, x: float):
assert isinstance(x, float)
super().assertEqual(x, 0.0, atol=0, rtol=0)
def assertNotZero(self, x):
assert isinstance(x, float)
super().assertNotEqual(x, 0.0, atol=0, rtol=0)
class UnsupportedTests(TestCase):
device = DEVICE
def test_no_op(self):
def f(a):
return a
inp = (T(10, 10),)
self.assertZero(calculate_runtime(f, *inp))
def test_no_cuda(self):
def f(a):
return a
inp = (torch.randn((10, 10), device="cpu"),)
self.assertZero(calculate_runtime(f, *inp))
class ComputeBoundedTests(TestCase):
device = DEVICE
# lack of profiler on XPU
@expectedFailureXPU
def test_conv1d(self):
def f(x, y):
return torch.nn.functional.conv1d(x, y)
inp = (T(33, 16, 30), T(20, 16, 5))
self.assertNotZero(calculate_runtime(f, *inp))
# lack of profiler on XPU
@expectedFailureXPU
def test_conv2d(self):
def f(x, y):
return torch.nn.functional.conv2d(x, y, padding=1)
inp = (T(8, 4, 3, 3), T(1, 4, 5, 5))
self.assertNotZero(calculate_runtime(f, *inp))
# lack of profiler on XPU
@expectedFailureXPU
def test_conv2d_transpose(self):
def f(x, y):
return torch.nn.functional.conv_transpose2d(x, y, padding=1)
inp = (T(8, 1, 1, 1), T(1, 4, 5, 5))
self.assertNotZero(calculate_runtime(f, *inp))
# lack of profiler on XPU
@expectedFailureXPU
def test_conv3d(self):
def f(x, y):
return torch.nn.functional.conv3d(x, y)
inp = (T(20, 16, 50, 10, 20), T(33, 16, 3, 3, 3))
self.assertNotZero(calculate_runtime(f, *inp))
# lack of profiler on XPU
@expectedFailureXPU
def test_mm(self):
def f(a, b):
return torch.mm(a, b)
inp = (
T(10, 10),
T(10, 10),
)
self.assertNotZero(calculate_runtime(f, *inp))
# lack of profiler on XPU
@expectedFailureXPU
def test_addmm(self):
def f(a, b, c):
return torch.addmm(a, b, c)
inp = (
T(10, 10),
T(10, 10),
T(10, 10),
)
self.assertNotZero(calculate_runtime(f, *inp))
# lack of profiler on XPU
@expectedFailureXPU
def test_bmm(self):
def f(a, b):
return torch.bmm(a, b)
inp = (
T(10, 10, 10),
T(10, 10, 10),
)
self.assertNotZero(calculate_runtime(f, *inp))
class MemoryBoundedTests(TestCase):
device = DEVICE
# lack of profiler on XPU
@expectedFailureXPU
def test_relu(self):
def f(a):
return torch.nn.functional.relu(a)
inp = (T(10, 10),)
self.assertNotZero(calculate_runtime(f, *inp))
# lack of profiler on XPU
@expectedFailureXPU
def test_horizontal_reduction_pointwise(self):
def f(a):
b = a.sum(dim=1)
c = a.cos()
return b, c
inp = (T(10, 10),)
self.assertNotZero(calculate_runtime(f, *inp))
# lack of profiler on XPU
@expectedFailureXPU
def test_pointwise(self):
def f(x):
return x.cos()
inp = (T(10),)
self.assertNotZero(calculate_runtime(f, *inp))
# lack of profiler on XPU
@expectedFailureXPU
@torch._dynamo.config.patch(assume_static_by_default=False)
def test_dynamic(self):
def f(x):
return x.cos()
inp = (T(10),)
self.assertNotZero(calculate_runtime(f, *inp))
@skipIf(not dist.is_available(), "requires distributed")
class TestCommAnalysis(TestCase):
device = DEVICE
WORLD_SIZE: int = 8
RANKS = list(range(8))
def _verify_runtime_estimation(self, fn, inps):
from torch.testing._internal.distributed.fake_pg import FakeStore
store = FakeStore()
dist.init_process_group(
backend="fake", rank=0, world_size=self.WORLD_SIZE, store=store
)
try:
metrics.reset()
torch.compile(fn)(*inps)
found_collective = False
for snode, runtime in metrics.node_runtimes:
if not is_collective(snode.node):
continue
found_collective = True
# Inductor swallows errors from snode runtime estimations.
# We call estimate_nccl_collective_runtime in a white-box
# fashion here so potential issues can be surfaced in tests.
est = estimate_nccl_collective_runtime(snode.node)
self.assertNotZero(est)
# Also make sure estimate_nccl_collective_runtime works
# correctly in inductor.
self.assertNotZero(runtime)
# Make sure a collective kernel is found in graph
self.assertTrue(found_collective)
finally:
dist.destroy_process_group()
# lack of profiler on XPU
@expectedFailureXPU
def test_legacy_all_reduce(self):
def fn(x):
r = c10d.all_reduce(x, "sum", "", self.RANKS, self.WORLD_SIZE)
return c10d.wait_tensor(r)
inp = T(10, 10)
self._verify_runtime_estimation(fn, (inp,))
# lack of profiler on XPU
@expectedFailureXPU
def test_legacy_all_reduce_coalesced(self):
def fn(x):
rs = c10d.all_reduce_coalesced(x, "sum", "", self.RANKS, self.WORLD_SIZE)
return [c10d.wait_tensor(r) for r in rs]
inp = [T(10, 10), T(15, 15)]
self._verify_runtime_estimation(fn, (inp,))
# lack of profiler on XPU
@expectedFailureXPU
def test_legacy_all_gather_into_tensor_coalesced(self):
def fn(x):
rs = c10d.all_gather_into_tensor_coalesced(
x,
"",
self.RANKS,
self.WORLD_SIZE,
)
return [c10d.wait_tensor(r) for r in rs]
inp = [T(10, 10), T(15, 15)]
self._verify_runtime_estimation(fn, (inp,))
# lack of profiler on XPU
@expectedFailureXPU
def test_all_reduce(self):
def fn(x):
r = _c10d.all_reduce(x, "sum", "0")
return _c10d.wait_tensor(r)
inp = T(10, 10)
self._verify_runtime_estimation(fn, (inp,))
# lack of profiler on XPU
@expectedFailureXPU
def test_all_reduce_coalesced(self):
def fn(x):
rs = _c10d.all_reduce_coalesced(x, "sum", "0")
return [_c10d.wait_tensor(r) for r in rs]
inp = [T(10, 10), T(15, 15)]
self._verify_runtime_estimation(fn, (inp,))
# lack of profiler on XPU
@expectedFailureXPU
def test_all_gather_into_tensor(self):
def fn(x):
rs = _c10d.all_gather_into_tensor(
x,
self.WORLD_SIZE,
"0",
)
return [_c10d.wait_tensor(r) for r in rs]
inp = T(10, 10)
self._verify_runtime_estimation(fn, (inp,))
# lack of profiler on XPU
@expectedFailureXPU
def test_all_gather_into_tensor_coalesced(self):
def fn(x):
rs = _c10d.all_gather_into_tensor_coalesced(
x,
self.WORLD_SIZE,
"0",
)
return [_c10d.wait_tensor(r) for r in rs]
inp = [T(10, 10), T(15, 15)]
self._verify_runtime_estimation(fn, (inp,))
# lack of profiler on XPU
@expectedFailureXPU
def test_reduce_scatter_tensor(self):
def fn(x):
rs = _c10d.reduce_scatter_tensor(
x,
"sum",
self.WORLD_SIZE,
"0",
)
return [_c10d.wait_tensor(r) for r in rs]
inp = T(self.WORLD_SIZE, 10)
self._verify_runtime_estimation(fn, (inp,))
# lack of profiler on XPU
@expectedFailureXPU
def test_reduce_scatter_tensor_coalesced(self):
def fn(x):
rs = _c10d.reduce_scatter_tensor_coalesced(
x,
"sum",
self.WORLD_SIZE,
"0",
)
return [_c10d.wait_tensor(r) for r in rs]
inp = [T(self.WORLD_SIZE, 10), T(self.WORLD_SIZE, 15)]
self._verify_runtime_estimation(fn, (inp,))
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_GPU:
run_tests(needs="filelock")
|