File: utils.py

package info (click to toggle)
python-tuf 6.0.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,300 kB
  • sloc: python: 7,738; makefile: 8
file content (366 lines) | stat: -rw-r--r-- 12,591 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
# Copyright 2020, TUF contributors
# SPDX-License-Identifier: MIT OR Apache-2.0

"""
<Program Name>
  utils.py

<Started>
  August 3, 2020.

<Author>
  Jussi Kukkonen

<Copyright>
  See LICENSE-MIT OR LICENSE for licensing information.

<Purpose>
  Provide common utilities for TUF tests
"""

from __future__ import annotations

import argparse
import errno
import logging
import os
import queue
import socket
import subprocess
import sys
import threading
import time
import warnings
from contextlib import contextmanager
from typing import IO, TYPE_CHECKING, Any, Callable

if TYPE_CHECKING:
    import unittest
    from collections.abc import Iterator

logger = logging.getLogger(__name__)

# May be used to reliably read other files in tests dir regardless of cwd
TESTS_DIR = os.path.dirname(os.path.realpath(__file__))

# Used when forming URLs on the client side
TEST_HOST_ADDRESS = "127.0.0.1"


# Test runner decorator: Runs the test as a set of N SubTests,
# (where N is number of items in dataset), feeding the actual test
# function one test case at a time
def run_sub_tests_with_dataset(
    dataset: dict[str, Any],
) -> Callable[[Callable], Callable]:
    """Decorator starting a unittest.TestCase.subtest() for each of the
    cases in dataset"""

    def real_decorator(
        function: Callable[[unittest.TestCase, Any], None],
    ) -> Callable[[unittest.TestCase], None]:
        def wrapper(test_cls: unittest.TestCase) -> None:
            for case, data in dataset.items():
                with test_cls.subTest(case=case):
                    # Save case name for future reference
                    test_cls.case_name = case.replace(" ", "_")
                    function(test_cls, data)

        return wrapper

    return real_decorator


class TestServerProcessError(Exception):
    def __init__(self, value: str = "TestServerProcess") -> None:
        super().__init__()
        self.value = value

    def __str__(self) -> str:
        return repr(self.value)


@contextmanager
def ignore_deprecation_warnings(module: str) -> Iterator[None]:
    with warnings.catch_warnings():
        warnings.filterwarnings(
            "ignore", category=DeprecationWarning, module=module
        )
        yield


# Wait until host:port accepts connections.
# Raises TimeoutError if this does not happen within timeout seconds
# There are major differences between operating systems on how this works
# but the current blocking connect() seems to work fast on Linux and seems
# to at least work on Windows (ECONNREFUSED unfortunately has a 2 second
# timeout on Windows)
def wait_for_server(
    host: str, server: str, port: int, timeout: int = 10
) -> None:
    """Wait for server start until timeout is reached or server has started"""
    start = time.time()
    remaining_timeout = timeout
    succeeded = False
    while not succeeded and remaining_timeout > 0:
        try:
            sock: socket.socket | None = socket.socket(
                socket.AF_INET, socket.SOCK_STREAM
            )
            assert sock is not None
            sock.settimeout(remaining_timeout)
            sock.connect((host, port))
            succeeded = True
        except socket.timeout:
            pass
        except OSError as e:
            # ECONNREFUSED is expected while the server is not started
            if e.errno not in [errno.ECONNREFUSED]:
                logger.warning(
                    "Unexpected error while waiting for server: %s", str(e)
                )
            # Avoid pegging a core just for this
            time.sleep(0.01)
        finally:
            if sock:
                sock.close()
                sock = None
            remaining_timeout = int(timeout - (time.time() - start))

    if not succeeded:
        raise TimeoutError(
            "Could not connect to the " + server + " on port " + str(port) + "!"
        )


def configure_test_logging(argv: list[str]) -> None:
    """Configure logger level for a certain test file"""
    # parse arguments but only handle '-v': argv may contain
    # other things meant for unittest argument parser
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument("-v", "--verbose", action="count", default=0)
    args, _ = parser.parse_known_args(argv)

    if args.verbose <= 1:
        # 0 and 1 both mean ERROR: this way '-v' makes unittest print test
        # names without increasing log level
        loglevel = logging.ERROR
    elif args.verbose == 2:
        loglevel = logging.WARNING
    elif args.verbose == 3:
        loglevel = logging.INFO
    else:
        loglevel = logging.DEBUG

    logging.basicConfig(level=loglevel)


def cleanup_metadata_dir(path: str) -> None:
    """Delete the local metadata dir"""
    with os.scandir(path) as it:
        for entry in it:
            if entry.name == "root_history":
                cleanup_metadata_dir(entry.path)
            elif entry.name.endswith(".json"):
                os.remove(entry.path)
            else:
                raise ValueError(f"Unexpected local metadata file {entry.path}")


class TestServerProcess:
    """Helper class used to create a child process with the subprocess.Popen
    object and use a thread-safe Queue structure for logging.

     Args:
        log: Logger which will be used for logging.
        server: Path to the server to run in the subprocess.
        timeout: Time in seconds in which the server should start or otherwise
            TimeoutError error will be raised.
        popen_cwd: Current working directory used when instancing a
          subprocess.Popen object.
        extra_cmd_args: Additional arguments for the command which will start
            the subprocess. More precisely:
            "python -u <path_to_server> <port> <extra_cmd_args>".
            If no list is provided, an empty list ("[]") will be assigned to it.
    """

    def __init__(
        self,
        log: logging.Logger,
        server: str = os.path.join(TESTS_DIR, "simple_server.py"),
        timeout: int = 10,
        popen_cwd: str = ".",
        extra_cmd_args: list[str] | None = None,
    ):
        self.server = server
        self.__logger = log
        # Stores popped messages from the queue.
        self.__logged_messages: list[str] = []
        self.__server_process: subprocess.Popen | None = None
        self._log_queue: queue.Queue | None = None
        self.port = -1
        if extra_cmd_args is None:
            extra_cmd_args = []

        try:
            self._start_server(timeout, extra_cmd_args, popen_cwd)
            wait_for_server("localhost", self.server, self.port, timeout)
        except Exception as e:
            # Clean the resources and log the server errors if any exists.
            self.clean()
            raise e

    def _start_server(
        self, timeout: int, extra_cmd_args: list[str], popen_cwd: str
    ) -> None:
        """
        Start the server subprocess and a thread
        responsible to redirect stdout/stderr to the Queue.
        Waits for the port message maximum timeout seconds.
        """

        self._start_process(extra_cmd_args, popen_cwd)
        self._start_redirect_thread()

        self._wait_for_port(timeout)

        self.__logger.info("%s serving on %d", self.server, self.port)

    def _start_process(self, extra_cmd_args: list[str], popen_cwd: str) -> None:
        """Starts the process running the server."""

        # The "-u" option forces stdin, stdout and stderr to be unbuffered.
        command = [sys.executable, "-u", self.server, *extra_cmd_args]

        # Reusing one subprocess in multiple tests, but split up the logs
        # for each.
        self.__server_process = subprocess.Popen(
            command,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            cwd=popen_cwd,
        )

    def _start_redirect_thread(self) -> None:
        """Starts a thread redirecting the stdout/stderr to the Queue."""

        assert isinstance(self.__server_process, subprocess.Popen)
        # Run log_queue_worker() in a thread.
        # The thread will exit when the child process dies.
        self._log_queue = queue.Queue()
        log_thread = threading.Thread(
            target=self._log_queue_worker,
            args=(self.__server_process.stdout, self._log_queue),
        )

        # "daemon = True" means the thread won't interfere with the
        # process exit.
        log_thread.daemon = True
        log_thread.start()

    @staticmethod
    def _log_queue_worker(stream: IO, line_queue: queue.Queue) -> None:
        """
        Worker function to run in a separate thread.
        Reads from 'stream', puts lines in a Queue (Queue is thread-safe).
        """

        while True:
            # readline() is a blocking operation.
            # decode to push a string in the queue instead of 8-bit bytes.
            log_line = stream.readline().decode("utf-8")
            line_queue.put(log_line)

            if len(log_line) == 0:
                # This is the end of the stream meaning the server process
                # has exited.
                stream.close()
                break

    def _wait_for_port(self, timeout: int) -> None:
        """
        Validates the first item from the Queue against the port message.
        If validation is successful, self.port is set.
        Raises TestServerProcessError if the process has exited or
        TimeoutError if no message was found within timeout seconds.
        """

        assert isinstance(self.__server_process, subprocess.Popen)
        assert isinstance(self._log_queue, queue.Queue)
        # We have hardcoded the message we expect on a successful server
        # startup. This message should be the first message sent by the server!
        expected_msg = "bind succeeded, server port is: "
        try:
            line = self._log_queue.get(timeout=timeout)
            if len(line) == 0:
                # The process has exited.
                raise TestServerProcessError(
                    self.server
                    + " exited unexpectedly "
                    + "with code "
                    + str(self.__server_process.poll())
                    + "!"
                )

            if line.startswith(expected_msg):
                self.port = int(line[len(expected_msg) :])
            else:
                # An exception or some other message is printed from the server.
                self.__logged_messages.append(line)
                # Check if more lines are logged.
                self.flush_log()
                raise TestServerProcessError(
                    self.server
                    + " did not print port "
                    + "message as first stdout line as expected!"
                )
        except queue.Empty as e:
            raise TimeoutError(
                "Failure during " + self.server + " startup!"
            ) from e

    def _kill_server_process(self) -> None:
        """Kills the server subprocess if it's running."""

        assert isinstance(self.__server_process, subprocess.Popen)
        if self.is_process_running():
            self.__logger.info(
                "Server process %d terminated", self.__server_process.pid
            )
            self.__server_process.kill()
            self.__server_process.wait()

    def flush_log(self) -> None:
        """Flushes the log lines from the logging queue."""

        assert isinstance(self._log_queue, queue.Queue)
        while True:
            # Get lines from log_queue
            try:
                line = self._log_queue.get(block=False)
                if len(line) > 0:
                    self.__logged_messages.append(line)
            except queue.Empty:
                # No more lines are logged in the queue.
                break

        if len(self.__logged_messages) > 0:
            title = "Test server (" + self.server + ") output:\n"
            message = [title, *self.__logged_messages]
            self.__logger.info("| ".join(message))
            self.__logged_messages = []

    def clean(self) -> None:
        """
        Kills the subprocess and closes the TempFile.
        Calls flush_log to check for logged information, but not yet flushed.
        """

        # If there is anything logged, flush it before closing the resources.
        self.flush_log()

        self._kill_server_process()

    def is_process_running(self) -> bool:
        assert isinstance(self.__server_process, subprocess.Popen)
        return self.__server_process.poll() is None