1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
|
#!/usr/bin/env python3
"""
This lint verifies that every Python test file (file that matches test_*.py or
*_test.py in the test folder) has a main block which raises an exception or
calls run_tests to ensure that the test will be run in OSS CI.
Takes ~2 minuters to run without the multiprocessing, probably overkill.
"""
from __future__ import annotations
import argparse
import json
import multiprocessing as mp
from enum import Enum
from typing import NamedTuple
import libcst as cst
import libcst.matchers as m
LINTER_CODE = "TEST_HAS_MAIN"
class HasMainVisiter(cst.CSTVisitor):
def __init__(self) -> None:
super().__init__()
self.found = False
def visit_Module(self, node: cst.Module) -> bool:
name = m.Name("__name__")
main = m.SimpleString('"__main__"') | m.SimpleString("'__main__'")
run_test_call = m.Call(
func=m.Name("run_tests") | m.Attribute(attr=m.Name("run_tests"))
)
# Distributed tests (i.e. MultiProcContinuousTest) calls `run_rank`
# instead of `run_tests` in main
run_rank_call = m.Call(
func=m.Name("run_rank") | m.Attribute(attr=m.Name("run_rank"))
)
raise_block = m.Raise()
# name == main or main == name
if_main1 = m.Comparison(
name,
[m.ComparisonTarget(m.Equal(), main)],
)
if_main2 = m.Comparison(
main,
[m.ComparisonTarget(m.Equal(), name)],
)
for child in node.children:
if m.matches(child, m.If(test=if_main1 | if_main2)):
if m.findall(child, raise_block | run_test_call | run_rank_call):
self.found = True
break
return False
class LintSeverity(str, Enum):
ERROR = "error"
WARNING = "warning"
ADVICE = "advice"
DISABLED = "disabled"
class LintMessage(NamedTuple):
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: str | None
replacement: str | None
description: str | None
def check_file(filename: str) -> list[LintMessage]:
lint_messages = []
with open(filename) as f:
file = f.read()
v = HasMainVisiter()
cst.parse_module(file).visit(v)
if not v.found:
message = (
"Test files need to have a main block which either calls run_tests "
+ "(to ensure that the tests are run during OSS CI) or raises an exception "
+ "and added to the blocklist in test/run_test.py"
)
lint_messages.append(
LintMessage(
path=filename,
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.ERROR,
name="[no-main]",
original=None,
replacement=None,
description=message,
)
)
return lint_messages
def main() -> None:
parser = argparse.ArgumentParser(
description="test files should have main block linter",
fromfile_prefix_chars="@",
)
parser.add_argument(
"filenames",
nargs="+",
help="paths to lint",
)
args = parser.parse_args()
pool = mp.Pool(8)
lint_messages = pool.map(check_file, args.filenames)
pool.close()
pool.join()
flat_lint_messages = []
for sublist in lint_messages:
flat_lint_messages.extend(sublist)
for lint_message in flat_lint_messages:
print(json.dumps(lint_message._asdict()), flush=True)
if __name__ == "__main__":
main()
|