1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
|
import os
import sys
from io import StringIO
from tempfile import TemporaryFile
from testfixtures.comparison import compare
class OutputCapture:
"""
A context manager for capturing output to the
:any:`sys.stdout` and :any:`sys.stderr` streams.
:param separate: If ``True``, ``stdout`` and ``stderr`` will be captured
separately and their expected values must be passed to
:meth:`~OutputCapture.compare`.
:param fd: If ``True``, the underlying file descriptors will be captured,
rather than just the attributes on :mod:`sys`. This allows
you to capture things like subprocesses that write directly
to the file descriptors, but is more invasive, so only use it
when you need it.
:param strip_whitespace:
When ``True``, which is the default, leading and training whitespace
is trimmed from both the expected and actual values when comparing.
.. note:: If ``separate`` is passed as ``True``,
:attr:`OutputCapture.captured` will be an empty string.
"""
original_stdout = None
original_stderr = None
def __init__(self, separate: bool = False, fd: bool = False, strip_whitespace: bool = True):
self.separate = separate
self.fd = fd
self.strip_whitespace = strip_whitespace
def __enter__(self):
if self.fd:
self.output = TemporaryFile()
self.stdout = TemporaryFile()
self.stderr = TemporaryFile()
else:
self.output = StringIO()
self.stdout = StringIO()
self.stderr = StringIO()
self.enable()
return self
def __exit__(self, *args):
self.disable()
def disable(self):
"Disable the output capture if it is enabled."
if self.fd:
for original, current in (
(self.original_stdout, sys.stdout),
(self.original_stderr, sys.stderr),
):
os.dup2(original, current.fileno())
os.close(original)
else:
sys.stdout = self.original_stdout
sys.stderr = self.original_stderr
def enable(self):
"Enable the output capture if it is disabled."
if self.original_stdout is None:
if self.fd:
self.original_stdout = os.dup(sys.stdout.fileno())
self.original_stderr = os.dup(sys.stderr.fileno())
else:
self.original_stdout = sys.stdout
self.original_stderr = sys.stderr
if self.separate:
if self.fd:
os.dup2(self.stdout.fileno(), sys.stdout.fileno())
os.dup2(self.stderr.fileno(), sys.stderr.fileno())
else:
sys.stdout = self.stdout
sys.stderr = self.stderr
else:
if self.fd:
os.dup2(self.output.fileno(), sys.stdout.fileno())
os.dup2(self.output.fileno(), sys.stderr.fileno())
else:
sys.stdout = sys.stderr = self.output
def _read(self, stream):
if self.fd:
stream.seek(0)
return stream.read().decode()
else:
return stream.getvalue()
@property
def captured(self) -> str:
"A property containing any output that has been captured so far."
return self._read(self.output)
def compare(self, expected: str = '', stdout: str = '', stderr: str = ''):
"""
Compare the captured output to that expected. If the output is
not the same, an :class:`AssertionError` will be raised.
:param expected: A string containing the expected combined output
of ``stdout`` and ``stderr``.
:param stdout: A string containing the expected output to ``stdout``.
:param stderr: A string containing the expected output to ``stderr``.
"""
expected_mapping = {}
actual_mapping = {}
for prefix, _expected, captured in (
('captured', expected, self.captured),
('stdout', stdout, self._read(self.stdout)),
('stderr', stderr, self._read(self.stderr)),
):
if self.strip_whitespace:
_expected = _expected.strip()
captured = captured.strip()
if _expected != captured:
expected_mapping[prefix] = _expected
actual_mapping[prefix] = captured
if len(expected_mapping) == 1:
compare(expected=tuple(expected_mapping.values())[0],
actual=tuple(actual_mapping.values())[0])
compare(expected=expected_mapping, actual=actual_mapping)
|