1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
|
import os
import subprocess
from typing import List, Union
import sys
import pytest
def pytest_addoption(parser):
parser.addoption(
"--shell-binary",
action="store",
default=None,
help="Provide the shell binary to use for the tests",
)
parser.addoption("--start-offset", action="store", type=int, help="Skip the first 'n' tests")
def pytest_collection_modifyitems(config, items):
start_offset = config.getoption("--start-offset")
if not start_offset:
# --skiplist not given in cli, therefore move on
return
skipped = pytest.mark.skip(reason="included in --skiplist")
skipped_items = items[:start_offset]
for item in skipped_items:
item.add_marker(skipped)
class TestResult:
def __init__(self, stdout, stderr, status_code):
self.stdout: Union[str, bytes] = stdout
self.stderr: Union[str, bytes] = stderr
self.status_code: int = status_code
def check_stdout(self, expected: Union[str, List[str], bytes, None]):
if expected is None:
assert self.stdout == ""
return
if isinstance(expected, list):
expected = "\n".join(expected)
if self.status_code != 0:
print(self.stderr, file=sys.stderr)
assert self.status_code == 0
if not is_needle_in_haystack(expected, self.stdout):
print(self.stdout, file=sys.stderr)
assert is_needle_in_haystack(expected, self.stdout)
def check_not_exist(self, not_exist: Union[str, List[str], bytes]):
if isinstance(not_exist, list):
not_exist = "\n".join(not_exist)
assert self.status_code == 0
if is_needle_in_haystack(not_exist, self.stdout):
print(self.stdout, file=sys.stderr)
assert not is_needle_in_haystack(not_exist, self.stdout)
def check_stderr(self, expected):
if expected is None:
assert self.stderr == ""
else:
assert is_needle_in_haystack(expected, self.stderr)
class ShellTest:
def __init__(self, shell, arguments=[]):
if not shell:
raise ValueError("Please provide a shell binary")
self.shell = shell
self.arguments = [shell, "--batch"] + arguments
if "-init" not in arguments and "--init" not in arguments:
arguments += ["--no-init"]
self.statements: List[str] = []
self.input = None
self.output = None
self.environment = {}
def add_argument(self, *args):
self.arguments.extend(args)
return self
def statement(self, stmt):
self.statements.append(stmt)
return self
def env_var(self, key: str, value: str):
self.environment[key] = value
return self
def query(self, *stmts):
self.statements.extend(stmts)
return self
def input_file(self, file_path):
self.input = file_path
return self
def output_file(self, file_path):
self.output = file_path
return self
# Test Running methods
def get_command(self, cmd: str) -> List[str]:
command = self.arguments
if self.input:
command += [cmd]
return command
def get_input_data(self, cmd: str):
if self.input:
with open(self.input, "rb") as f:
input_data = f.read()
else:
input_data = bytearray(cmd, "utf8")
return input_data
def get_statements(self):
statements = []
for statement in self.statements:
if statement.startswith("."):
statements.append(statement)
else:
statements.append(statement + ";")
return "\n".join(statements)
def get_output_data(self, res):
if self.output:
with open(self.output, "r") as f:
stdout = f.read()
else:
stdout = res.stdout.decode("utf8").strip()
stderr = res.stderr.decode("utf8").strip()
return stdout, stderr
def run(self):
statements = self.get_statements()
command = self.get_command(statements)
input_data = self.get_input_data(statements)
my_env = os.environ.copy()
for key, val in self.environment.items():
my_env[key] = val
if self.output:
with open(self.output, "w") as output_pipe:
res = subprocess.run(
command,
input=input_data,
stdout=output_pipe,
stderr=subprocess.PIPE,
env=my_env,
)
else:
res = subprocess.run(
command,
input=input_data,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=my_env,
)
stdout, stderr = self.get_output_data(res)
return TestResult(stdout, stderr, res.returncode)
@pytest.fixture()
def shell(request):
custom_arg = request.config.getoption("--shell-binary")
if not custom_arg:
raise ValueError("Please provide a shell binary path to the tester, using '--shell-binary <path_to_cli>'")
return custom_arg
@pytest.fixture()
def random_filepath(tmp_path):
tmp_file = tmp_path / "random_import_file"
return tmp_file
@pytest.fixture()
def generated_file(request, random_filepath):
param = request.param
tmp_file = random_filepath
with open(tmp_file, "w+") as f:
f.write(param)
return tmp_file
def check_load_status(shell, extension: str):
binary = ShellTest(shell)
binary.statement(f"select loaded from duckdb_extensions() where extension_name = '{extension}';")
result = binary.run()
return result.stdout
def is_needle_in_haystack(needle, haystack) -> bool:
if needle == "" or haystack == "":
return False
if isinstance(haystack, str) and isinstance(needle, str):
if needle in haystack:
return True
if os.name == 'nt' and '\n' in needle:
# try with windows-style newlines
return needle.replace('\n', '\r\n') in haystack
return False
elif isinstance(haystack, bytes) and isinstance(needle, bytes):
return needle in haystack
else:
return False
def assert_loaded(shell, extension: str):
# TODO: add a command line argument to fail instead of skip if the extension is not loaded
out = check_load_status(shell, extension)
if not is_needle_in_haystack("true", out):
pytest.skip(reason=f"'{extension}' extension is not loaded!")
return
@pytest.fixture()
def autocomplete_extension(shell):
assert_loaded(shell, "autocomplete")
@pytest.fixture()
def json_extension(shell):
assert_loaded(shell, "json")
|