1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
|
"""File invoked through subprocess to actually carry out measurements.
`worker/main.py` is deliberately isolated from the rest of the benchmark
infrastructure. Other parts of the benchmark rely on this file, but
`worker/` has only one Python file and does not import ANYTHING from the rest
of the benchmark suite. The reason that this is important is that we can't
rely on paths to access the other files (namely `core.api`) since a source
command might change the CWD. It also helps keep startup time down by limiting
spurious definition work.
The life of a worker is very simple:
It receives a file containing a `WorkerTimerArgs` telling it what to run,
and writes a `WorkerOutput` result back to the same file.
Because this file only expects to run in a child context, error handling means
plumbing failures up to the caller, not raising in this process.
"""
import argparse
import dataclasses
import io
import os
import pickle
import timeit
import traceback
from typing import Any, Tuple, Union, TYPE_CHECKING
import sys
if TYPE_CHECKING:
# Benchmark utils are only partially strict compliant, so MyPy won't follow
# imports using the public namespace. (Due to an exclusion rule in
# mypy-strict.ini)
from torch.utils.benchmark.utils.timer import Language, Timer
from torch.utils.benchmark.utils.valgrind_wrapper.timer_interface import CallgrindStats
else:
from torch.utils.benchmark import CallgrindStats, Language, Timer
WORKER_PATH = os.path.abspath(__file__)
# =============================================================================
# == Interface ================================================================
# =============================================================================
# While the point of this is mainly to collect instruction counts, we're going
# to have to compile C++ timers anyway (as they're used as a check before
# calling Valgrind), so we may as well grab wall times for reference. They
# are comparatively inexpensive.
MIN_RUN_TIME = 5
# Repeats are inexpensive as long as they are all run in the same process. This
# also lets us filter outliers (e.g. malloc arena reorganization), so we don't
# need a high CALLGRIND_NUMBER to get good data.
CALLGRIND_NUMBER = 100
CALLGRIND_REPEATS = 5
@dataclasses.dataclass(frozen=True)
class WorkerTimerArgs:
"""Container for Timer constructor arguments.
This dataclass serves two roles. First, it is a simple interface for
defining benchmarks. (See core.api.GroupedStmts and core.api.GroupedModules
for the advanced interfaces.) Second, it provides serialization for
controlling workers. `Timer` is not pickleable, so instead the main process
will pass `WorkerTimerArgs` instances to workers for processing.
"""
stmt: str
setup: str = "pass"
global_setup: str = ""
num_threads: int = 1
language: Language = Language.PYTHON
@dataclasses.dataclass(frozen=True)
class WorkerOutput:
# Only return values to reduce communication between main process and workers.
wall_times: Tuple[float, ...]
instructions: Tuple[int, ...]
@dataclasses.dataclass(frozen=True)
class WorkerFailure:
# If a worker fails, we attach the string contents of the Exception
# rather than the Exception object itself. This is done for two reasons:
# 1) Depending on the type thrown, `e` may or may not be pickleable
# 2) If we re-throw in the main process, we lose the true stack trace.
failure_trace: str
class WorkerUnpickler(pickle.Unpickler):
def find_class(self, module: str, name: str) -> Any:
"""Resolve import for pickle.
When the main runner uses a symbol `foo` from this file, it sees it as
`worker.main.foo`. However the worker (called as a standalone file)
sees the same symbol as `__main__.foo`. We have to help pickle
understand that they refer to the same symbols.
"""
symbol_map = {
# Only blessed interface Enums and dataclasses need to be mapped.
"WorkerTimerArgs": WorkerTimerArgs,
"WorkerOutput": WorkerOutput,
"WorkerFailure": WorkerFailure,
}
if name in symbol_map:
return symbol_map[name]
return super().find_class(module, name)
def load_input(self) -> WorkerTimerArgs:
result = self.load()
assert isinstance(result, WorkerTimerArgs)
return result
def load_output(self) -> Union[WorkerTimerArgs, WorkerOutput, WorkerFailure]:
"""Convenience method for type safe loading."""
result = self.load()
assert isinstance(result, (WorkerTimerArgs, WorkerOutput, WorkerFailure))
return result
# =============================================================================
# == Execution ================================================================
# =============================================================================
def _run(timer_args: WorkerTimerArgs) -> WorkerOutput:
timer = Timer(
stmt=timer_args.stmt,
setup=timer_args.setup or "pass",
global_setup=timer_args.global_setup,
# Prevent NotImplementedError on GPU builds and C++ snippets.
timer=timeit.default_timer,
num_threads=timer_args.num_threads,
language=timer_args.language,
)
m = timer.blocked_autorange(min_run_time=MIN_RUN_TIME)
stats: Tuple[CallgrindStats, ...] = timer.collect_callgrind(
number=CALLGRIND_NUMBER,
collect_baseline=False,
repeats=CALLGRIND_REPEATS,
retain_out_file=False,
)
return WorkerOutput(
wall_times=tuple(m.times),
instructions=tuple(s.counts(denoise=True) for s in stats)
)
def main(communication_file: str) -> None:
result: Union[WorkerOutput, WorkerFailure]
try:
with open(communication_file, "rb") as f:
timer_args: WorkerTimerArgs = WorkerUnpickler(f).load_input()
assert isinstance(timer_args, WorkerTimerArgs)
result = _run(timer_args)
except KeyboardInterrupt:
# Runner process sent SIGINT.
sys.exit()
except BaseException:
trace_f = io.StringIO()
traceback.print_exc(file=trace_f)
result = WorkerFailure(failure_trace=trace_f.getvalue())
if not os.path.exists(os.path.split(communication_file)[0]):
# This worker is an orphan, and the parent has already cleaned up the
# working directory. In that case we can simply exit.
print(f"Orphaned worker {os.getpid()} exiting.")
return
with open(communication_file, "wb") as f:
pickle.dump(result, f)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--communication_file', type=str)
communication_file = parser.parse_args().communication_file
main(communication_file)
|