1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
|
import subprocess
import click
def test(cmd, limit):
print(f"Testing PYTORCH_JIT_OPT_LIMIT=tensorexpr_fuser={limit} {cmd}")
p = subprocess.run(
f"PYTORCH_JIT_OPT_LIMIT=tensorexpr_fuser={limit} {cmd}",
shell=True,
capture_output=True,
encoding="utf-8",
)
print(p.stdout)
f = "INTERNAL ASSERT FAILED"
if f in p.stdout or f in p.stderr:
print("skip")
return -1
if p.returncode == 0:
print("good")
return 1
print("bad")
return 0
@click.command()
@click.option("--cmd")
def bisect(cmd):
last_good = 0
first_bad = 10000
skips = set()
# Test if there are any unskipped commits in (last_good, first_bad)
def keep_going():
for limit in range(last_good + 1, first_bad):
if limit not in skips:
return True
return False
while keep_going():
test_limit = test_mid = (last_good + first_bad) // 2
val = -1
# Scan forward from mid towards bad.
while test_limit <= first_bad and val == -1:
val = test(cmd, test_limit)
if val == -1:
skips.add(test_limit)
test_limit = test_limit + 1
# If everything in [mid, bad] skipped, scan back towards good.
if val == -1:
test_limit = test_mid - 1
while test_limit >= last_good and val == -1:
val = test(cmd, test_limit)
if val == -1:
skips.add(test_limit)
test_limit = test_limit - 1
if val == 0:
first_bad = test_limit
elif val == 1:
last_good = test_limit
print(f"last good: {last_good}, first bad: {first_bad}")
if __name__ == "__main__":
bisect()
|