1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
|
# Copyright 2020, TUF contributors
# SPDX-License-Identifier: MIT OR Apache-2.0
"""
<Program Name>
utils.py
<Started>
August 3, 2020.
<Author>
Jussi Kukkonen
<Copyright>
See LICENSE-MIT OR LICENSE for licensing information.
<Purpose>
Provide common utilities for TUF tests
"""
from __future__ import annotations
import argparse
import errno
import logging
import os
import queue
import socket
import subprocess
import sys
import threading
import time
import warnings
from contextlib import contextmanager
from typing import IO, TYPE_CHECKING, Any, Callable
if TYPE_CHECKING:
import unittest
from collections.abc import Iterator
logger = logging.getLogger(__name__)
# May be used to reliably read other files in tests dir regardless of cwd
TESTS_DIR = os.path.dirname(os.path.realpath(__file__))
# Used when forming URLs on the client side
TEST_HOST_ADDRESS = "127.0.0.1"
# Test runner decorator: Runs the test as a set of N SubTests,
# (where N is number of items in dataset), feeding the actual test
# function one test case at a time
def run_sub_tests_with_dataset(
dataset: dict[str, Any],
) -> Callable[[Callable], Callable]:
"""Decorator starting a unittest.TestCase.subtest() for each of the
cases in dataset"""
def real_decorator(
function: Callable[[unittest.TestCase, Any], None],
) -> Callable[[unittest.TestCase], None]:
def wrapper(test_cls: unittest.TestCase) -> None:
for case, data in dataset.items():
with test_cls.subTest(case=case):
# Save case name for future reference
test_cls.case_name = case.replace(" ", "_")
function(test_cls, data)
return wrapper
return real_decorator
class TestServerProcessError(Exception):
def __init__(self, value: str = "TestServerProcess") -> None:
super().__init__()
self.value = value
def __str__(self) -> str:
return repr(self.value)
@contextmanager
def ignore_deprecation_warnings(module: str) -> Iterator[None]:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", category=DeprecationWarning, module=module
)
yield
# Wait until host:port accepts connections.
# Raises TimeoutError if this does not happen within timeout seconds
# There are major differences between operating systems on how this works
# but the current blocking connect() seems to work fast on Linux and seems
# to at least work on Windows (ECONNREFUSED unfortunately has a 2 second
# timeout on Windows)
def wait_for_server(
host: str, server: str, port: int, timeout: int = 10
) -> None:
"""Wait for server start until timeout is reached or server has started"""
start = time.time()
remaining_timeout = timeout
succeeded = False
while not succeeded and remaining_timeout > 0:
try:
sock: socket.socket | None = socket.socket(
socket.AF_INET, socket.SOCK_STREAM
)
assert sock is not None
sock.settimeout(remaining_timeout)
sock.connect((host, port))
succeeded = True
except socket.timeout:
pass
except OSError as e:
# ECONNREFUSED is expected while the server is not started
if e.errno not in [errno.ECONNREFUSED]:
logger.warning(
"Unexpected error while waiting for server: %s", str(e)
)
# Avoid pegging a core just for this
time.sleep(0.01)
finally:
if sock:
sock.close()
sock = None
remaining_timeout = int(timeout - (time.time() - start))
if not succeeded:
raise TimeoutError(
"Could not connect to the " + server + " on port " + str(port) + "!"
)
def configure_test_logging(argv: list[str]) -> None:
"""Configure logger level for a certain test file"""
# parse arguments but only handle '-v': argv may contain
# other things meant for unittest argument parser
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("-v", "--verbose", action="count", default=0)
args, _ = parser.parse_known_args(argv)
if args.verbose <= 1:
# 0 and 1 both mean ERROR: this way '-v' makes unittest print test
# names without increasing log level
loglevel = logging.ERROR
elif args.verbose == 2:
loglevel = logging.WARNING
elif args.verbose == 3:
loglevel = logging.INFO
else:
loglevel = logging.DEBUG
logging.basicConfig(level=loglevel)
def cleanup_metadata_dir(path: str) -> None:
"""Delete the local metadata dir"""
with os.scandir(path) as it:
for entry in it:
if entry.name == "root_history":
cleanup_metadata_dir(entry.path)
elif entry.name.endswith(".json"):
os.remove(entry.path)
else:
raise ValueError(f"Unexpected local metadata file {entry.path}")
class TestServerProcess:
"""Helper class used to create a child process with the subprocess.Popen
object and use a thread-safe Queue structure for logging.
Args:
log: Logger which will be used for logging.
server: Path to the server to run in the subprocess.
timeout: Time in seconds in which the server should start or otherwise
TimeoutError error will be raised.
popen_cwd: Current working directory used when instancing a
subprocess.Popen object.
extra_cmd_args: Additional arguments for the command which will start
the subprocess. More precisely:
"python -u <path_to_server> <port> <extra_cmd_args>".
If no list is provided, an empty list ("[]") will be assigned to it.
"""
def __init__(
self,
log: logging.Logger,
server: str = os.path.join(TESTS_DIR, "simple_server.py"),
timeout: int = 10,
popen_cwd: str = ".",
extra_cmd_args: list[str] | None = None,
):
self.server = server
self.__logger = log
# Stores popped messages from the queue.
self.__logged_messages: list[str] = []
self.__server_process: subprocess.Popen | None = None
self._log_queue: queue.Queue | None = None
self.port = -1
if extra_cmd_args is None:
extra_cmd_args = []
try:
self._start_server(timeout, extra_cmd_args, popen_cwd)
wait_for_server("localhost", self.server, self.port, timeout)
except Exception as e:
# Clean the resources and log the server errors if any exists.
self.clean()
raise e
def _start_server(
self, timeout: int, extra_cmd_args: list[str], popen_cwd: str
) -> None:
"""
Start the server subprocess and a thread
responsible to redirect stdout/stderr to the Queue.
Waits for the port message maximum timeout seconds.
"""
self._start_process(extra_cmd_args, popen_cwd)
self._start_redirect_thread()
self._wait_for_port(timeout)
self.__logger.info("%s serving on %d", self.server, self.port)
def _start_process(self, extra_cmd_args: list[str], popen_cwd: str) -> None:
"""Starts the process running the server."""
# The "-u" option forces stdin, stdout and stderr to be unbuffered.
command = [sys.executable, "-u", self.server, *extra_cmd_args]
# Reusing one subprocess in multiple tests, but split up the logs
# for each.
self.__server_process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
cwd=popen_cwd,
)
def _start_redirect_thread(self) -> None:
"""Starts a thread redirecting the stdout/stderr to the Queue."""
assert isinstance(self.__server_process, subprocess.Popen)
# Run log_queue_worker() in a thread.
# The thread will exit when the child process dies.
self._log_queue = queue.Queue()
log_thread = threading.Thread(
target=self._log_queue_worker,
args=(self.__server_process.stdout, self._log_queue),
)
# "daemon = True" means the thread won't interfere with the
# process exit.
log_thread.daemon = True
log_thread.start()
@staticmethod
def _log_queue_worker(stream: IO, line_queue: queue.Queue) -> None:
"""
Worker function to run in a separate thread.
Reads from 'stream', puts lines in a Queue (Queue is thread-safe).
"""
while True:
# readline() is a blocking operation.
# decode to push a string in the queue instead of 8-bit bytes.
log_line = stream.readline().decode("utf-8")
line_queue.put(log_line)
if len(log_line) == 0:
# This is the end of the stream meaning the server process
# has exited.
stream.close()
break
def _wait_for_port(self, timeout: int) -> None:
"""
Validates the first item from the Queue against the port message.
If validation is successful, self.port is set.
Raises TestServerProcessError if the process has exited or
TimeoutError if no message was found within timeout seconds.
"""
assert isinstance(self.__server_process, subprocess.Popen)
assert isinstance(self._log_queue, queue.Queue)
# We have hardcoded the message we expect on a successful server
# startup. This message should be the first message sent by the server!
expected_msg = "bind succeeded, server port is: "
try:
line = self._log_queue.get(timeout=timeout)
if len(line) == 0:
# The process has exited.
raise TestServerProcessError(
self.server
+ " exited unexpectedly "
+ "with code "
+ str(self.__server_process.poll())
+ "!"
)
if line.startswith(expected_msg):
self.port = int(line[len(expected_msg) :])
else:
# An exception or some other message is printed from the server.
self.__logged_messages.append(line)
# Check if more lines are logged.
self.flush_log()
raise TestServerProcessError(
self.server
+ " did not print port "
+ "message as first stdout line as expected!"
)
except queue.Empty as e:
raise TimeoutError(
"Failure during " + self.server + " startup!"
) from e
def _kill_server_process(self) -> None:
"""Kills the server subprocess if it's running."""
assert isinstance(self.__server_process, subprocess.Popen)
if self.is_process_running():
self.__logger.info(
"Server process %d terminated", self.__server_process.pid
)
self.__server_process.kill()
self.__server_process.wait()
def flush_log(self) -> None:
"""Flushes the log lines from the logging queue."""
assert isinstance(self._log_queue, queue.Queue)
while True:
# Get lines from log_queue
try:
line = self._log_queue.get(block=False)
if len(line) > 0:
self.__logged_messages.append(line)
except queue.Empty:
# No more lines are logged in the queue.
break
if len(self.__logged_messages) > 0:
title = "Test server (" + self.server + ") output:\n"
message = [title, *self.__logged_messages]
self.__logger.info("| ".join(message))
self.__logged_messages = []
def clean(self) -> None:
"""
Kills the subprocess and closes the TempFile.
Calls flush_log to check for logged information, but not yet flushed.
"""
# If there is anything logged, flush it before closing the resources.
self.flush_log()
self._kill_server_process()
def is_process_running(self) -> bool:
assert isinstance(self.__server_process, subprocess.Popen)
return self.__server_process.poll() is None
|