1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
|
import asyncio
import atexit
import importlib
import os
import platform
import sys
import tempfile
from dataclasses import dataclass
from pathlib import Path
from types import ModuleType
from typing import (
Callable,
Dict,
Generator,
List,
Optional,
Tuple,
Union,
)
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
root_path = Path(__file__).resolve().parent
inputs_path = root_path.joinpath("inputs")
output_path_reference = root_path.joinpath("output_reference")
output_path_aristaproto = root_path.joinpath("output_aristaproto")
output_path_aristaproto_pydantic = root_path.joinpath("output_aristaproto_pydantic")
def get_files(path, suffix: str) -> Generator[str, None, None]:
for r, dirs, files in os.walk(path):
for filename in [f for f in files if f.endswith(suffix)]:
yield os.path.join(r, filename)
def get_directories(path):
for root, directories, files in os.walk(path):
yield from directories
async def protoc(
path: Union[str, Path],
output_dir: Union[str, Path],
reference: bool = False,
pydantic_dataclasses: bool = False,
):
path: Path = Path(path).resolve()
output_dir: Path = Path(output_dir).resolve()
python_out_option: str = "python_aristaproto_out" if not reference else "python_out"
if pydantic_dataclasses:
plugin_path = Path("src/aristaproto/plugin/main.py")
if "Win" in platform.system():
with tempfile.NamedTemporaryFile(
"w", encoding="UTF-8", suffix=".bat", delete=False
) as tf:
# See https://stackoverflow.com/a/42622705
tf.writelines(
[
"@echo off",
f"\nchdir {os.getcwd()}",
f"\n{sys.executable} -u {plugin_path.as_posix()}",
]
)
tf.flush()
plugin_path = Path(tf.name)
atexit.register(os.remove, plugin_path)
command = [
sys.executable,
"-m",
"grpc.tools.protoc",
f"--plugin=protoc-gen-custom={plugin_path.as_posix()}",
"--experimental_allow_proto3_optional",
"--custom_opt=pydantic_dataclasses",
f"--proto_path={path.as_posix()}",
f"--custom_out={output_dir.as_posix()}",
*[p.as_posix() for p in path.glob("*.proto")],
]
else:
command = [
sys.executable,
"-m",
"grpc.tools.protoc",
f"--proto_path={path.as_posix()}",
f"--{python_out_option}={output_dir.as_posix()}",
*[p.as_posix() for p in path.glob("*.proto")],
]
proc = await asyncio.create_subprocess_exec(
*command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await proc.communicate()
return stdout, stderr, proc.returncode
@dataclass
class TestCaseJsonFile:
json: str
test_name: str
file_name: str
def belongs_to(self, non_symmetrical_json: Dict[str, Tuple[str, ...]]):
return self.file_name in non_symmetrical_json.get(self.test_name, tuple())
def get_test_case_json_data(
test_case_name: str, *json_file_names: str
) -> List[TestCaseJsonFile]:
"""
:return:
A list of all files found in "{inputs_path}/test_case_name" with names matching
f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by
json_file_names
"""
test_case_dir = inputs_path.joinpath(test_case_name)
possible_file_paths = [
*(test_case_dir.joinpath(json_file_name) for json_file_name in json_file_names),
test_case_dir.joinpath(f"{test_case_name}.json"),
*test_case_dir.glob(f"{test_case_name}_*.json"),
]
result = []
for test_data_file_path in possible_file_paths:
if not test_data_file_path.exists():
continue
with test_data_file_path.open("r") as fh:
result.append(
TestCaseJsonFile(
fh.read(), test_case_name, test_data_file_path.name.split(".")[0]
)
)
return result
def find_module(
module: ModuleType, predicate: Callable[[ModuleType], bool]
) -> Optional[ModuleType]:
"""
Recursively search module tree for a module that matches the search predicate.
Assumes that the submodules are directories containing __init__.py.
Example:
# find module inside foo that contains Test
import foo
test_module = find_module(foo, lambda m: hasattr(m, 'Test'))
"""
if predicate(module):
return module
module_path = Path(*module.__path__)
for sub in [sub.parent for sub in module_path.glob("**/__init__.py")]:
if sub == module_path:
continue
sub_module_path = sub.relative_to(module_path)
sub_module_name = ".".join(sub_module_path.parts)
sub_module = importlib.import_module(f".{sub_module_name}", module.__name__)
if predicate(sub_module):
return sub_module
return None
|