1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
|
import argparse
import random
import time
from abc import abstractmethod
from typing import Any
from tqdm import tqdm # type: ignore[import-untyped]
import torch
class BenchmarkRunner:
"""
BenchmarkRunner is a base class for all benchmark runners. It provides an interface to run benchmarks in order to
collect data with AutoHeuristic.
"""
def __init__(self, name: str) -> None:
self.name = name
self.parser = argparse.ArgumentParser()
self.add_base_arguments()
self.args = None
def add_base_arguments(self) -> None:
self.parser.add_argument(
"--device",
type=int,
default=None,
help="torch.cuda.set_device(device) will be used",
)
self.parser.add_argument(
"--use-heuristic",
action="store_true",
help="Use learned heuristic instead of collecting data.",
)
self.parser.add_argument(
"-o",
type=str,
default="ah_data.txt",
help="Path to file where AutoHeuristic will log results.",
)
self.parser.add_argument(
"--num-samples",
type=int,
default=1000,
help="Number of samples to collect.",
)
self.parser.add_argument(
"--num-reps",
type=int,
default=3,
help="Number of measurements to collect for each input.",
)
def run(self) -> None:
torch.set_default_device("cuda")
args = self.parser.parse_args()
if args.use_heuristic:
torch._inductor.config.autoheuristic_use = self.name
torch._inductor.config.autoheuristic_collect = ""
else:
torch._inductor.config.autoheuristic_use = ""
torch._inductor.config.autoheuristic_collect = self.name
torch._inductor.config.autoheuristic_log_path = args.o
if args.device is not None:
torch.cuda.set_device(args.device)
random.seed(time.time())
self.main(args.num_samples, args.num_reps)
@abstractmethod
def run_benchmark(self, *args: Any) -> None: ...
@abstractmethod
def create_input(self) -> tuple[Any, ...]: ...
def main(self, num_samples: int, num_reps: int) -> None:
for _ in tqdm(range(num_samples)):
input = self.create_input()
for _ in range(num_reps):
self.run_benchmark(*input)
|