File: outputcapture.py

package info (click to toggle)
python-testfixtures 8.3.0-2
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 1,064 kB
  • sloc: python: 10,208; makefile: 76; sh: 9
file content (132 lines) | stat: -rw-r--r-- 4,793 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import sys
from io import StringIO
from tempfile import TemporaryFile

from testfixtures.comparison import compare


class OutputCapture:
    """
    A context manager for capturing output to the
    :any:`sys.stdout` and :any:`sys.stderr` streams.

    :param separate: If ``True``, ``stdout`` and ``stderr`` will be captured
                     separately and their expected values must be passed to
                     :meth:`~OutputCapture.compare`.

    :param fd: If ``True``, the underlying file descriptors will be captured,
               rather than just the attributes on :mod:`sys`. This allows
               you to capture things like subprocesses that write directly
               to the file descriptors, but is more invasive, so only use it
               when you need it.

    :param strip_whitespace:
        When ``True``, which is the default, leading and training whitespace
        is trimmed from both the expected and actual values when comparing.

    .. note:: If ``separate`` is passed as ``True``,
              :attr:`OutputCapture.captured` will be an empty string.
    """

    original_stdout = None
    original_stderr = None

    def __init__(self, separate: bool = False, fd: bool = False, strip_whitespace: bool = True):
        self.separate = separate
        self.fd = fd
        self.strip_whitespace = strip_whitespace

    def __enter__(self):
        if self.fd:
            self.output = TemporaryFile()
            self.stdout = TemporaryFile()
            self.stderr = TemporaryFile()
        else:
            self.output = StringIO()
            self.stdout = StringIO()
            self.stderr = StringIO()
        self.enable()
        return self

    def __exit__(self, *args):
        self.disable()

    def disable(self):
        "Disable the output capture if it is enabled."
        if self.fd:
            for original, current in (
                (self.original_stdout, sys.stdout),
                (self.original_stderr, sys.stderr),
            ):
                os.dup2(original, current.fileno())
                os.close(original)

        else:
            sys.stdout = self.original_stdout
            sys.stderr = self.original_stderr

    def enable(self):
        "Enable the output capture if it is disabled."
        if self.original_stdout is None:
            if self.fd:
                self.original_stdout = os.dup(sys.stdout.fileno())
                self.original_stderr = os.dup(sys.stderr.fileno())
            else:
                self.original_stdout = sys.stdout
                self.original_stderr = sys.stderr
        if self.separate:
            if self.fd:
                os.dup2(self.stdout.fileno(), sys.stdout.fileno())
                os.dup2(self.stderr.fileno(), sys.stderr.fileno())
            else:
                sys.stdout = self.stdout
                sys.stderr = self.stderr
        else:
            if self.fd:
                os.dup2(self.output.fileno(), sys.stdout.fileno())
                os.dup2(self.output.fileno(), sys.stderr.fileno())
            else:
                sys.stdout = sys.stderr = self.output

    def _read(self, stream):
        if self.fd:
            stream.seek(0)
            return stream.read().decode()
        else:
            return stream.getvalue()

    @property
    def captured(self) -> str:
        "A property containing any output that has been captured so far."
        return self._read(self.output)

    def compare(self, expected: str = '', stdout: str = '', stderr: str = ''):
        """
        Compare the captured output to that expected. If the output is
        not the same, an :class:`AssertionError` will be raised.

        :param expected: A string containing the expected combined output
                         of ``stdout`` and ``stderr``.

        :param stdout: A string containing the expected output to ``stdout``.

        :param stderr: A string containing the expected output to ``stderr``.
        """
        expected_mapping = {}
        actual_mapping = {}
        for prefix, _expected, captured in (
                ('captured', expected, self.captured),
                ('stdout', stdout, self._read(self.stdout)),
                ('stderr', stderr, self._read(self.stderr)),
        ):
            if self.strip_whitespace:
                _expected = _expected.strip()
                captured = captured.strip()
            if _expected != captured:
                expected_mapping[prefix] = _expected
                actual_mapping[prefix] = captured
        if len(expected_mapping) == 1:
            compare(expected=tuple(expected_mapping.values())[0],
                    actual=tuple(actual_mapping.values())[0])
        compare(expected=expected_mapping, actual=actual_mapping)