1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
|
# Owner(s): ["module: unknown"]
import unittest
from dataclasses import dataclass
from typing import Any, Callable, cast, Tuple, Union
import torch
from torch import nn, optim
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed._tools.runtime_estimator import RuntimeEstimator
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
)
@dataclass
class ConvArgs:
image_size: int
num_classes: int
class SimpleCNN(nn.Module):
def __init__(self, conv_args: ConvArgs):
super().__init__()
image_size = conv_args.image_size
num_classes = conv_args.num_classes
self.image_size = image_size
self.conv1 = nn.Conv2d(3, 32, kernel_size=5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3)
self.conv4 = nn.Conv2d(128, 256, kernel_size=3)
self.fc1_size = self._calculate_fc1_size()
self.fc1 = nn.Linear(self.fc1_size, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, num_classes)
def _calculate_fc1_size(self):
size = self.image_size
size = (size - 5 + 1) // 2 # conv1 and pool
size = (size - 5 + 1) // 2 # conv2 and pool
size = size - 3 + 1 # conv3
size = (size - 3 + 1) // 2 # conv4 and pool
return 512 * size * size
def forward(self, x):
x = self.pool(nn.functional.relu(self.conv1(x)))
x = self.pool(nn.functional.relu(self.conv2(x)))
x = nn.functional.relu(self.conv3(x))
x = self.pool(nn.functional.relu(self.conv4(x)))
x = x.view(-1, self.fc1_size)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = self.fc3(x)
return x
class TestRuntimeEstimator(TestCase):
def _train_step(
self,
model: nn.Module,
optimizer: optim.Optimizer,
inp: torch.Tensor,
):
out = model(inp)
loss = out.sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()
def _measure_actual_cuda_time(
self,
func: Callable,
args: Tuple[Any, ...],
) -> float:
warmup_iters, actual_iters = 2, 5
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
for _ in range(warmup_iters):
func(*args)
start_event.record()
for _ in range(actual_iters):
func(*args)
end_event.record()
torch.cuda.synchronize()
measured_time = start_event.elapsed_time(end_event) / actual_iters
return measured_time
def _runtime_estimate(
self,
estimate_mode: str,
func: Callable,
args: Tuple[Any, ...],
) -> float:
# Optimizer init step
func(*args)
runtime_estimator = RuntimeEstimator()
with runtime_estimator(estimate_mode_type=estimate_mode):
func(*args)
return runtime_estimator.total_runtime
def _init_model_and_args(
self,
model_type: str,
model_args: Union[ConvArgs, ModelArgs],
bsz: int,
) -> Tuple[nn.Module, optim.Optimizer, torch.Tensor]:
dev = torch.cuda.current_device()
if model_type == "Transformer":
model_args = cast(ModelArgs, model_args)
with torch.device(dev):
model = Transformer(model_args)
optimizer = optim.Adam(model.parameters(), lr=1e-2, foreach=True)
inp = torch.randint(
0, model_args.vocab_size, (bsz, model_args.max_seq_len), device=dev
)
elif model_type == "CNN":
model_args = cast(ConvArgs, model_args)
with torch.device(dev):
model = SimpleCNN(model_args)
optimizer = optim.SGD(model.parameters(), lr=1e-2, foreach=True)
inp = torch.randn(
bsz, 3, model_args.image_size, model_args.image_size, device=dev
)
else:
raise NotImplementedError("Only Transformer and CNN is supported")
return (model, optimizer, inp)
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
def test_transformer_runtime(
self,
):
"""Runs a basic GPT-2 model"""
vocab_size = 8192
bsz, seq_len = 8, 1024
model_args = ModelArgs(
n_layers=4,
n_heads=12,
vocab_size=vocab_size,
max_seq_len=seq_len,
dim=768,
dropout_p=0.1,
)
args = self._init_model_and_args("Transformer", model_args, bsz)
actual_runtime = self._measure_actual_cuda_time(self._train_step, args)
with FakeTensorMode():
fake_args = self._init_model_and_args("Transformer", model_args, bsz)
benchmark_estimate = self._runtime_estimate(
"operator-level-benchmark", self._train_step, fake_args
)
roofline_estimate = self._runtime_estimate(
"operator-level-cost-model", self._train_step, fake_args
)
benchmark_accuracy = actual_runtime / benchmark_estimate
roofline_accuracy = actual_runtime / roofline_estimate
print(
f"Actual: {actual_runtime} Benchmark Estimate: {benchmark_estimate} Accuracy: {benchmark_accuracy}"
f"\n Actual: {actual_runtime} Roofline Estimatee: {roofline_estimate} Accuracy: {roofline_accuracy}"
)
# No accuracy check for benchmark in CI as it is highly variable
# self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2)
# self.assertAlmostEqual(roofline_accuracy, 1.0, delta=0.3)
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
def test_conv_model_runtime(
self,
):
"""Runs a simple CNN model"""
num_classes = 100
bsz, img_sz = 256, 128
model_args = ConvArgs(img_sz, num_classes)
args = self._init_model_and_args("CNN", model_args, bsz)
actual_runtime = self._measure_actual_cuda_time(self._train_step, args)
with FakeTensorMode():
fake_args = self._init_model_and_args("CNN", model_args, bsz)
benchmark_estimate = self._runtime_estimate(
"operator-level-benchmark", self._train_step, fake_args
)
roofline_estimate = self._runtime_estimate(
"operator-level-cost-model", self._train_step, fake_args
)
benchmark_accuracy = actual_runtime / benchmark_estimate
roofline_accuracy = actual_runtime / roofline_estimate
print(
f"Actual: {actual_runtime} Benchmark Estimate: {benchmark_estimate} Accuracy: {benchmark_accuracy}\n"
f"Actual: {actual_runtime} Roofline Estimatee: {roofline_estimate} Accuracy: {roofline_accuracy}"
)
# No accuracy check for benchmark in CI as it is highly variable
# self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2)
# self.assertAlmostEqual(roofline_accuracy, 1.0, delta=0.4)
if __name__ == "__main__":
run_tests()
|