1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
|
import sys
import textwrap
from pathlib import Path
def check(path):
"""Check a test file for common issues with pytest->pytorch conversion."""
print(path.name)
print("=" * len(path.name), "\n")
src = path.read_text().split("\n")
for num, line in enumerate(src):
if is_comment(line):
continue
# module level test functions
if line.startswith("def test"):
report_violation(line, num, header="Module-level test function")
# test classes must inherit from TestCase
if line.startswith("class Test") and "TestCase" not in line:
report_violation(
line, num, header="Test class does not inherit from TestCase"
)
# last vestiges of pytest-specific stuff
if "pytest.mark" in line:
report_violation(line, num, header="pytest.mark.something")
for part in ["pytest.xfail", "pytest.skip", "pytest.param"]:
if part in line:
report_violation(line, num, header=f"stray {part}")
if textwrap.dedent(line).startswith("@parametrize"):
# backtrack to check
nn = num
for nn in range(num, -1, -1):
ln = src[nn]
if "class Test" in ln:
# hack: large indent => likely an inner class
if len(ln) - len(ln.lstrip()) < 8:
break
else:
report_violation(line, num, "off-class parametrize")
if not src[nn - 1].startswith("@instantiate_parametrized_tests"):
report_violation(
line, num, f"missing instantiation of parametrized tests in {ln}?"
)
def is_comment(line):
return textwrap.dedent(line).startswith("#")
def report_violation(line, lineno, header):
print(f">>>> line {lineno} : {header}\n {line}\n")
if __name__ == "__main__":
argv = sys.argv
if len(argv) != 2:
raise ValueError("Usage : python check_tests_conform path/to/file/or/dir")
path = Path(argv[1])
if path.is_dir():
# run for all files in the directory (no subdirs)
for this_path in path.glob("test*.py"):
# breakpoint()
check(this_path)
else:
check(path)
|