File: test_unix_socket_connection.py

package info (click to toggle)
python-gvm 26.5.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 5,132 kB
  • sloc: python: 44,662; makefile: 18
file content (114 lines) | stat: -rw-r--r-- 3,888 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# SPDX-FileCopyrightText: 2018-2024 Greenbone AG
#
# SPDX-License-Identifier: GPL-3.0-or-later
#

import socketserver
import tempfile
import threading
import unittest
import uuid
from pathlib import Path
from unittest.mock import patch

from gvm.connections import (
    DEFAULT_TIMEOUT,
    DEFAULT_UNIX_SOCKET_PATH,
    GvmConnection,
    UnixSocketConnection,
)
from gvm.errors import GvmError


class DummyRequestHandler(socketserver.BaseRequestHandler):
    def handle(self):
        response = b'<gmp_response status="200" status_text="OK"/>'
        self.request.sendall(response)


class ThreadedUnixStreamServer(
    socketserver.ThreadingMixIn, socketserver.UnixStreamServer
):
    pass


class UnixSocketConnectionTestCase(unittest.TestCase):
    # pylint: disable=protected-access, invalid-name
    def setUp(self):
        self.socket_name = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.sock"
        self.socket_path = Path(self.socket_name)
        self.socket_server = ThreadedUnixStreamServer(
            self.socket_name, DummyRequestHandler
        )
        self.server_thread = threading.Thread(
            target=self.socket_server.serve_forever
        )
        self.server_thread.daemon = True
        self.server_thread.start()

    def tearDown(self):
        self.socket_server.server_close()
        self.socket_server.shutdown()
        self.server_thread.join(60.0)
        self.socket_path.unlink(missing_ok=True)

    def test_unix_socket_connection_connect_read(self):
        connection = UnixSocketConnection(
            path=self.socket_name, timeout=DEFAULT_TIMEOUT
        )
        connection.connect()
        resp = connection.read()
        self.assertEqual(resp, b'<gmp_response status="200" status_text="OK"/>')
        connection.disconnect()

    def test_unix_socket_connection_connect_send_bytes_read(self):
        connection = UnixSocketConnection(
            path=self.socket_name, timeout=DEFAULT_TIMEOUT
        )
        connection.connect()
        connection.send(b"<gmp/>")
        resp = connection.read()
        self.assertEqual(resp, b'<gmp_response status="200" status_text="OK"/>')
        connection.disconnect()

    def test_unix_socket_connect_file_not_found(self):
        connection = UnixSocketConnection(path="foo", timeout=DEFAULT_TIMEOUT)
        with self.assertRaises(GvmError, msg="Socket foo does not exist"):
            connection.connect()
        connection.disconnect()

    def test_unix_socket_connect_could_not_connect(self):
        connection = UnixSocketConnection(
            path=self.socket_name, timeout=DEFAULT_TIMEOUT
        )
        with patch("socket.socket.connect") as ConnectMock:
            connect_mock = ConnectMock
            connect_mock.side_effect = ConnectionError
            with self.assertRaises(
                GvmError, msg=f"Could not connect to socket {self.socket_name}"
            ):
                connection.connect()
            connection.disconnect()

    def test_unix_socket_send_unconnected_socket(self):
        connection = UnixSocketConnection(
            path=self.socket_name, timeout=DEFAULT_TIMEOUT
        )
        with self.assertRaises(GvmError, msg="Socket is not connected"):
            connection.send("<gmp>/")

    def test_init_no_args(self):
        connection = UnixSocketConnection()
        self.check_default_values(connection)

    def test_init_with_none(self):
        connection = UnixSocketConnection(path=None, timeout=None)
        self.check_default_values(connection)

    def check_default_values(self, connection: UnixSocketConnection):
        self.assertEqual(connection._timeout, DEFAULT_TIMEOUT)
        self.assertEqual(connection.path, DEFAULT_UNIX_SOCKET_PATH)

    def test_is_gvm_connection(self):
        connection = UnixSocketConnection()
        self.assertTrue(isinstance(connection, GvmConnection))