1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
|
import argparse
import functools
import traceback
from typing import Callable, List, Optional, Tuple
from torch.utils.jit.log_extract import (
extract_ir,
load_graph_and_inputs,
run_baseline_no_fusion,
run_nnc,
run_nvfuser,
)
"""
Usage:
1. Run your script and pipe into a log file
PYTORCH_JIT_LOG_LEVEL=">>graph_fuser" python3 my_test.py &> log.txt
2. Run log_extract:
log_extract.py log.txt --nvfuser --nnc-dynamic --nnc-static
You can also extract the list of extracted IR:
log_extract.py log.txt --output
Passing in --graphs 0 2 will only run graphs 0 and 2
"""
def test_runners(
graphs: List[str],
runners: List[Tuple[str, Callable]],
graph_set: Optional[List[int]],
):
for i, ir in enumerate(graphs):
_, inputs = load_graph_and_inputs(ir)
if graph_set and i not in graph_set:
continue
print(f"Running Graph {i}")
prev_result = None
prev_runner_name = None
for runner in runners:
runner_name, runner_fn = runner
try:
result = runner_fn(ir, inputs)
if prev_result:
improvement = (prev_result / result - 1) * 100
print(
f"{runner_name} : {result:.6f} ms improvement over {prev_runner_name}: improvement: {improvement:.2f}%"
)
else:
print(f"{runner_name} : {result:.6f} ms")
prev_result = result
prev_runner_name = runner_name
except RuntimeError:
print(f" Graph {i} failed for {runner_name} :", traceback.format_exc())
def run():
parser = argparse.ArgumentParser(
description="Extracts torchscript IR from log files and, optionally, benchmarks it or outputs the IR"
)
parser.add_argument("filename", help="Filename of log file")
parser.add_argument(
"--nvfuser", dest="nvfuser", action="store_true", help="benchmark nvfuser"
)
parser.add_argument(
"--no-nvfuser",
dest="nvfuser",
action="store_false",
help="DON'T benchmark nvfuser",
)
parser.set_defaults(nvfuser=False)
parser.add_argument(
"--nnc-static",
dest="nnc_static",
action="store_true",
help="benchmark nnc static",
)
parser.add_argument(
"--no-nnc-static",
dest="nnc_static",
action="store_false",
help="DON'T benchmark nnc static",
)
parser.set_defaults(nnc_static=False)
parser.add_argument(
"--nnc-dynamic",
dest="nnc_dynamic",
action="store_true",
help="nnc with dynamic shapes",
)
parser.add_argument(
"--no-nnc-dynamic",
dest="nnc_dynamic",
action="store_false",
help="DONT't benchmark nnc with dynamic shapes",
)
parser.set_defaults(nnc_dynamic=False)
parser.add_argument(
"--baseline", dest="baseline", action="store_true", help="benchmark baseline"
)
parser.add_argument(
"--no-baseline",
dest="baseline",
action="store_false",
help="DON'T benchmark baseline",
)
parser.set_defaults(baseline=False)
parser.add_argument(
"--output", dest="output", action="store_true", help="Output graph IR"
)
parser.add_argument(
"--no-output", dest="output", action="store_false", help="DON'T output graph IR"
)
parser.set_defaults(output=False)
parser.add_argument(
"--graphs", nargs="+", type=int, help="Run only specified graph indices"
)
args = parser.parse_args()
graphs = extract_ir(args.filename)
graph_set = args.graphs
graph_set = graph_set if graph_set else None
options = []
if args.baseline:
options.append(("Baseline no fusion", run_baseline_no_fusion))
if args.nnc_dynamic:
options.append(("NNC Dynamic", functools.partial(run_nnc, dynamic=True)))
if args.nnc_static:
options.append(("NNC Static", functools.partial(run_nnc, dynamic=False)))
if args.nvfuser:
options.append(("NVFuser", run_nvfuser))
test_runners(graphs, options, graph_set)
if args.output:
quoted = []
for i, ir in enumerate(graphs):
if graph_set and i not in graph_set:
continue
quoted.append('"""' + ir + '"""')
print("[" + ", ".join(quoted) + "]")
if __name__ == "__main__":
run()
|