1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
|
# Owner(s): ["module: inductor"]
import copy
import os
import random
import torch
from torch import nn
from torch._dynamo.utils import same
from torch._inductor import config
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_cuda import tf32_off
from torch.testing._internal.common_utils import skipIfXpu
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
USE_DDP_WRAPPER = os.environ.get("USE_DDP_WRAPPER", "1") == "1"
class Model2Conv(nn.Module):
def __init__(self, dim=512, manual_graph_break=False):
super().__init__()
self.conv1 = nn.Conv2d(3, dim, kernel_size=3, stride=2, bias=False)
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=2, bias=False)
self.manual_graph_break = manual_graph_break
def forward(self, x):
x = self.conv1(x)
if self.manual_graph_break:
torch._dynamo.graph_break()
x = self.conv2(x)
return x
def get_example_inputs(self):
return (torch.rand(2, 3, 16, 16),)
@skipIfXpu(msg="ccl doesn't currently work on the XPU stack")
class TestLayoutOptim(TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
import torch.distributed as dist
# not use a fixed port for stress test
tot_retry = 5
for retry_no in range(tot_retry):
try:
port = random.randint(10000, 60000)
if GPU_TYPE == "cuda":
backend = "nccl"
elif GPU_TYPE == "xpu":
backend = "ccl"
dist.init_process_group(
backend=backend,
init_method=f"tcp://localhost:{port}",
world_size=1,
rank=0,
)
break
except RuntimeError:
if retry_no == tot_retry - 1:
raise
else:
continue
def verify_accuracy(
self, model_class, use_ddp_wrapper=USE_DDP_WRAPPER, is_train=False
):
# there are 2 potential ways to introduce graph breaks
# 1. manually
# 2. using DDP
# if we are not using DDP to introduce graph breaks, do that manually
def wrap_mod(m):
if is_train:
def f(*inp):
x = m(*inp)
x.sum().backward()
grads = []
for name, param in m.named_parameters():
grad = param.grad
if param.grad is None:
grad = torch.zeros_like(param)
grads.append(grad)
return grads
return f
else:
return m
manual_graph_break = not use_ddp_wrapper
mod = model_class(manual_graph_break=manual_graph_break).to(GPU_TYPE)
inp = [t.to(GPU_TYPE) for t in mod.get_example_inputs()]
expected_out = wrap_mod(mod)(*inp)
fp64_mod = copy.deepcopy(mod).to(torch.float64)
fp64_inp = [t.to(torch.float64) for t in copy.deepcopy(inp)]
fp64_out = wrap_mod(fp64_mod)(*fp64_inp)
if use_ddp_wrapper:
from torch.nn.parallel import DistributedDataParallel as DDP
ddp_wrapped_mod = DDP(mod)
opt_mod = torch.compile(wrap_mod(ddp_wrapped_mod))
else:
opt_mod = torch.compile(wrap_mod(mod))
actual_out = opt_mod(*inp)
if is_train:
self.assertTrue(same(expected_out, actual_out, fp64_ref=fp64_out))
else:
expected_sum = expected_out.sum()
actual_sum = actual_out.sum()
print(f"Expected sum {expected_sum}, actual sum {actual_sum}")
self.assertTrue(same(expected_out, actual_out, fp64_ref=fp64_out))
def verify_accuracy_for_infer(self, *args, **kwargs):
self.verify_accuracy(*args, **kwargs, is_train=False)
def verify_accuracy_for_train(self, *args, **kwargs):
self.verify_accuracy(*args, **kwargs, is_train=True)
def test_2conv_with_graph_break(self):
"""
Make sure graph break does not cause any accuracy issue.
"""
self.verify_accuracy_for_infer(Model2Conv)
def test_3conv_with_graph_break(self):
class Model(nn.Module):
def __init__(
self, dim=512, patch_size=7, kernel_size=7, manual_graph_break=False
):
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(
3, dim, kernel_size=patch_size, stride=patch_size, bias=False
),
nn.Conv2d(
dim, dim, kernel_size, groups=dim, padding="same", bias=False
),
)
self.conv = nn.Conv2d(dim, dim, kernel_size=1, bias=False)
self.manual_graph_break = manual_graph_break
def forward(self, x):
x = self.seq(x)
if self.manual_graph_break:
torch._dynamo.graph_break()
x = self.conv(x)
return x
def get_example_inputs(self):
return (torch.randn(2, 3, 16, 16),)
self.verify_accuracy_for_infer(Model)
@torch.no_grad()
def test_keep_output_layout_infer(self):
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(
3, 128, kernel_size=3, padding=1, stride=1, bias=False
)
def forward(self, x):
x = self.conv(x)
return x
def get_example_inputs(self):
return (torch.randn(2, 3, 5, 5),)
mod = Model().to(GPU_TYPE)
inp = [t.to(GPU_TYPE) for t in mod.get_example_inputs()]
out = mod(*inp)
opt_mod = torch.compile(mod)
opt_out = opt_mod(*inp)
# We should be able to do view on eager output
out.view(5, -1)
# We should be able to do view on the output of the optimized module
# Note that if the output is channels last, the view op will fail.
opt_out.view(5, -1)
def test_keep_output_layout_with_freezing(self):
with config.patch(
{
"freezing": True,
}
):
self.test_keep_output_layout_infer()
def test_training_acc(self):
self.verify_accuracy_for_train(Model2Conv)
def test_mutate_view(self):
"""
The GraphModule passed to GraphLowering init method is like:
https://gist.github.com/shunting314/07228313fd017e2267101ff32edc6d64
It shows that we will call copy_ to update the argument in the end. This
guarantees the correctnesss.
"""
@torch.compile
def f(x):
y = x.view(3, 2)
y.mul_(2)
x = torch.ones(2, 3).to(GPU_TYPE)
f(x)
self.assertTrue(torch.equal(x, torch.ones(2, 3).to(GPU_TYPE) * 2))
def test_mutate_base(self):
"""
The GraphModule passed to GraphLowering init method is like:
https://gist.github.com/shunting314/fd60fe11d1f844c6db76aba7b06811bc
It shows that the output of the graph is the mul node which contains
the update we applied to the base tensor.
"""
@torch.compile
def f(x):
y = x.view(3, 2)
x.mul_(2)
return y
x = torch.ones(2, 3).to(GPU_TYPE)
y = f(x)
self.assertTrue(torch.equal(y, torch.ones(3, 2).to(GPU_TYPE) * 2))
@tf32_off()
def test_mutate_base_for_conv_output(self):
class Model(nn.Module):
def __init__(self, manual_graph_break=False):
super().__init__()
self.conv = nn.Conv2d(3, 512, kernel_size=3, stride=2, bias=False)
def forward(self, x):
x = self.conv(x)
y = x.view(-1)
x.mul_(2)
return y
def get_example_inputs(self):
return (torch.rand(2, 3, 16, 16),)
self.verify_accuracy_for_infer(Model)
@tf32_off()
def test_mutate_view_for_conv_output(self):
class Model(nn.Module):
def __init__(self, manual_graph_break=False):
super().__init__()
self.conv = nn.Conv2d(3, 512, kernel_size=3, stride=2, bias=False)
def forward(self, x):
x = self.conv(x)
y = x.view(-1)
y.mul_(2)
return x
def get_example_inputs(self):
return (torch.rand(2, 3, 16, 16),)
self.verify_accuracy_for_infer(Model)
def test_dynamic_shape_specialization(self):
"""
Previously in aot_autograd.py we compare strides of FakeTensor
with real tensor. That cause dynamic dimensions of the FakeTensor
being specialized to static shapes. This test protects against that.
"""
def f(a, b):
x = a.sin()
y = b.cos()
z = x + y
return z
for size in [4, 8, 16]:
a = torch.randn(2, size, requires_grad=True).to(GPU_TYPE)
b = torch.randn(2, size).to(GPU_TYPE)
actual = torch.compile(f, dynamic=True)(a, b)
self.assertTrue(torch.allclose(f(a, b), actual))
# Trigger the compiling of the backward graph
actual.sum().backward()
def test_nll_loss_backward(self):
"""
Repro for issue https://github.com/pytorch/pytorch/issues/120759
The CUDA implementation of aten.nll_loss2d_backward.default requires
the self tensor (whose layout will be used to create grad_input)
to be contiguous. Layout optimization may change the self tensor's layout
and cause failure. We fix that by adding layout constaints to the
fallback of aten.nll_loss2d_backward.default .
"""
class MyModel(torch.nn.Module):
def __init__(self, input_dim, num_classes):
super().__init__()
self.conv = torch.nn.Conv2d(1, num_classes, 3, 1, padding="same")
self.out = torch.nn.Linear(input_dim * num_classes, num_classes)
def forward(self, x: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.reshape(b, t, c * f))
logits = x.reshape(x.size(0), x.size(2), x.size(1))
loss = torch.nn.functional.cross_entropy(logits, targets)
return loss
device = GPU_TYPE
batch_size = 48
seq_len = 144
input_dim = 39
num_classes = 111
model = MyModel(input_dim, num_classes)
model.to(device)
opt_model = torch.compile(model)
x = torch.ones((batch_size, 1, seq_len, input_dim), device=device)
targets = torch.randint(
0, num_classes - 1, (batch_size, seq_len), device=device, dtype=torch.int64
)
loss = model(x, targets)
loss.backward()
ref = model(x, targets)
self.assertTrue(torch.allclose(ref, loss))
if __name__ == "__main__":
if HAS_GPU:
run_tests()
|