1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
|
"""
Helpers for tests.
"""
import ast
from typing import Any, Tuple, Callable, TypeVar
from ast_decompiler import decompile
from ast_decompiler.check import check as check
import difflib
import sys
VERSION = sys.version_info.major
_CallableT = TypeVar("_CallableT", bound=Callable[..., None])
def assert_decompiles(
code: str,
result: str,
do_check: bool = True,
indentation: int = 4,
line_length: int = 100,
starting_indentation: int = 0,
) -> None:
"""Asserts that code, when parsed, decompiles into result."""
decompile_result = decompile(
ast.parse(code),
indentation=indentation,
line_length=line_length,
starting_indentation=starting_indentation,
)
if do_check:
check(decompile_result)
if result != decompile_result:
print(">>> expected")
print(result)
print(">>> actual")
print(decompile_result)
print(">>> diff")
for line in difflib.unified_diff(
result.splitlines(), decompile_result.splitlines()
):
print(line)
assert False, f"failed to decompile {code}"
def only_on_version(py_version: int) -> Callable[[_CallableT], _CallableT]:
"""Decorator that runs a test only if the Python version matches."""
if py_version != VERSION:
def decorator(fn: Callable[..., Any]) -> Callable[..., None]:
return lambda *args: None
else:
def decorator(fn: _CallableT) -> _CallableT:
return fn
return decorator
def skip_before(py_version: Tuple[int, int]) -> Callable[[_CallableT], _CallableT]:
"""Decorator that skips a test on Python versions before py_version."""
if sys.version_info < py_version:
def decorator(fn: Callable[..., Any]) -> Callable[..., None]:
return lambda *args: None
else:
def decorator(fn: _CallableT) -> _CallableT:
return fn
return decorator
def skip_after(py_version: Tuple[int, int]) -> Callable[[_CallableT], _CallableT]:
"""Decorator that skips a test on Python versions after py_version."""
if sys.version_info > py_version:
def decorator(fn: Callable[..., Any]) -> Callable[..., None]:
return lambda *args, **kwargs: None
else:
def decorator(fn: _CallableT) -> _CallableT:
return fn
return decorator
|