1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380
|
#!/usr/bin/env python3
import csv
import itertools
import sys
import time
import warnings
from contextlib import nullcontext
import click
import numpy as np
from operator_inp_utils import OperatorInputsLoader
from tqdm import tqdm
import torch
from torch._dynamo.backends.cudagraphs import cudagraphs_inner
from torch._dynamo.testing import same
from torch._inductor.compile_fx import compile_fx
from torch._inductor.decomposition import decompositions
from torch._inductor.lowering import lowerings
from torch._inductor.runtime.benchmarking import benchmarker
from torch._inductor.utils import gen_gm_and_inputs
from torch.utils._pytree import tree_map_only
aten = torch.ops.aten
profile_enabled = False
inductor_config_options = {
"halide": {"cpu_backend": "halide", "cuda_backend": "halide"},
"autotune": {
"max_autotune_pointwise": True,
"max_autotune": True,
"max_autotune_gemm": True,
"coordinate_descent_tuning": True,
},
}
def maybe_record_function(name):
return torch.profiler.record_function(name) if profile_enabled else nullcontext()
def compute_speedups(
operator, models, example_inputs, repeats, accuracy_checking=False, device="cuda"
):
expected = models[0](*example_inputs)
if accuracy_checking:
for model in models[1:]:
actual = model(*example_inputs)
# change to assert later
try:
same(actual, expected, cos_similarity=True, equal_nan=True)
except AssertionError as e:
print(e)
print(f"Accuracy check failed: {operator}")
print((expected[0] - actual[0]).abs().max())
timings = np.zeros((repeats, len(models)), np.float64)
for rep in range(repeats):
with maybe_record_function(f"rep_{rep}"):
# interleave the runs to handle frequency scaling and load changes
for m, model in enumerate(models):
with maybe_record_function(f"model_{m}"):
if device == "cuda":
model(*example_inputs)
# benchmarker.benchmark_gpu() clears L2 cache to hide the latency of CPU launch time
# along with cuda synchronization
timings[rep, m] = benchmarker.benchmark_gpu(
lambda: model(*example_inputs)
)
else:
from torch._inductor.utils import timed
timings[rep, m] = timed(model, example_inputs)
return np.median(timings, axis=0)
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
def convert_to_jit(gm, gm_args):
strip_overloads(gm)
try:
return torch.jit.script(gm)
except Exception:
pass
return torch.jit.trace(gm, gm_args)
def to_channels_last(ten):
return ten if ten.ndim != 4 else ten.to(memory_format=torch.channels_last)
def microbenchmark(
operator,
args,
kwargs,
accuracy_checking,
repeats,
inductor_configs,
measure_nvfuser,
device,
):
gm, gm_args = gen_gm_and_inputs(operator, args, kwargs)
torch.jit._builtins._register_builtin(
torch.ops.aten.convolution_backward.default, "aten::convolution_backward"
)
compiled = [gm]
for config in inductor_configs:
t = -time.perf_counter()
compiled.append(compile_fx(gm, gm_args, config_patches=config))
t += time.perf_counter()
if t > 10:
print(f"slow compile inductor {t:.1f}s {config}")
if measure_nvfuser:
g = convert_to_jit(gm, gm_args)
cudagraphs_jit = cudagraphs_inner(
g, gm_args, copy_outputs=False, copy_inputs=False
)
compiled += [cudagraphs_jit]
if accuracy_checking:
repeats = 1
medians = compute_speedups(
operator, compiled, gm_args, repeats, accuracy_checking, device
)
return medians
quantiles_thresholds = (0.2, 0.5, 0.8)
def quantiles(timings):
return np.quantile(timings, quantiles_thresholds).tolist()
def skip_operator(operator):
nyi_strings = (
"aten.gather.default",
"nll_loss",
"aten.index",
"aten.scatter_",
"masked_fill_.Scalar",
)
if any(nyi_string in str(operator) for nyi_string in nyi_strings):
# maybe disable aten.native_layer_norm.default
# TODO - inputs cannot be randomly initialized, causes cyda failures
print(f"Skipping {operator}, input generator nyi")
return True
# not covered by other non-compute operator heuristics
if operator == torch.ops.aten._unsafe_view.default:
print(f"Skipping {operator}, non compute operator")
return True
# some of inductor registered to the OpOverload, some registered to OpOverloadPacket
op_impls = [operator]
if isinstance(operator, torch._ops.OpOverload):
op_impls.append(operator.overloadpacket)
# TODO - skip benchmarking fallbacks. for some ops we have both lowerings and fallbacks
# so its not clear just from operator what will be lowered.
if all(op not in decompositions and op not in lowerings for op in op_impls):
print(f"Skipping {operator}, no inductor impl")
return True
if "convolution" in str(operator):
return True
return False
@click.command()
@click.option(
"--suite",
help="suite to load inps from: options: timm, huggingface, torchbench",
default="torchbench",
)
@click.option("--op", help="operator overload to benchmark", default="all")
@click.option("--dtype", help="dtype to benchmark", default="float32")
@click.option("--max-samples", help="max samples per op", default=15)
@click.option("--accuracy-checking", help="check accuracy", default=False)
@click.option(
"--repeats", help="how many times to repeat for perf measurement", default=3
)
@click.option(
"--inductor-config",
multiple=True,
help="Custom inductor config, options: " + ", ".join(inductor_config_options),
)
@click.option(
"--measure-nvfuser/--no-measure-nvfuser",
help="default we only measure inductor",
default=False,
)
@click.option("--device", help="cpu or cuda", default="cuda")
@click.option("--inp-file", help="use custom input file instead of suite", default=None)
@click.option("--start-idx", help="specify start index of samples", default=0)
@click.option(
"--channels-last", help="force inputs to channels last", is_flag=True, default=False
)
@click.option("--profile", help="profile the benchmark", is_flag=True, default=False)
def benchmark(
suite,
op,
dtype,
max_samples,
accuracy_checking,
repeats,
inductor_config,
measure_nvfuser,
device,
inp_file,
start_idx,
channels_last,
profile,
):
warnings.filterwarnings("ignore", module="torch.jit._check")
torch.set_float32_matmul_precision("high")
global profile_enabled
if inp_file is not None:
loader = OperatorInputsLoader(inp_file)
else:
assert suite in ("timm", "huggingface", "torchbench"), f"got {suite}"
if suite == "timm":
loader = OperatorInputsLoader.get_timm_loader()
elif suite == "huggingface":
loader = OperatorInputsLoader.get_huggingface_loader()
else:
loader = OperatorInputsLoader.get_torchbench_loader()
assert dtype in ("float16", "float32"), f"got {dtype}"
inductor_configs = [{}]
backend_names = ["inductor"]
for name in inductor_config or ():
backend_names.append(name)
inductor_configs.append(inductor_config_options[name])
if measure_nvfuser:
backend_names.append("nvfuser")
compare2 = len(backend_names) == 2
if compare2:
a, b = backend_names
backend_names.append(f"{a}/{b}")
output_fd = None
output_csv = None
if op == "all":
filename = f"operatorbench_{suite}_{dtype}.csv"
output_fd = open(filename, "w")
output_csv = csv.writer(output_fd)
output_csv.writerow(
[
"operator",
*[
f"{a} {b}"
for a, b in itertools.product(
backend_names,
[f"{x * 100:.0f}th" for x in quantiles_thresholds],
)
],
"elapsed",
*map("{} abs".format, ["eager", *backend_names]),
]
)
dtype = torch.float16 if dtype == "float16" else torch.float32
if op == "all":
ops = loader.get_all_ops()
else:
ops = [eval(op)]
max_samples = max_samples + start_idx
profile_enabled = profile
for operator in ops:
if skip_operator(operator):
continue
start = time.perf_counter()
inp_gen = loader.get_inputs_for_operator(operator, dtype=dtype, device=device)
timings = []
inputs_list = []
for _ in range(min(max_samples, 1000000)):
try:
inps = next(inp_gen)
inputs_list.append(inps)
except StopIteration:
break
profiler_context = (
torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=False,
profile_memory=False,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
f"./log/operator_{operator}", use_gzip=True
),
)
if profile_enabled
else nullcontext()
)
with profiler_context:
for i, inps in enumerate(tqdm(inputs_list[start_idx:], desc=str(operator))):
if inps is None:
break
args, kwargs = inps
if channels_last:
args, kwargs = tree_map_only(
torch.Tensor, to_channels_last, (args, kwargs)
)
try:
with maybe_record_function(f"iter_{i}"):
# aten, nvfuser, inductor
timings.append(
microbenchmark(
operator,
args,
kwargs,
accuracy_checking,
repeats,
inductor_configs,
measure_nvfuser,
device,
)
)
except Exception as e:
print(f"error {operator} input {i}: {type(e).__name__}: {e}")
# comment out this line to avoid blocking other tests
# raise e
if not timings:
continue
timings = np.stack(timings)
speedups = [
quantiles(timings[:, 0] / timings[:, x]) for x in range(1, timings.shape[1])
]
if compare2:
speedups.append(quantiles(timings[:, 1] / timings[:, 2]))
assert len(backend_names) == len(speedups)
row = [f"{operator}"]
sys.stdout.write(f"{operator}: ")
for backend, (low, mid, high) in zip(backend_names, speedups):
sys.stdout.write(f"{backend}={mid:.4f}x ({low:.4f}-{high:.4f}) ")
row.extend(map("{:.6f}".format, [low, mid, high]))
elapsed = time.perf_counter() - start
row.append(f"{elapsed:1f}")
row.extend(map("{:.8f}".format, np.mean(timings, axis=0).tolist()))
sys.stdout.write(f"took {elapsed:.0f}s\n")
sys.stdout.flush()
if output_csv:
output_csv.writerow(row)
output_fd.flush()
if output_fd:
print(f"Wrote {filename}")
output_fd.close()
if __name__ == "__main__":
benchmark()
|