File: test_base_sync_driver.py

package info (click to toggle)
python-scrapli 2023.7.30-5
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 4,536 kB
  • sloc: python: 14,459; makefile: 72
file content (163 lines) | stat: -rw-r--r-- 4,754 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from io import BytesIO

import pytest

from scrapli.driver.base.sync_driver import Driver
from scrapli.driver.core import IOSXRDriver
from scrapli.exceptions import ScrapliValueError


def test_async_transport_exception():
    """Assert we raise ScrapliValueError if an async transport is provided to the sync driver"""
    with pytest.raises(ScrapliValueError):
        Driver(host="localhost", transport="asynctelnet")


@pytest.mark.parametrize(
    "test_data",
    (True, False),
    ids=(
        "on_init",
        "no_on_init",
    ),
)
def test_on_init(test_data):
    """Assert on init method is executed at end of driver initialization (if provided)"""
    test_on_init = test_data
    on_init_called = False

    def _on_init(cls):
        nonlocal on_init_called
        on_init_called = True

    Driver(host="localhost", on_init=_on_init if test_on_init else None)

    if test_on_init:
        assert on_init_called is True
    else:
        assert on_init_called is False


def test_context_manager(monkeypatch):
    """Asserts context manager properly opens/closes"""
    channel_ssh_auth_called = False

    def _channel_ssh_auth(cls, auth_password, auth_private_key_passphrase):
        nonlocal channel_ssh_auth_called
        channel_ssh_auth_called = True

    monkeypatch.setattr(
        "scrapli.transport.plugins.system.transport.SystemTransport.open", lambda x: None
    )
    monkeypatch.setattr(
        "scrapli.channel.sync_channel.Channel.channel_authenticate_ssh", _channel_ssh_auth
    )

    with Driver(host="localhost") as conn:
        pass

    assert channel_ssh_auth_called is True


def test_open_ssh_channel_auth(monkeypatch, sync_driver):
    """Test patched ssh channel auth -- asserts methods get called where they should"""
    on_open_called = False
    channel_ssh_auth_called = False

    def _on_open(cls):
        nonlocal on_open_called
        on_open_called = True

    def _channel_ssh_auth(cls, auth_password, auth_private_key_passphrase):
        nonlocal channel_ssh_auth_called
        channel_ssh_auth_called = True

    sync_driver.on_open = _on_open

    monkeypatch.setattr(
        "scrapli.transport.plugins.system.transport.SystemTransport.open", lambda x: None
    )
    monkeypatch.setattr(
        "scrapli.channel.sync_channel.Channel.channel_authenticate_ssh", _channel_ssh_auth
    )

    sync_driver.open()

    assert on_open_called is True
    assert channel_ssh_auth_called is True


def test_open_telnet_channel_auth(monkeypatch, sync_driver_telnet):
    """Test patched telnet channel auth -- asserts methods get called where they should"""
    on_open_called = False
    channel_telnet_auth_called = False

    def _on_open(cls):
        nonlocal on_open_called
        on_open_called = True

    def _channel_telnet_auth(cls, auth_username, auth_password):
        nonlocal channel_telnet_auth_called
        channel_telnet_auth_called = True

    sync_driver_telnet.on_open = _on_open

    monkeypatch.setattr(
        "scrapli.transport.plugins.telnet.transport.TelnetTransport.open", lambda x: None
    )
    monkeypatch.setattr(
        "scrapli.channel.sync_channel.Channel.channel_authenticate_telnet", _channel_telnet_auth
    )

    sync_driver_telnet.open()

    assert on_open_called is True
    assert channel_telnet_auth_called is True


def test_close(sync_driver):
    """
    Test unit-testable driver close

    Asserts on_close gets called and channel log gets closed
    """
    on_close_called = False

    def _on_close(cls):
        nonlocal on_close_called
        on_close_called = True

    sync_driver.on_close = _on_close
    sync_driver.channel.channel_log = BytesIO()
    assert sync_driver.channel.channel_log.closed is False

    # close will basically do nothing as no transport is open, so no need to mock/patch
    sync_driver.close()

    assert on_close_called is True
    assert sync_driver.channel.channel_log.closed is True


def test_commandeer(sync_driver):
    """
    Test commandeer works as expected
    """
    on_open_called = False

    def on_open(cls):
        nonlocal on_open_called
        on_open_called = True

    channel_log_dummy = BytesIO()
    sync_driver.channel.channel_log = channel_log_dummy

    new_conn = IOSXRDriver(host="tacocat", on_open=on_open)
    new_conn.commandeer(sync_driver, execute_on_open=True)

    assert on_open_called is True
    assert new_conn.transport is sync_driver.transport
    assert new_conn.channel.transport is sync_driver.transport
    assert new_conn.logger is sync_driver.logger
    assert new_conn.transport.logger is sync_driver.transport.logger
    assert new_conn.channel.logger is sync_driver.channel.logger
    assert new_conn.channel.channel_log is channel_log_dummy