1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
|
from io import BytesIO
import pytest
from scrapli.driver.base.async_driver import AsyncDriver
from scrapli.driver.core import AsyncIOSXRDriver
from scrapli.exceptions import ScrapliValueError
def test_sync_transport_exception():
"""Assert we raise ScrapliValueError if a sync transport is provided to the async driver"""
with pytest.raises(ScrapliValueError):
AsyncDriver(host="localhost", transport="system")
@pytest.mark.parametrize(
"test_data",
(True, False),
ids=(
"on_init",
"no_on_init",
),
)
def test_on_init(test_data):
"""Assert on init method is executed at end of driver initialization (if provided)"""
test_on_init = test_data
on_init_called = False
def _on_init(cls):
nonlocal on_init_called
on_init_called = True
AsyncDriver(
host="localhost", transport="asynctelnet", on_init=_on_init if test_on_init else None
)
if test_on_init:
assert on_init_called is True
else:
assert on_init_called is False
async def test_context_manager(monkeypatch):
"""Asserts context manager properly opens/closes"""
channel_telnet_auth_called = False
async def _open(cls):
pass
async def _channel_telnet_auth(cls, auth_username, auth_password):
nonlocal channel_telnet_auth_called
channel_telnet_auth_called = True
monkeypatch.setattr(
"scrapli.transport.plugins.asynctelnet.transport.AsynctelnetTransport.open", _open
)
monkeypatch.setattr(
"scrapli.channel.async_channel.AsyncChannel.channel_authenticate_telnet",
_channel_telnet_auth,
)
async with AsyncDriver(host="localhost", transport="asynctelnet") as conn:
pass
assert channel_telnet_auth_called is True
async def test_open_telnet_channel_auth(monkeypatch, async_driver):
"""Test patched telnet channel auth -- asserts methods get called where they should"""
on_open_called = False
channel_telnet_auth_called = False
async def _on_open(cls):
nonlocal on_open_called
on_open_called = True
async def _open(cls):
pass
async def _channel_telnet_auth(cls, auth_username, auth_password):
nonlocal channel_telnet_auth_called
channel_telnet_auth_called = True
async_driver.on_open = _on_open
monkeypatch.setattr(
"scrapli.transport.plugins.asynctelnet.transport.AsynctelnetTransport.open", _open
)
monkeypatch.setattr(
"scrapli.channel.async_channel.AsyncChannel.channel_authenticate_telnet",
_channel_telnet_auth,
)
await async_driver.open()
assert on_open_called is True
assert channel_telnet_auth_called is True
async def test_close(async_driver):
"""
Test unit-testable driver close
Asserts on_close gets called and channel log gets closed
"""
on_close_called = False
async def _on_close(cls):
nonlocal on_close_called
on_close_called = True
async_driver.on_close = _on_close
async_driver.channel.channel_log = BytesIO()
assert async_driver.channel.channel_log.closed is False
# close will basically do nothing as no transport is open, so no need to mock/patch
await async_driver.close()
assert on_close_called is True
assert async_driver.channel.channel_log.closed is True
async def test_commandeer(async_driver):
"""
Test commandeer works as expected
"""
on_open_called = False
async def on_open(cls):
nonlocal on_open_called
on_open_called = True
channel_log_dummy = BytesIO()
async_driver.channel.channel_log = channel_log_dummy
new_conn = AsyncIOSXRDriver(host="tacocat", on_open=on_open, transport="asyncssh")
await new_conn.commandeer(async_driver, execute_on_open=True)
assert on_open_called is True
assert new_conn.transport is async_driver.transport
assert new_conn.channel.transport is async_driver.transport
assert new_conn.logger is async_driver.logger
assert new_conn.transport.logger is async_driver.transport.logger
assert new_conn.channel.logger is async_driver.channel.logger
assert new_conn.channel.channel_log is channel_log_dummy
|