1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
|
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
from __future__ import annotations
import ast
import sys
from typing import Iterator, NamedTuple
class Flake8ASTErrorInfo(NamedTuple):
line_number: int
offset: int
msg: str
cls: type # unused
class Visitor(ast.NodeVisitor):
msg = "AK101 exception must be wrapped in ak._errors.*_error"
def __init__(self):
self.errors: list[Flake8ASTErrorInfo] = []
def visit_Raise(self, node):
if isinstance(node.exc, ast.Call):
if isinstance(node.exc.func, ast.Attribute):
if node.exc.func.attr in {"wrap_error", "index_error"}:
return self.generic_visit(node)
elif isinstance(node.exc.func, ast.Name):
if node.exc.func.id in {"wrap_error", "index_error", "ImportError"}:
return self.generic_visit(node)
self.errors.append(
Flake8ASTErrorInfo(node.lineno, node.col_offset, self.msg, type(self))
)
class AwkwardASTPlugin:
name = "flake8_awkward"
version = "0.0.0"
def __init__(self, tree: ast.AST) -> None:
self._tree = tree
def run(self) -> Iterator[Flake8ASTErrorInfo]:
visitor = Visitor()
visitor.visit(self._tree)
yield from visitor.errors
def main(path):
with open(path) as f:
code = f.read()
node = ast.parse(code)
plugin = AwkwardASTPlugin(node)
for err in plugin.run():
print(f"{path}:{err.line_number}:{err.offset} {err.msg}")
if __name__ == "__main__":
for item in sys.argv[1:]:
main(item)
|