File: generate_tests_from_examples.py

package info (click to toggle)
python-apischema 0.18.3-1
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 1,608 kB
  • sloc: python: 15,266; sh: 7; makefile: 7
file content (108 lines) | stat: -rwxr-xr-x 3,977 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python3
import os
import re
import sys
from itertools import takewhile
from pathlib import Path
from shutil import rmtree
from typing import Iterable, Iterator, Tuple

ROOT_DIR = Path(__file__).parent.parent

EXAMPLES_PATH = ROOT_DIR / "examples"
GENERATED_PATH = ROOT_DIR / "tests" / "__generated__"
with open(ROOT_DIR / "scripts" / "test_wrapper.py") as wrapper_file:
    before_lines = [*takewhile(lambda l: not l.startswith("##"), wrapper_file), "##\n"]
    after_lines = ["##\n", *wrapper_file]


def iter_paths() -> Iterator[Tuple[Path, Path]]:
    for example_path in EXAMPLES_PATH.glob("**/*.py"):
        if example_path.name == "__init__.py":
            continue
        relative_path = example_path.relative_to(EXAMPLES_PATH)
        test_dir = GENERATED_PATH / relative_path.parent
        test_dir.mkdir(parents=True, exist_ok=True)
        yield example_path, test_dir / f"test_{relative_path.name}"


INDENTATION = 4 * " "
union_regex = re.compile(r"..(\w+(\[.+?\])? \| )+(\w+)")
# regex is not recursive and thus cannot catch things like Connection[Ship | None] | None

try:
    from re import Match
except ImportError:
    Match = ...  # type: ignore


def replace_union(match: Match) -> str:
    args = list(map(str.strip, match.group(0)[2:].split("|")))
    if match.group(0)[0] == "=" and args[-1] != "None":  # graphql types
        return match.group(0)
    joined = ", ".join(args)
    return match.group(0)[:2] + f"Union[{joined}]"


def handle_union(line: str) -> str:
    return union_regex.sub(replace_union, line)


def main():
    if GENERATED_PATH.exists():
        rmtree(GENERATED_PATH)
    GENERATED_PATH.mkdir(parents=True)
    for example_path, test_path in iter_paths():
        example: Iterable[str]
        with open(example_path) as example:
            with open(test_path, "w") as test:
                if (
                    sys.version_info < (3, 10)
                    or os.getenv("TOXENV", None) != "py310"
                    or True
                ):
                    example = map(handle_union, example)
                # 3.9 compatibility is added after __future__ import
                # However, Annotated/Literal/etc. can be an issue
                first_line = next(example)
                if first_line.startswith("from __future__ import"):
                    test.write(first_line)
                    test.writelines(before_lines)
                else:
                    test.writelines(before_lines)
                    test.write(first_line)
                test_count = 0
                while example:
                    # Classes must be declared in global namespace in order to get
                    # get_type_hints and is_method to work
                    # Test function begin at the first assertion.
                    for line in example:
                        if line.startswith("assert ") or line.startswith(
                            "with raises("
                        ):
                            test.write(f"def {test_path.stem}{test_count}():\n")
                            test.write(INDENTATION + line)
                            break
                        test.write(line)
                    else:
                        break
                    cur_indent = INDENTATION
                    for line in example:
                        if any(line.startswith(s) for s in ("class ", "@")):
                            test.write(line)
                            test_count += 1
                            break
                        test.write(cur_indent + line)
                        if '"""' in line:
                            cur_indent = "" if cur_indent else INDENTATION
                    else:
                        break
                test.writelines(after_lines)

    for path in GENERATED_PATH.glob("**"):
        if path.is_dir():
            open(path / "__init__.py", "w").close()


if __name__ == "__main__":
    main()