File: test_mqtt.py

package info (click to toggle)
ospd-openvas 22.9.1-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,648 kB
  • sloc: python: 14,197; xml: 1,913; makefile: 45; sh: 29
file content (102 lines) | stat: -rw-r--r-- 3,027 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# -*- coding: utf-8 -*-
# SPDX-FileCopyrightText: 2021-2023 Greenbone AG
#
# SPDX-License-Identifier: AGPL-3.0-or-later

import time
from datetime import datetime
from uuid import UUID

from unittest import TestCase, mock

from ospd_openvas.messages.result import ResultMessage
from ospd_openvas.messaging.mqtt import (
    MQTTDaemon,
    MQTTPublisher,
    MQTTSubscriber,
)


class MQTTPublisherTestCase(TestCase):
    def test_publish(self):
        client = mock.MagicMock()
        publisher = MQTTPublisher(client)

        created = datetime.fromtimestamp(1628512774)
        message_id = UUID('63026767-029d-417e-9148-77f4da49f49a')
        group_id = UUID('866350e8-1492-497e-b12b-c079287d51dd')
        message = ResultMessage(
            created=created,
            message_id=message_id,
            group_id=group_id,
            scan_id='scan_1',
            host_ip='1.1.1.1',
            host_name='foo',
            oid='1.2.3.4.5',
            value='A Vulnerability has been found',
            port='42',
            uri='file://foo/bar',
        )

        publisher.publish(message)

        client.publish.assert_called_with(
            'scanner/scan/info',
            '{"message_id": "63026767-029d-417e-9148-77f4da49f49a", '
            '"message_type": "result.scan", '
            '"group_id": "866350e8-1492-497e-b12b-c079287d51dd", '
            '"created": 1628512774.0, '
            '"scan_id": "scan_1", '
            '"host_ip": "1.1.1.1", '
            '"host_name": "foo", '
            '"oid": "1.2.3.4.5", '
            '"value": "A Vulnerability has been found", '
            '"port": "42", '
            '"uri": "file://foo/bar", '
            '"result_type": "ALARM"}',
            qos=1,
        )


class MQTTSubscriberTestCase(TestCase):
    def test_subscribe(self):
        client = mock.MagicMock()
        callback = mock.MagicMock()
        callback.__name__ = "callback_name"

        subscriber = MQTTSubscriber(client)

        message = ResultMessage(
            scan_id='scan_1',
            host_ip='1.1.1.1',
            host_name='foo',
            oid='1.2.3.4.5',
            value='A Vulnerability has been found',
            uri='file://foo/bar',
        )

        subscriber.subscribe(message, callback)

        client.subscribe.assert_called_with('scanner/scan/info', qos=1)


class MQTTDaemonTestCase(TestCase):
    def test_connect(self):
        client = mock.MagicMock()

        # pylint: disable=unused-variable
        daemon = MQTTDaemon(client)

    def test_run(self):
        client = mock.MagicMock(side_effect=1)
        daemon = MQTTDaemon(client)
        t_ini = time.time()

        daemon.run()
        # In some systems the spawn of the thread can take longer than expected.
        # Therefore, we wait until the thread is spawned or times out.
        while len(client.mock_calls) == 0 and time.time() - t_ini < 10:
            time.sleep(1)

        client.connect.assert_called_with()
        client.loop_start.assert_called_with()