1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
|
import io
import sys
import socket
import pytest
@pytest.fixture(scope='module')
def ssh_audit():
import ssh_audit.ssh_audit
return ssh_audit.ssh_audit
# pylint: disable=attribute-defined-outside-init
class _OutputSpy(list):
def begin(self):
self.__out = io.StringIO()
self.__old_stdout = sys.stdout
sys.stdout = self.__out
def flush(self):
lines = self.__out.getvalue().splitlines()
sys.stdout = self.__old_stdout
self.__out = None
return lines
@pytest.fixture(scope='module')
def output_spy():
return _OutputSpy()
class _VirtualGlobalSocket:
def __init__(self, vsocket):
self.vsocket = vsocket
self.addrinfodata = {}
# pylint: disable=unused-argument
def create_connection(self, address, timeout=0, source_address=None):
# pylint: disable=protected-access
return self.vsocket._connect(address, True)
# pylint: disable=unused-argument
def socket(self,
family=socket.AF_INET,
socktype=socket.SOCK_STREAM,
proto=0,
fileno=None):
return self.vsocket
def getaddrinfo(self, host, port, family=0, socktype=0, proto=0, flags=0):
key = '{}#{}'.format(host, port)
if key in self.addrinfodata:
data = self.addrinfodata[key]
if isinstance(data, Exception):
raise data
return data
if host == 'localhost':
r = []
if family in (0, socket.AF_INET):
r.append((socket.AF_INET, 1, 6, '', ('127.0.0.1', port)))
if family in (0, socket.AF_INET6):
r.append((socket.AF_INET6, 1, 6, '', ('::1', port)))
return r
return []
class _VirtualSocket:
def __init__(self):
self.sock_address = ('127.0.0.1', 0)
self.peer_address = None
self._connected = False
self.timeout = -1.0
self.rdata = []
self.sdata = []
self.errors = {}
self.blocking = False
self.gsock = _VirtualGlobalSocket(self)
def _check_err(self, method):
method_error = self.errors.get(method)
if method_error:
raise method_error
def connect(self, address):
return self._connect(address, False)
def connect_ex(self, address):
return self.connect(address)
def _connect(self, address, ret=True):
self.peer_address = address
self._connected = True
self._check_err('connect')
return self if ret else None
def setblocking(self, r: bool):
self.blocking = r
def settimeout(self, timeout):
self.timeout = timeout
def gettimeout(self):
return self.timeout
def getpeername(self):
if self.peer_address is None or not self._connected:
raise OSError(57, 'Socket is not connected')
return self.peer_address
def getsockname(self):
return self.sock_address
def bind(self, address):
self.sock_address = address
def listen(self, backlog):
pass
def accept(self):
# pylint: disable=protected-access
conn = _VirtualSocket()
conn.sock_address = self.sock_address
conn.peer_address = ('127.0.0.1', 0)
conn._connected = True
return conn, conn.peer_address
def recv(self, bufsize, flags=0):
# pylint: disable=unused-argument
if not self._connected:
raise OSError(54, 'Connection reset by peer')
if not len(self.rdata) > 0:
return b''
data = self.rdata.pop(0)
if isinstance(data, Exception):
raise data
return data
def send(self, data):
if self.peer_address is None or not self._connected:
raise OSError(32, 'Broken pipe')
self._check_err('send')
self.sdata.append(data)
@pytest.fixture()
def virtual_socket(monkeypatch):
vsocket = _VirtualSocket()
gsock = vsocket.gsock
monkeypatch.setattr(socket, 'create_connection', gsock.create_connection)
monkeypatch.setattr(socket, 'socket', gsock.socket)
monkeypatch.setattr(socket, 'getaddrinfo', gsock.getaddrinfo)
return vsocket
|