1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
|
import atexit
import shutil
import re
import textwrap
from typing import List, Optional, Tuple
from torch.utils.benchmark.utils.common import _make_temp_dir
from core.api import GroupedBenchmark, TimerArgs
from core.types import Definition, FlatIntermediateDefinition, Label
_TEMPDIR: Optional[str] = None
def get_temp_dir() -> str:
global _TEMPDIR
if _TEMPDIR is None:
_TEMPDIR = _make_temp_dir(prefix="instruction_count_microbenchmarks", gc_dev_shm=True)
atexit.register(shutil.rmtree, path=_TEMPDIR)
return _TEMPDIR
def _flatten(
key_prefix: Label,
sub_schema: Definition,
result: FlatIntermediateDefinition
) -> None:
for k, value in sub_schema.items():
if isinstance(k, tuple):
assert all(isinstance(ki, str) for ki in k)
key_suffix: Label = k
elif k is None:
key_suffix = ()
else:
assert isinstance(k, str)
key_suffix = (k,)
key: Label = key_prefix + key_suffix
if isinstance(value, (TimerArgs, GroupedBenchmark)):
assert key not in result, f"duplicate key: {key}"
result[key] = value
else:
assert isinstance(value, dict)
_flatten(key_prefix=key, sub_schema=value, result=result)
def flatten(schema: Definition) -> FlatIntermediateDefinition:
"""See types.py for an explanation of nested vs. flat definitions."""
result: FlatIntermediateDefinition = {}
_flatten(key_prefix=(), sub_schema=schema, result=result)
# Ensure that we produced a valid flat definition.
for k, v in result.items():
assert isinstance(k, tuple)
assert all(isinstance(ki, str) for ki in k)
assert isinstance(v, (TimerArgs, GroupedBenchmark))
return result
def parse_stmts(stmts: str) -> Tuple[str, str]:
"""Helper function for side-by-side Python and C++ stmts.
For more complex statements, it can be useful to see Python and C++ code
side by side. To this end, we provide an **extremely restricted** way
to define Python and C++ code side-by-side. The schema should be mostly
self explanatory, with the following non-obvious caveats:
- Width for the left (Python) column MUST be 40 characters.
- The column separator is " | ", not "|". Whitespace matters.
"""
stmts = textwrap.dedent(stmts).strip()
lines: List[str] = stmts.splitlines(keepends=False)
assert len(lines) >= 3, f"Invalid string:\n{stmts}"
column_header_pattern = r"^Python\s{35}\| C\+\+(\s*)$"
signature_pattern = r"^: f\((.*)\)( -> (.+))?\s*$"
separation_pattern = r"^[-]{40} | [-]{40}$"
code_pattern = r"^(.{40}) \|($| (.*)$)"
column_match = re.search(column_header_pattern, lines[0])
if column_match is None:
raise ValueError(
f"Column header `{lines[0]}` "
f"does not match pattern `{column_header_pattern}`")
assert re.search(separation_pattern, lines[1])
py_lines: List[str] = []
cpp_lines: List[str] = []
for l in lines[2:]:
l_match = re.search(code_pattern, l)
if l_match is None:
raise ValueError(f"Invalid line `{l}`")
py_lines.append(l_match.groups()[0])
cpp_lines.append(l_match.groups()[2] or "")
# Make sure we can round trip for correctness.
l_from_stmts = f"{py_lines[-1]:<40} | {cpp_lines[-1]:<40}".rstrip()
assert l_from_stmts == l.rstrip(), f"Failed to round trip `{l}`"
return "\n".join(py_lines), "\n".join(cpp_lines)
|