1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396
|
# Owner(s): ["module: inductor"]
import unittest
from unittest.mock import patch
import torch
import torch._dynamo
import torch._dynamo.logging
import torch._dynamo.test_case
# for some reason importing functional collectives after dynamo breaks collectives handling!
import torch.distributed._functional_collectives as _functional_collectives
from torch._C import FileCheck
from torch._dynamo.utils import same
from torch._inductor import ir, scheduler
from torch._inductor.comm_analysis import (
baseLat,
hwLat,
llMaxBws,
NCCL_ALGO,
NCCL_HW,
NCCL_PROTO,
NVIDIA_GPU_TYPE,
)
from torch._inductor.utils import run_and_get_triton_code
from torch.testing._internal.common_distributed import (
_dynamo_dist_per_rank_init,
at_least_x_gpu,
DynamoDistributedMultiProcTestCase,
requires_nccl,
)
from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.inductor_utils import HAS_GPU
def get_snode_runtime_for_reorder_compute_test(snode):
# NOTE: custom cost model to show that the compute reordering algorithm is working
# Collective kernels
if isinstance(snode.node, ir._CollectiveKernel):
return 100
elif isinstance(snode.node, ir._WaitKernel):
return 0
# High-arithmetic-intensity compute kernels
elif isinstance(snode.node, ir.ExternKernel):
return 5
# All other kernels
return 1
def create_grouped_node_for_allreduce_and_its_deps(snodes):
name_to_snode = {snode.node.name: snode for snode in snodes}
all_reduce_snodes = [
snode
for snode in snodes
if isinstance(snode.node, ir._CollectiveKernel)
and snode.node.op_overload == torch.ops._c10d_functional.all_reduce_.default
]
assert len(all_reduce_snodes) == 1
all_reduce_snode = all_reduce_snodes[0]
all_reduce_dep_snodes = [
name_to_snode[node.name] for node in all_reduce_snode.node.inputs
]
assert len(all_reduce_dep_snodes) == 1
all_reduce_dep_snode = all_reduce_dep_snodes[0]
grouped_snode = scheduler.GroupedSchedulerNode.create(
[all_reduce_dep_snode, all_reduce_snode]
)
new_snode_order = []
new_snode_order.append(grouped_snode)
for snode in snodes:
if snode in grouped_snode.snodes:
continue
new_snode_order.append(snode)
return new_snode_order
@requires_nccl()
class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
"""
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
"""
def get_world_trs(self):
return {
"tag": "",
"ranks": list(range(self.world_size)),
"group_size": self.world_size,
}
@property
def world_size(self) -> int:
# hack: no matter whether we have 2 or 3 or 4 gpus, just run on 2
# works around issue with skipif<2 and workers with unpredictable #s gpu
return 2
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(torch._inductor.config, "reorder_for_locality", False)
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
@patch.object(
torch._inductor.config,
"reorder_for_compute_comm_overlap_passes",
[
"sink_waits",
],
)
def test_sink_waits(self):
def func(a):
ar = _functional_collectives.all_reduce(a, "sum", "0")
b = torch.matmul(a, a)
return torch.matmul(ar, b)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs)
# Verify that the wait_tensor is sinked below the 1st matmul but
# above the 2nd matmul.
(
FileCheck()
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("extern_kernels.mm")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("extern_kernels.mm")
.run(code)
)
out = compiled(inputs)
correct = func(inputs)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(torch._inductor.config, "reorder_for_locality", False)
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
@patch.object(
torch._inductor.config,
"reorder_for_compute_comm_overlap_passes",
[
"raise_comms",
],
)
def test_raise_comms(self):
def func(a):
b = torch.matmul(a, a)
c = torch.relu(b)
d = torch.matmul(c, c)
e = _functional_collectives.all_reduce(b, "sum", "0")
return torch.matmul(d, e)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs)
# Verify that the all_reduce_ has been raised above the 2nd matmul
# but below the 1st matmul. Note that the all_reduce_ directly
# writes to the output buffer of the 1st matmul, which is an input
# to the first relu. Therefore, the all_reduce_ should be scheduled
# after the first relu.
(
FileCheck()
.check("extern_kernels.mm")
.check("triton_poi_fused_relu")
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("extern_kernels.mm")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("extern_kernels.mm")
.run(code)
)
out = compiled(inputs)
correct = func(inputs)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
@patch.object(
torch._inductor.config,
"reorder_for_compute_comm_overlap_passes",
[
"sink_waits",
"raise_comms",
],
)
def test_sink_waits_raise_comms(self):
def func(a, *, tag, ranks, group_size):
b = torch.matmul(a, a)
c = torch.relu(b)
d = torch.matmul(c, c)
e = _functional_collectives.all_reduce(b, "sum", "0")
f = torch.relu(d)
g = torch.matmul(f, f)
return torch.mm(e, g)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# Things to verify:
# - The clone prologue of the all_reduce_ should not be fused with
# any relus.
# - The all_reduce_ and its prologue should be raised above the 2nd
# matmul but below the 1st matmul.
# - The wait_tensor should be sinked below the 3rd matmul but above
# the 4th matmul.
(
FileCheck()
.check("extern_kernels.mm")
.check("triton_poi_fused_all_reduce_0")
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("triton_poi_fused_relu")
.check("extern_kernels.mm")
.check("triton_poi_fused_relu")
.check("extern_kernels.mm")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("extern_kernels.mm")
.run(code)
)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
@patch.object(
torch._inductor.config,
"reorder_for_compute_comm_overlap_passes",
[
"reorder_compute_for_overlap",
],
)
def test_reorder_compute_for_overlap(self):
def func(a, *, tag, ranks, group_size):
ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
g = torch.matmul(a, a)
c = torch.relu(a)
d = torch.matmul(c, c)
f = d * c * ar
fr = _functional_collectives.all_reduce(f, "sum", ranks, tag)
e = torch.matmul(d + ar + fr, g)
return (e,)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# NOTE: after scheduling the first all_reduce:
# 1. we first schedule the ops (c and d) that ARE required for second all_reduce but DO NOT depend on first all_reduce.
# 2. then, we schedule the ops (g) that ARE NOT required for second all_reduce and DO NOT depend on first all_reduce.
# 3. then, we schedule the ops (f) that ARE required for second all_reduce and DO depend on first all_reduce.
# and then, we schedule the second all_reduce. And then schedule all ops that depend on second all_reduce.
(
FileCheck()
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("triton_poi_fused_relu")
.check("extern_kernels.mm")
.check("extern_kernels.mm")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("triton_poi_fused_all_reduce_mul")
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("triton_poi_fused_add")
.check("extern_kernels.mm")
.run(code)
)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "allow_buffer_reuse", True)
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(torch._inductor.config, "reorder_for_compute_comm_overlap", True)
@patch.object(
torch._inductor.config,
"reorder_for_compute_comm_overlap_passes",
[
"reorder_compute_for_overlap",
],
)
@patch.object(
torch._inductor.config,
"estimate_op_runtime",
get_snode_runtime_for_reorder_compute_test,
)
def test_reorder_compute_for_overlap_custom_runtime_estimation(self):
def func(a, *, tag, ranks, group_size):
ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
g = torch.matmul(a, a)
c = torch.relu(a)
d = torch.matmul(c, c)
f = d * c * ar
fr = _functional_collectives.all_reduce(f, "sum", ranks, tag)
e = torch.matmul(d + ar + fr, g)
return (e,)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# NOTE: after scheduling the first all_reduce:
# 1. we first schedule the ops (c and d) that ARE required for second all_reduce but DO NOT depend on first all_reduce.
# 2. then, we schedule the ops (g) that ARE NOT required for second all_reduce and DO NOT depend on first all_reduce.
# 3. then, we schedule the ops (f) that ARE required for second all_reduce and DO depend on first all_reduce.
# and then, we schedule the second all_reduce. And then schedule all ops that depend on second all_reduce.
(
FileCheck()
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("triton_poi_fused_relu")
.check("extern_kernels.mm")
.check("extern_kernels.mm")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("triton_poi_fused_all_reduce_mul")
.check("torch.ops._c10d_functional.all_reduce_.default")
.check("torch.ops._c10d_functional.wait_tensor.default")
.check("triton_poi_fused_add")
.check("extern_kernels.mm")
.run(code)
)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skipIfRocm
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@patch.object(
torch._inductor.config,
"_pre_fusion_custom_pass",
create_grouped_node_for_allreduce_and_its_deps,
)
def test_grouped_scheduler_node(self):
def func(a, *, tag, ranks, group_size):
add = a + a
div = add / a
ar = _functional_collectives.all_reduce(div, "sum", ranks, tag)
# Normally, we would fuse `add = a + a`, `div = add / a` and `mul = a * a` together into a single fused op,
# but here in this unit test, we intentionally put `add`, `div` and `ar` computation
# into a GroupedSchedulerNode, which prevents them from being fused with any other ops.
mul = a * a
mm = torch.matmul(mul, ar)
return (mm,)
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, fake_pg=not at_least_x_gpu(2)
):
inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# Expectations:
# 1. `add = a + a` and `div = add / a` are still fused, which means fusion
# still happens among nodes within a GroupedSchedulerNode.
# 2. `mul = a * a` is not fused with `add` or `div`, because the latter two are within
# GroupedSchedulerNode and thus are prevented from being fused with any outside ops.
FileCheck().check("triton_poi_fused_add_all_reduce_div_0.").check(
"_c10d_functional.all_reduce_."
).check("triton_poi_fused_mul_1.").run(code)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
def test_nccl_heuristics(self):
assert len(baseLat) == len(NCCL_ALGO)
assert all(len(x) == len(NCCL_PROTO) for x in baseLat)
assert len(hwLat) == len(NCCL_HW)
assert all(len(x) == len(NCCL_ALGO) for x in hwLat)
assert all(len(y) == len(NCCL_PROTO) for x in hwLat for y in x)
assert len(llMaxBws) == len(NVIDIA_GPU_TYPE)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()
|