1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547
|
#!/usr/bin/env python3
"""
A wrapper over the benchmark infrastructure to generate commonly used commands,
parse results and generate csv/graphs.
The script works on manually written TABLE (see below). We can add more commands
in the future.
One example usage is
-> python benchmarks/runner.py --suites=torchbench --inference
This command will generate the commands for the default compilers (see DEFAULTS
below) for inference, run them and visualize the logs.
If you want to just print the commands, you could use the following command
-> python benchmarks/runner.py --print-run-commands --suites=torchbench --inference
Similarly, if you want to just visualize the already finished logs
-> python benchmarks/runner.py --visualize-logs --suites=torchbench --inference
If you want to test float16
-> python benchmarks/runner.py --suites=torchbench --inference --dtypes=float16
"""
import argparse
import dataclasses
import functools
import glob
import importlib
import io
import itertools
import logging
import os
import re
import shutil
import subprocess
import sys
import tempfile
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from os.path import abspath, exists
from random import randint
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib import rcParams
from scipy.stats import gmean
from tabulate import tabulate
import torch
import torch._dynamo
rcParams.update({"figure.autolayout": True})
plt.rc("axes", axisbelow=True)
DEFAULT_OUTPUT_DIR = "benchmark_logs"
log = logging.getLogger(__name__)
TABLE = {
"training": {
"ts_nnc": "--training --speedup-ts ",
"ts_nvfuser": "--training --nvfuser --speedup-dynamo-ts ",
"eager": "--training --backend=eager ",
"aot_eager": "--training --backend=aot_eager ",
"cudagraphs": "--training --backend=cudagraphs ",
"aot_nvfuser": "--training --nvfuser --backend=aot_ts_nvfuser ",
"nvprims_nvfuser": "--training --backend=nvprims_nvfuser ",
"inductor": "--training --inductor ",
"inductor_no_cudagraphs": "--training --inductor --disable-cudagraphs ",
"inductor_max_autotune": "--training --inductor --inductor-compile-mode max-autotune ",
"inductor_max_autotune_no_cudagraphs": (
"--training --inductor --inductor-compile-mode max-autotune-no-cudagraphs --disable-cudagraphs "
),
},
"inference": {
"aot_eager": "--inference --backend=aot_eager ",
"eager": "--inference --backend=eager ",
"ts_nnc": "--inference --speedup-ts ",
"ts_nvfuser": "--inference -n100 --speedup-ts --nvfuser ",
"trt": "--inference -n100 --speedup-trt ",
"ts_nvfuser_cudagraphs": "--inference --backend=cudagraphs_ts ",
"inductor": "--inference -n50 --inductor ",
"inductor_no_cudagraphs": "--inference -n50 --inductor --disable-cudagraphs ",
"inductor_max_autotune": "--inference -n50 --inductor --inductor-compile-mode max-autotune ",
"inductor_max_autotune_no_cudagraphs": (
"--inference -n50 --inductor --inductor-compile-mode max-autotune-no-cudagraphs --disable-cudagraphs "
),
"torchscript-onnx": "--inference -n5 --torchscript-onnx",
"dynamo-onnx": "--inference -n5 --dynamo-onnx",
},
}
INFERENCE_COMPILERS = tuple(TABLE["inference"].keys())
TRAINING_COMPILERS = tuple(TABLE["training"].keys())
DEFAULTS = {
"training": [
"eager",
"aot_eager",
"inductor",
"inductor_no_cudagraphs",
],
"inference": [
"eager",
"aot_eager",
"inductor",
"inductor_no_cudagraphs",
],
"flag_compilers": {
"training": ["inductor", "inductor_no_cudagraphs"],
"inference": ["inductor", "inductor_no_cudagraphs"],
},
"dtypes": [
"float32",
],
"suites": ["torchbench", "huggingface", "timm_models"],
"devices": [
"cuda",
],
"quick": {
"torchbench": '-k "resnet..$"',
"huggingface": "-k Albert",
"timm_models": ' -k "^resnet" -k "^inception"',
},
}
DASHBOARD_DEFAULTS = {
"dashboard_image_uploader": "/fsx/users/anijain/bin/imgur.sh",
"dashboard_archive_path": "/data/home/anijain/cluster/cron_logs",
"dashboard_gh_cli_path": "/data/home/anijain/miniconda/bin/gh",
}
def flag_speedup(x):
return x < 0.95
def flag_compilation_latency(x):
return x > 120
def flag_compression_ratio(x):
return x < 0.9
def flag_accuracy(x):
return "pass" not in x
FLAG_FNS = {
"speedup": flag_speedup,
"compilation_latency": flag_compilation_latency,
"compression_ratio": flag_compression_ratio,
"accuracy": flag_accuracy,
}
def percentage(part, whole, decimals=2):
if whole == 0:
return 0
return round(100 * float(part) / float(whole), decimals)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--devices", action="append", help="cpu or cuda")
parser.add_argument("--dtypes", action="append", help="float16/float32/amp")
parser.add_argument("--suites", action="append", help="huggingface/torchbench/timm")
parser.add_argument(
"--compilers",
action="append",
help=f"For --inference, options are {INFERENCE_COMPILERS}. For --training, options are {TRAINING_COMPILERS}",
)
parser.add_argument(
"--flag-compilers",
action="append",
help="List of compilers to flag issues. Same format as --compilers.",
)
parser.add_argument(
"--quick", action="store_true", help="Just runs one model. Helps in debugging"
)
parser.add_argument(
"--output-dir",
help="Choose the output directory to save the logs",
default=DEFAULT_OUTPUT_DIR,
)
parser.add_argument(
"--keep-output-dir",
action="store_true",
help="Do not cleanup the output directory before running",
)
# Choose either generation of commands, pretty parsing or e2e runs
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument(
"--print-run-commands",
"--print_run_commands",
action="store_true",
help="Generate commands and saves them to run.sh",
)
group.add_argument(
"--visualize-logs",
"--visualize_logs",
action="store_true",
help="Pretty print the log files and draw graphs",
)
group.add_argument(
"--run",
action="store_true",
default=True,
help="Generate commands, run and parses the files",
)
parser.add_argument(
"--log-operator-inputs",
action="store_true",
default=False,
help="Log operator inputs",
)
parser.add_argument(
"--include-slowdowns",
"--include_slowdowns",
action="store_true",
default=False,
help="Include slowdowns in geomean performance speedup report. By default, slowdowns are ignored. "
"This is because one can always use eager if compile is not speeding things up",
)
parser.add_argument(
"--extra-args", default="", help="Append commandline with these args"
)
# Choose either inference or training
group_mode = parser.add_mutually_exclusive_group(required=True)
group_mode.add_argument(
"--inference", action="store_true", help="Only run inference related tasks"
)
group_mode.add_argument(
"--training", action="store_true", help="Only run training related tasks"
)
parser.add_argument(
"--base-sha",
help="commit id for the tested pytorch",
)
parser.add_argument(
"--total-partitions",
type=int,
help="Total number of partitions, to be passed to the actual benchmark script",
)
parser.add_argument(
"--partition-id",
type=int,
help="ID of partition, to be passed to the actual benchmark script",
)
parser.add_argument(
"--update-dashboard",
action="store_true",
default=False,
help="Updates to dashboard",
)
parser.add_argument(
"--no-graphs",
action="store_true",
default=False,
help="Do not genenerate and upload metric graphs",
)
parser.add_argument(
"--no-update-archive",
action="store_true",
default=False,
help="Do not update lookup.csv or the log archive",
)
parser.add_argument(
"--no-gh-comment",
action="store_true",
default=False,
help="Do not write a comment to github",
)
parser.add_argument(
"--no-detect-regressions",
action="store_true",
default=False,
help="Do not compare to previous runs for regressions or metric graphs.",
)
parser.add_argument(
"--update-dashboard-test",
action="store_true",
default=False,
help="does all of --no-graphs, --no-update-archive, and --no-gh-comment",
)
parser.add_argument(
"--dashboard-image-uploader",
default=DASHBOARD_DEFAULTS["dashboard_image_uploader"],
help="Image uploader command",
)
parser.add_argument(
"--dashboard-archive-path",
default=DASHBOARD_DEFAULTS["dashboard_archive_path"],
help="Archived directory path",
)
parser.add_argument(
"--archive-name",
help="Directory name under dashboard-archive-path to copy output-dir to. "
"If not provided, a generated name is used.",
)
parser.add_argument(
"--dashboard-gh-cli-path",
default=DASHBOARD_DEFAULTS["dashboard_gh_cli_path"],
help="Github CLI path",
)
parser.add_argument(
"--batch-size",
"--batch_size",
type=int,
default=None,
help="batch size for benchmarking",
)
parser.add_argument(
"--threads",
"-t",
type=int,
default=None,
help="number of threads to use for eager and inductor.",
)
launcher_group = parser.add_argument_group("CPU Launcher Parameters")
launcher_group.add_argument(
"--enable-cpu-launcher",
"--enable_cpu_launcher",
action="store_true",
default=False,
help="Use torch.backends.xeon.run_cpu to get the peak performance on Intel(R) Xeon(R) Scalable Processors.",
)
launcher_group.add_argument(
"--cpu-launcher-args",
"--cpu_launcher_args",
type=str,
default="",
help="Provide the args of torch.backends.xeon.run_cpu. "
"To look up what optional arguments this launcher offers: python -m torch.backends.xeon.run_cpu --help",
)
parser.add_argument(
"--no-cold-start-latency",
action="store_true",
default=False,
help="Do not include --cold-start-latency on inductor benchmarks",
)
parser.add_argument(
"--inductor-compile-mode",
default=None,
help="torch.compile mode argument for inductor runs.",
)
args = parser.parse_args()
return args
def get_mode(args):
if args.inference:
return "inference"
return "training"
def get_skip_tests(suite, device, is_training: bool):
"""
Generate -x seperated string to skip the unusual setup training tests
"""
skip_tests = set()
original_dir = abspath(os.getcwd())
module = importlib.import_module(suite)
os.chdir(original_dir)
if suite == "torchbench":
skip_tests.update(module.TorchBenchmarkRunner().skip_models)
if is_training:
skip_tests.update(
module.TorchBenchmarkRunner().skip_not_suitable_for_training_models
)
if device == "cpu":
skip_tests.update(module.TorchBenchmarkRunner().skip_models_for_cpu)
elif device == "cuda":
skip_tests.update(module.TorchBenchmarkRunner().skip_models_for_cuda)
skip_tests = (f"-x {name}" for name in skip_tests)
skip_str = " ".join(skip_tests)
return skip_str
def generate_csv_name(args, dtype, suite, device, compiler, testing):
mode = get_mode(args)
return f"{compiler}_{suite}_{dtype}_{mode}_{device}_{testing}.csv"
def generate_commands(args, dtypes, suites, devices, compilers, output_dir):
mode = get_mode(args)
suites_str = "_".join(suites)
devices_str = "_".join(devices)
dtypes_str = "_".join(dtypes)
compilers_str = "_".join(compilers)
generated_file = (
f"run_{mode}_{devices_str}_{dtypes_str}_{suites_str}_{compilers_str}.sh"
)
with open(generated_file, "w") as runfile:
lines = []
lines.append("#!/bin/bash")
lines.append("set -x")
lines.append("# Setup the output directory")
if not args.keep_output_dir:
lines.append(f"rm -rf {output_dir}")
# It's ok if the output directory already exists
lines.append(f"mkdir -p {output_dir}")
lines.append("")
for testing in ["performance", "accuracy"]:
for iter in itertools.product(suites, devices, dtypes):
suite, device, dtype = iter
lines.append(
f"# Commands for {suite} for device={device}, dtype={dtype} for {mode} and for {testing} testing"
)
info = TABLE[mode]
for compiler in compilers:
base_cmd = info[compiler]
output_filename = f"{output_dir}/{generate_csv_name(args, dtype, suite, device, compiler, testing)}"
launcher_cmd = "python"
if args.enable_cpu_launcher:
launcher_cmd = f"python -m torch.backends.xeon.run_cpu {args.cpu_launcher_args}"
cmd = f"{launcher_cmd} benchmarks/dynamo/{suite}.py --{testing} --{dtype} -d{device} --output={output_filename}"
cmd = f"{cmd} {base_cmd} {args.extra_args} --dashboard"
skip_tests_str = get_skip_tests(suite, device, args.training)
cmd = f"{cmd} {skip_tests_str}"
if args.log_operator_inputs:
cmd = f"{cmd} --log-operator-inputs"
if args.quick:
filters = DEFAULTS["quick"][suite]
cmd = f"{cmd} {filters}"
if (
compiler
in (
"inductor",
"inductor_no_cudagraphs",
)
and not args.no_cold_start_latency
):
cmd = f"{cmd} --cold-start-latency"
if args.batch_size is not None:
cmd = f"{cmd} --batch-size {args.batch_size}"
if args.threads is not None:
cmd = f"{cmd} --threads {args.threads}"
if args.total_partitions is not None:
cmd = f"{cmd} --total-partitions {args.total_partitions}"
if args.partition_id is not None:
cmd = f"{cmd} --partition-id {args.partition_id}"
if args.inductor_compile_mode is not None:
cmd = f"{cmd} --inductor-compile-mode {args.inductor_compile_mode}"
lines.append(cmd)
lines.append("")
runfile.writelines([line + "\n" for line in lines])
return generated_file
def generate_dropdown_comment(title, body):
str_io = io.StringIO()
str_io.write(f"{title}\n")
str_io.write("<details>\n")
str_io.write("<summary>see more</summary>\n")
str_io.write(f"{body}")
str_io.write("\n")
str_io.write("</details>\n\n")
return str_io.getvalue()
def build_summary(args):
out_io = io.StringIO()
def print_commit_hash(path, name):
if args.base_sha is not None:
if name == "pytorch":
out_io.write(f"{name} commit: {args.base_sha}\n")
elif exists(path):
import git
repo = git.Repo(path, search_parent_directories=True)
sha = repo.head.object.hexsha
date = repo.head.object.committed_datetime
out_io.write(f"{name} commit: {sha}\n")
out_io.write(f"{name} commit date: {date}\n")
else:
out_io.write(f"{name} Absent\n")
def env_var(name):
if name in os.environ:
out_io.write(f"{name} = {os.environ[name]}\n")
else:
out_io.write(f"{name} = {None}\n")
out_io.write("\n")
out_io.write("### Run name ###\n")
out_io.write(get_archive_name(args, args.dtypes[0]))
out_io.write("\n")
out_io.write("\n")
out_io.write("### Commit hashes ###\n")
print_commit_hash("../pytorch", "pytorch")
print_commit_hash("../torchbenchmark", "torchbench")
out_io.write("\n")
out_io.write("### TorchDynamo config flags ###\n")
for key in dir(torch._dynamo.config):
val = getattr(torch._dynamo.config, key)
if not key.startswith("__") and isinstance(val, bool):
out_io.write(f"torch._dynamo.config.{key} = {val}\n")
out_io.write("\n")
out_io.write("### Torch version ###\n")
out_io.write(f"torch: {torch.__version__}\n")
out_io.write("\n")
out_io.write("### Environment variables ###\n")
env_var("TORCH_CUDA_ARCH_LIST")
env_var("CUDA_HOME")
env_var("USE_LLVM")
if "cuda" in args.devices:
out_io.write("\n")
out_io.write("### GPU details ###\n")
out_io.write(f"CUDNN VERSION: {torch.backends.cudnn.version()}\n")
out_io.write(f"Number CUDA Devices: {torch.cuda.device_count()}\n")
out_io.write(f"Device Name: {torch.cuda.get_device_name(0)}\n")
out_io.write(
f"Device Memory [GB]: {torch.cuda.get_device_properties(0).total_memory/1e9}\n"
)
title = "## Build Summary"
comment = generate_dropdown_comment(title, out_io.getvalue())
with open(f"{output_dir}/gh_build_summary.txt", "w") as gh_fh:
gh_fh.write(comment)
@functools.lru_cache(None)
def archive_data(archive_name):
if archive_name is not None:
prefix_match = re.search(r"\w+(?=_performance)", archive_name)
if prefix_match is not None:
prefix = prefix_match.group(0)
else:
prefix = ""
day_match = re.search(r"day_(\d+)_", archive_name)
if day_match is not None:
day = day_match.group(1)
else:
day = "000"
else:
now = datetime.now(tz=timezone(timedelta(hours=-8)))
day = now.strftime("%j")
prefix = now.strftime(f"day_{day}_%d_%m_%y")
return day, prefix
@functools.lru_cache(None)
def default_archive_name(dtype):
_, prefix = archive_data(None)
return f"{prefix}_performance_{dtype}_{randint(100, 999)}"
def get_archive_name(args, dtype):
return (
default_archive_name(dtype) if args.archive_name is None else args.archive_name
)
def archive(src_dir, dest_dir_prefix, archive_name, dtype):
if archive_name is None:
archive_name = default_archive_name(dtype)
# Copy the folder to archived location
dest = os.path.join(dest_dir_prefix, archive_name)
shutil.copytree(src_dir, dest, dirs_exist_ok=True)
print(f"copied contents of {src_dir} to {dest}")
def get_metric_title(metric):
if metric == "speedup":
return "Performance speedup"
elif metric == "accuracy":
return "Accuracy"
elif metric == "compilation_latency":
return "Compilation latency (sec)"
elif metric == "compression_ratio":
return "Peak Memory Compression Ratio"
elif metric == "abs_latency":
return "Absolute latency (ms)"
raise RuntimeError("unknown metric")
class Parser:
def __init__(
self, suites, devices, dtypes, compilers, flag_compilers, mode, output_dir
):
self.suites = suites
self.devices = devices
self.dtypes = dtypes
self.compilers = compilers
self.flag_compilers = flag_compilers
self.output_dir = output_dir
self.mode = mode
def has_header(self, output_filename):
header_present = False
with open(output_filename) as f:
line = f.readline()
if "dev" in line:
header_present = True
return header_present
class ParsePerformanceLogs(Parser):
def __init__(
self,
suites,
devices,
dtypes,
compilers,
flag_compilers,
mode,
output_dir,
include_slowdowns=False,
):
super().__init__(
suites,
devices,
dtypes,
compilers,
flag_compilers,
mode,
output_dir,
)
self.parsed_frames = defaultdict(lambda: defaultdict(None))
self.untouched_parsed_frames = defaultdict(lambda: defaultdict(None))
self.metrics = [
"speedup",
"abs_latency",
"compilation_latency",
"compression_ratio",
]
self.bottom_k = 50
self.parse()
self.include_slowdowns = include_slowdowns
def plot_graph(self, df, title):
labels = df.columns.values.tolist()
labels = labels[3:]
df.plot(
x="name",
y=labels,
kind="bar",
width=0.65,
title=title,
ylabel="Speedup over eager",
xlabel="",
grid=True,
figsize=(max(len(df.index) / 4, 5), 10),
edgecolor="black",
)
plt.tight_layout()
plt.savefig(f"{self.output_dir}/{title}.png")
def read_csv(self, output_filename):
if self.has_header(output_filename):
return pd.read_csv(output_filename)
else:
return pd.read_csv(
output_filename,
names=[
"dev",
"name",
"batch_size",
"speedup",
"abs_latency",
"compilation_latency",
"compression_ratio",
],
header=None,
engine="python",
)
def parse(self):
self.extract_df("accuracy", "accuracy")
for metric in self.metrics:
self.extract_df(metric, "performance")
def clean_batch_sizes(self, frames):
# Clean up batch sizes when its 0
if len(frames) == 1:
return frames
batch_sizes = frames[0]["batch_size"].to_list()
for frame in frames[1:]:
frame_batch_sizes = frame["batch_size"].to_list()
for idx, (batch_a, batch_b) in enumerate(
zip(batch_sizes, frame_batch_sizes)
):
assert batch_a == batch_b or batch_a == 0 or batch_b == 0, print(
f"a={batch_a}, b={batch_b}"
)
batch_sizes[idx] = max(batch_a, batch_b)
for frame in frames:
frame["batch_size"] = batch_sizes
return frames
def extract_df(self, metric, testing):
for iter in itertools.product(self.suites, self.devices, self.dtypes):
suite, device, dtype = iter
frames = []
for compiler in self.compilers:
output_filename = f"{self.output_dir}/{compiler}_{suite}_{dtype}_{self.mode}_{device}_{testing}.csv"
df = self.read_csv(output_filename)
if metric not in df:
df.insert(len(df.columns), metric, np.nan)
df = df[["dev", "name", "batch_size", metric]]
df.rename(columns={metric: compiler}, inplace=True)
df["batch_size"] = df["batch_size"].astype(int)
frames.append(df)
# Merge the results
frames = self.clean_batch_sizes(frames)
if len(self.compilers) == 1:
df = frames[0]
else:
# Merge data frames
df = pd.merge(frames[0], frames[1], on=["dev", "name", "batch_size"])
for idx in range(2, len(frames)):
df = pd.merge(df, frames[idx], on=["dev", "name", "batch_size"])
if testing == "performance":
for compiler in self.compilers:
df[compiler] = pd.to_numeric(df[compiler], errors="coerce").fillna(
0
)
df_copy = df.copy()
df_copy = df_copy.sort_values(
by=list(reversed(self.compilers)), ascending=False
)
if "inductor" in self.compilers:
df_copy = df_copy.sort_values(by="inductor", ascending=False)
self.untouched_parsed_frames[suite][metric] = df_copy
if testing == "performance":
df_accuracy = self.parsed_frames[suite]["accuracy"]
perf_rows = []
for model_name in df["name"]:
perf_row = df[df["name"] == model_name].copy()
acc_row = df_accuracy[df_accuracy["name"] == model_name]
for compiler in self.compilers:
if not perf_row.empty:
if acc_row.empty:
perf_row[compiler] = 0.0
elif acc_row[compiler].iloc[0] in (
"model_fail_to_load",
"eager_fail_to_run",
):
perf_row = pd.DataFrame()
elif acc_row[compiler].iloc[0] not in (
"pass",
"pass_due_to_skip",
):
perf_row[compiler] = 0.0
if not perf_row.empty:
perf_rows.append(perf_row)
df = pd.concat(perf_rows)
df = df.sort_values(by=list(reversed(self.compilers)), ascending=False)
if "inductor" in self.compilers:
df = df.sort_values(by="inductor", ascending=False)
self.parsed_frames[suite][metric] = df
def get_passing_entries(self, compiler, df):
return df[compiler][df[compiler] > 0]
def comp_time(self, compiler, df):
df = self.get_passing_entries(compiler, df)
# df = df.sort_values(by=compiler, ascending=False)[compiler][: self.bottom_k]
if df.empty:
return "0.0"
return f"{df.mean():.2f}"
def geomean(self, compiler, df):
cleaned_df = self.get_passing_entries(compiler, df)
if not self.include_slowdowns:
cleaned_df = cleaned_df.clip(1)
if cleaned_df.empty:
return "0.0x"
return f"{gmean(cleaned_df):.2f}x"
def passrate(self, compiler, df):
total = len(df.index)
passing = df[df[compiler] > 0.0][compiler].count()
perc = int(percentage(passing, total, decimals=0))
return f"{perc}%, {passing}/{total}"
def memory(self, compiler, df):
df = self.get_passing_entries(compiler, df)
df = df.fillna(0)
df = df[df > 0]
if df.empty:
return "0.0x"
return f"{df.mean():.2f}x"
def exec_summary_df(self, fn, metric):
"""
Generate a table with passrate and geomean perf
"""
cols = {}
cols["Compiler"] = self.compilers
for suite in self.suites:
df = self.parsed_frames[suite][metric]
# speedups = [self.geomean(compiler, df) for compiler in self.compilers]
speedups = [fn(compiler, df) for compiler in self.compilers]
col = pd.Series(data=speedups, index=self.compilers)
cols[suite] = col
df = pd.DataFrame(cols)
df = df.fillna(0)
df.to_csv(os.path.join(self.output_dir, f"{fn.__name__}.csv"))
return df
def exec_summary_text(self, caption, fn, metric):
df = self.exec_summary_df(fn, metric)
tabform = tabulate(df, headers="keys", tablefmt="pretty", showindex="never")
str_io = io.StringIO()
str_io.write(f"{caption}")
str_io.write("~~~\n")
str_io.write(f"{tabform}\n")
str_io.write("~~~\n")
return str_io.getvalue()
def generate_executive_summary(self):
machine = "A100 GPUs"
if "cpu" in self.devices:
get_machine_cmd = "lscpu| grep 'Model name' | awk -F':' '{print $2}'"
machine = subprocess.getstatusoutput(get_machine_cmd)[1].strip()
description = (
"We evaluate different backends "
"across three benchmark suites - torchbench, huggingface and timm. We run "
"these experiments on "
+ machine
+ ". Each experiment runs one iteration of forward pass "
"and backward pass for training and forward pass only for inference. "
"For accuracy, we check the numerical correctness of forward pass outputs and gradients "
"by comparing with native pytorch. We measure speedup "
"by normalizing against the performance of native pytorch. We report mean "
"compilation latency numbers and peak memory footprint reduction ratio. \n\n"
"Caveats\n"
"1) Batch size has been reduced to workaround OOM errors. Work is in progress to "
"reduce peak memory footprint.\n"
"2) Experiments do not cover dynamic shapes.\n"
"3) Experimental setup does not have optimizer.\n\n"
)
comment = generate_dropdown_comment("", description)
str_io = io.StringIO()
str_io.write("\n")
str_io.write("## Executive Summary ##\n")
str_io.write(comment)
speedup_caption = "Geometric mean speedup \n"
speedup_summary = self.exec_summary_text(
speedup_caption, self.geomean, "speedup"
)
passrate_caption = "Passrate\n"
passrate_summary = self.exec_summary_text(
passrate_caption, self.passrate, "speedup"
)
comp_time_caption = "Mean compilation time (seconds)\n"
comp_time_summary = self.exec_summary_text(
comp_time_caption, self.comp_time, "compilation_latency"
)
peak_memory_caption = (
"Peak memory footprint compression ratio (higher is better)\n"
)
peak_memory_summary = self.exec_summary_text(
peak_memory_caption, self.memory, "compression_ratio"
)
str_io.write(
"To measure performance, compilation latency and memory footprint reduction, "
"we remove the models that fail accuracy checks.\n\n"
)
str_io.write(passrate_summary)
str_io.write(speedup_summary)
str_io.write(comp_time_summary)
str_io.write(peak_memory_summary)
self.executive_summary = str_io.getvalue()
def flag_bad_entries(self, suite, metric, flag_fn):
df = self.untouched_parsed_frames[suite][metric]
df = df.drop("dev", axis=1)
df = df.rename(columns={"batch_size": "bs"})
# apply flag_fn elementwise to flag_compilers columns,
# if one element fails, the entire row is flagged
flag = np.logical_or.reduce(
df[self.flag_compilers].applymap(flag_fn),
axis=1,
)
df = df[flag]
df = df.assign(suite=suite)
return df.reindex(columns=["suite", "name"] + self.flag_compilers)
def generate_warnings(self):
title = "## Warnings ##"
body = (
"We flag models where:\n\n"
" - accuracy fails\n"
" - speedup < 0.95x (NOTE: 0.0 speedup typically signifies a failure in the performance test)\n"
" - compilation latency > 120 sec.\n"
" - compression ratio < 0.9\n"
"\n"
)
for metric in [
"accuracy",
"speedup",
"compilation_latency",
"compression_ratio",
]:
dfs = []
for suite in self.suites:
dfs.append(self.flag_bad_entries(suite, metric, FLAG_FNS[metric]))
df = pd.concat(dfs, axis=0)
if df.empty:
continue
tabform = tabulate(df, headers="keys", tablefmt="pretty", showindex="never")
str_io = io.StringIO()
str_io.write("\n")
str_io.write(get_metric_title(metric) + " warnings\n")
str_io.write("~~~\n")
str_io.write(f"{tabform}\n")
str_io.write("~~~\n")
body += str_io.getvalue()
comment = generate_dropdown_comment(title, body)
return comment
def prepare_message(self, suite):
title = f"## {suite} suite with {self.dtypes[0]} precision ##"
body = ""
for metric in [
"speedup",
"accuracy",
"compilation_latency",
"compression_ratio",
"abs_latency",
]:
df = self.untouched_parsed_frames[suite][metric]
df = df.drop("dev", axis=1)
df = df.rename(columns={"batch_size": "bs"})
tabform = tabulate(df, headers="keys", tablefmt="pretty", showindex="never")
str_io = io.StringIO()
str_io.write("\n")
str_io.write(get_metric_title(metric) + "\n")
str_io.write("~~~\n")
str_io.write(f"{tabform}\n")
str_io.write("~~~\n")
body += str_io.getvalue()
comment = generate_dropdown_comment(title, body)
return comment
def gen_summary_files(self):
self.generate_executive_summary()
for suite in self.suites:
self.plot_graph(
self.untouched_parsed_frames[suite]["speedup"],
f"{suite}_{self.dtypes[0]}",
)
with open(f"{self.output_dir}/gh_title.txt", "w") as gh_fh:
str_io = io.StringIO()
str_io.write("\n")
str_io.write(f"# Performance Dashboard for {self.dtypes[0]} precision ##\n")
str_io.write("\n")
gh_fh.write(str_io.getvalue())
with open(f"{self.output_dir}/gh_executive_summary.txt", "w") as gh_fh:
gh_fh.write(self.executive_summary)
with open(f"{self.output_dir}/gh_warnings.txt", "w") as gh_fh:
warnings_body = self.generate_warnings()
gh_fh.write(warnings_body)
str_io = io.StringIO()
for suite in self.suites:
str_io.write(self.prepare_message(suite))
str_io.write("\n")
with open(f"{self.output_dir}/gh_{self.mode}.txt", "w") as gh_fh:
gh_fh.write(str_io.getvalue())
def parse_logs(args, dtypes, suites, devices, compilers, flag_compilers, output_dir):
mode = get_mode(args)
build_summary(args)
include_slowdowns = args.include_slowdowns
parser_class = ParsePerformanceLogs
parser = parser_class(
suites,
devices,
dtypes,
compilers,
flag_compilers,
mode,
output_dir,
include_slowdowns,
)
parser.gen_summary_files()
return
@dataclasses.dataclass
class LogInfo:
# Day of the year this log was generated
day: str
# Directory path where all logs are present
dir_path: str
def get_date(log_info):
return datetime.strptime(f"{log_info.day}", "%j").strftime("%m-%d")
def find_last_2_with_filenames(lookup_file, dashboard_archive_path, dtype, filenames):
df = pd.read_csv(lookup_file, names=("day", "mode", "prec", "path"))
df = df[df["mode"] == "performance"]
df = df[df["prec"] == dtype]
df = df[::-1]
last2 = []
for path in df["path"]:
output_dir = os.path.join(dashboard_archive_path, path)
fullpaths = [
os.path.join(dashboard_archive_path, path, name) for name in filenames
]
if all(os.path.exists(fullpath) for fullpath in fullpaths):
last2.append(output_dir)
if len(last2) >= 2:
return last2
return None
class SummaryStatDiffer:
def __init__(self, args):
self.args = args
self.lookup_file = os.path.join(self.args.dashboard_archive_path, "lookup.csv")
assert os.path.exists(self.lookup_file)
def generate_diff(self, last2, filename, caption):
df_cur, df_prev = (pd.read_csv(os.path.join(path, filename)) for path in last2)
df_merge = df_cur.merge(df_prev, on="Compiler", suffixes=("_cur", "_prev"))
data = {col: [] for col in ("compiler", "suite", "prev_value", "cur_value")}
for _, row in df_merge.iterrows():
if row["Compiler"] in self.args.flag_compilers:
for suite in self.args.suites:
if suite + "_prev" not in row or suite + "_cur" not in row:
continue
data["compiler"].append(row["Compiler"])
data["suite"].append(suite)
data["prev_value"].append(row[suite + "_prev"])
data["cur_value"].append(row[suite + "_cur"])
df = pd.DataFrame(data)
tabform = tabulate(df, headers="keys", tablefmt="pretty", showindex="never")
str_io = io.StringIO()
str_io.write("\n")
str_io.write(f"{caption}\n")
str_io.write("~~~\n")
str_io.write(f"{tabform}\n")
str_io.write("~~~\n")
return str_io.getvalue()
def generate_comment(self):
title = "## Summary Statistics Diff ##\n"
body = (
"For each relevant compiler, we compare the summary statistics "
"for the most 2 recent reports that actually run the compiler.\n\n"
)
dtype = self.args.dtypes[0]
last2 = find_last_2_with_filenames(
self.lookup_file,
self.args.dashboard_archive_path,
dtype,
["geomean.csv", "passrate.csv"],
)
if last2 is None:
body += "Could not find most 2 recent reports.\n\n"
else:
for state, path in zip(("Current", "Previous"), last2):
body += f"{state} report name: {path}\n\n"
body += self.generate_diff(last2, "passrate.csv", "Passrate diff")
body += self.generate_diff(
last2, "geomean.csv", "Geometric mean speedup diff"
)
comment = generate_dropdown_comment(title, body)
with open(f"{self.args.output_dir}/gh_summary_diff.txt", "w") as gh_fh:
gh_fh.write(comment)
class RegressionDetector:
"""
Compares the most recent 2 benchmarks to find previously unflagged models
that are now flagged.
"""
def __init__(self, args):
self.args = args
self.lookup_file = os.path.join(self.args.dashboard_archive_path, "lookup.csv")
assert os.path.exists(self.lookup_file)
def generate_comment(self):
title = "## Recent Regressions ##\n"
body = (
"For each relevant compiler, we compare the most recent 2 reports "
"(that actually run the compiler) to find previously unflagged "
"models that are now flagged as problematic (according to the "
"'Warnings' section).\n\n"
)
dtype = self.args.dtypes[0]
device = self.args.devices[0]
for suite in self.args.suites:
body += f"### Regressions for {suite} ###\n"
last2 = {}
for compiler in self.args.flag_compilers:
filenames = [
generate_csv_name(
self.args, dtype, suite, device, compiler, testing
)
for testing in ["performance", "accuracy"]
]
compiler_last2 = find_last_2_with_filenames(
self.lookup_file, self.args.dashboard_archive_path, dtype, filenames
)
if compiler_last2 is not None:
last2[compiler] = [
ParsePerformanceLogs(
[suite],
[device],
[dtype],
[compiler],
[compiler],
get_mode(self.args),
output_dir,
)
for output_dir in compiler_last2
]
for state, path in zip(("Current", "Previous"), compiler_last2):
body += (
f"{state} report name (compiler: {compiler}, "
f"suite: {suite}): {path}\n\n"
)
regressions_present = False
for metric in [
"accuracy",
"speedup",
"compilation_latency",
"compression_ratio",
]:
dfs = []
for compiler in self.args.flag_compilers:
if last2[compiler] is None:
continue
df_cur, df_prev = (
last2[compiler][i].untouched_parsed_frames[suite][metric]
for i in (0, 1)
)
df_merge = df_cur.merge(
df_prev, on="name", suffixes=("_cur", "_prev")
)
flag_fn = FLAG_FNS[metric]
flag = np.logical_and(
df_merge[compiler + "_prev"].apply(
lambda x: not pd.isna(x) and not flag_fn(x)
),
df_merge[compiler + "_cur"].apply(
lambda x: not pd.isna(x) and flag_fn(x)
),
)
df_bad = df_merge[flag]
dfs.append(
pd.DataFrame(
data={
"compiler": compiler,
"name": df_bad["name"],
"prev_status": df_bad[compiler + "_prev"],
"cur_status": df_bad[compiler + "_cur"],
}
)
)
if not dfs:
continue
df = pd.concat(dfs, axis=0)
if df.empty:
continue
regressions_present = True
tabform = tabulate(
df, headers="keys", tablefmt="pretty", showindex="never"
)
str_io = io.StringIO()
str_io.write("\n")
str_io.write(f"{get_metric_title(metric)} regressions\n")
str_io.write("~~~\n")
str_io.write(f"{tabform}\n")
str_io.write("~~~\n")
body += str_io.getvalue()
if not regressions_present:
body += "No regressions found.\n"
comment = generate_dropdown_comment(title, body)
with open(f"{self.args.output_dir}/gh_metric_regression.txt", "w") as gh_fh:
gh_fh.write(comment)
class RegressionTracker:
"""
Plots progress of different metrics over time to detect regressions.
"""
def __init__(self, args):
self.args = args
self.suites = self.args.suites
self.lookup_file = os.path.join(self.args.dashboard_archive_path, "lookup.csv")
assert os.path.exists(self.lookup_file)
self.k = 10
def find_last_k(self):
"""
Find the last k pairs of (day number, log_path)
"""
dtype = self.args.dtypes[0]
df = pd.read_csv(self.lookup_file, names=("day", "mode", "prec", "path"))
df = df[df["mode"] == "performance"]
df = df[df["prec"] == dtype]
log_infos = []
for day, path in zip(df["day"], df["path"]):
log_infos.append(LogInfo(day, path))
assert len(log_infos) >= self.k
log_infos = log_infos[len(log_infos) - self.k :]
return log_infos
def generate_comment(self):
title = "## Metrics over time ##\n"
str_io = io.StringIO()
if not self.args.update_dashboard_test and not self.args.no_graphs:
for name in glob.glob(self.args.output_dir + "/*over_time.png"):
output = (
subprocess.check_output([self.args.dashboard_image_uploader, name])
.decode("ascii")
.rstrip()
)
str_io.write(f"\n{name} : \n")
comment = generate_dropdown_comment(title, str_io.getvalue())
with open(f"{self.args.output_dir}/gh_regression.txt", "w") as gh_fh:
gh_fh.write(comment)
def diff(self):
log_infos = self.find_last_k()
for metric in ["geomean", "passrate", "comp_time", "memory"]:
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
for idx, suite in enumerate(self.suites):
dfs = []
for log_info in log_infos:
dir_path = os.path.join(
self.args.dashboard_archive_path, log_info.dir_path
)
assert os.path.exists(dir_path)
gmean_filename = os.path.join(dir_path, f"{metric}.csv")
if not os.path.exists(gmean_filename):
continue
df = pd.read_csv(gmean_filename)
if suite not in df:
continue
if metric == "geomean" or metric == "memory":
df[suite] = df[suite].str.replace("x", "").astype(float)
elif metric == "passrate":
df[suite] = df[suite].str.split("%").str[0].astype(float)
df.insert(0, "day", get_date(log_info))
df = df.pivot(index="day", columns="Compiler", values=suite)
# Interim stage when both inductor_cudagraphs and inductor exist
df = df.rename(columns={"inductor_cudagraphs": "inductor"})
for col_name in df.columns:
if col_name not in self.args.compilers:
df = df.drop(columns=[col_name])
dfs.append(df)
df = pd.concat(dfs)
df = df.interpolate(method="linear")
ax = df.plot(
ax=axes[idx],
kind="line",
ylabel=metric,
xlabel="Date",
grid=True,
ylim=0 if metric == "passrate" else 0.8,
title=suite,
style=".-",
legend=False,
)
ax.legend(loc="lower right", ncol=2)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, f"{metric}_over_time.png"))
self.generate_comment()
class DashboardUpdater:
"""
Aggregates the information and makes a comment to Performance Dashboard.
https://github.com/pytorch/torchdynamo/issues/681
"""
def __init__(self, args):
self.args = args
self.output_dir = args.output_dir
self.lookup_file = os.path.join(self.args.dashboard_archive_path, "lookup.csv")
assert os.path.exists(self.lookup_file)
try:
if not self.args.update_dashboard_test and not self.args.no_update_archive:
self.update_lookup_file()
except subprocess.CalledProcessError:
sys.stderr.write("failed to update lookup file\n")
def update_lookup_file(self):
dtype = self.args.dtypes[0]
day, _ = archive_data(self.args.archive_name)
target_dir = get_archive_name(self.args, dtype)
# Update lookup csv the folder to arhived logs
subprocess.check_call(
f'echo "{day},performance,{dtype},{target_dir}" >> {self.lookup_file}',
shell=True,
)
def archive(self):
dtype = self.args.dtypes[0]
# Copy the folder to archived location
archive(
self.output_dir,
self.args.dashboard_archive_path,
self.args.archive_name,
dtype,
)
def upload_graphs(self):
title = "## Performance graphs ##\n"
str_io = io.StringIO()
if not self.args.update_dashboard_test and not self.args.no_graphs:
for name in glob.glob(self.output_dir + "/*png"):
if "over_time" not in name:
output = (
subprocess.check_output(
[self.args.dashboard_image_uploader, name]
)
.decode("ascii")
.rstrip()
)
str_io.write(f"\n{name} : \n")
comment = generate_dropdown_comment(title, str_io.getvalue())
with open(f"{self.output_dir}/gh_graphs.txt", "w") as gh_fh:
gh_fh.write(comment)
def gen_comment(self):
files = [
"gh_title.txt",
"gh_executive_summary.txt",
"gh_summary_diff.txt",
"gh_warnings.txt",
"gh_regression.txt",
"gh_metric_regression.txt",
"gh_training.txt" if self.args.training else "gh_inference.txt",
"gh_graphs.txt",
"gh_build_summary.txt",
]
all_lines = []
for f in files:
try:
with open(os.path.join(self.output_dir, f)) as fh:
all_lines.extend(fh.readlines())
except FileNotFoundError:
pass
return "\n".join([x.rstrip() for x in all_lines])
def comment_on_gh(self, comment):
"""
Send a commment to dashboard
"""
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
f.write(comment)
filename = f.name
issue_number = "93794"
if self.args.dtypes[0] == "float32":
issue_number = "93518"
subprocess.check_call(
[
self.args.dashboard_gh_cli_path,
"issue",
"comment",
"--repo=https://github.com/pytorch/pytorch.git",
issue_number,
"-F",
filename,
]
)
os.remove(filename)
def update(self):
self.upload_graphs()
if not self.args.no_detect_regressions:
SummaryStatDiffer(self.args).generate_comment()
RegressionDetector(self.args).generate_comment()
try:
RegressionTracker(self.args).diff()
except Exception:
logging.exception("")
with open(f"{self.args.output_dir}/gh_regression.txt", "w") as gh_fh:
gh_fh.write("")
comment = self.gen_comment()
print(comment)
if not self.args.update_dashboard_test:
if not self.args.no_gh_comment:
self.comment_on_gh(comment)
if not self.args.no_update_archive:
self.archive()
if __name__ == "__main__":
args = parse_args()
def extract(key):
return DEFAULTS[key] if getattr(args, key, None) is None else getattr(args, key)
dtypes = extract("dtypes")
suites = extract("suites")
devices = extract("devices")
if args.inference:
compilers = DEFAULTS["inference"] if args.compilers is None else args.compilers
flag_compilers = (
DEFAULTS["flag_compilers"]["inference"]
if args.flag_compilers is None
else args.flag_compilers
)
else:
assert args.training
compilers = DEFAULTS["training"] if args.compilers is None else args.compilers
flag_compilers = (
DEFAULTS["flag_compilers"]["training"]
if args.flag_compilers is None
else args.flag_compilers
)
output_dir = args.output_dir
args.compilers = compilers
args.devices = devices
args.dtypes = dtypes
flag_compilers = list(set(flag_compilers) & set(compilers))
args.flag_compilers = flag_compilers
args.suites = suites
if args.print_run_commands:
generated_file = generate_commands(
args, dtypes, suites, devices, compilers, output_dir
)
print(
f"Running commands are generated in file {generated_file}. Please run (bash {generated_file})."
)
elif args.visualize_logs:
parse_logs(args, dtypes, suites, devices, compilers, flag_compilers, output_dir)
elif args.run:
generated_file = generate_commands(
args, dtypes, suites, devices, compilers, output_dir
)
# generate memoized archive name now so that the date is reflective
# of when the run started
get_archive_name(args, dtypes[0])
# TODO - Do we need to worry about segfaults
try:
os.system(f"bash {generated_file}")
except Exception as e:
print(
f"Running commands failed. Please run manually (bash {generated_file}) and inspect the errors."
)
raise e
if not args.log_operator_inputs:
if not args.no_update_archive:
archive(
output_dir,
args.dashboard_archive_path,
args.archive_name,
dtypes[0],
)
parse_logs(
args, dtypes, suites, devices, compilers, flag_compilers, output_dir
)
if not args.no_update_archive:
archive(
output_dir,
args.dashboard_archive_path,
args.archive_name,
dtypes[0],
)
if args.update_dashboard:
DashboardUpdater(args).update()
|