File: test_websockets.py

package info (click to toggle)
python-paho-mqtt 2.1.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,248 kB
  • sloc: python: 8,765; sh: 48; makefile: 40
file content (142 lines) | stat: -rw-r--r-- 4,384 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import socket
from unittest.mock import Mock

import pytest
from paho.mqtt.client import WebsocketConnectionError, _WebsocketWrapper


class TestHeaders:
    """ Make sure headers are used correctly """

    @pytest.mark.parametrize("wargs,expected_sent", [
        (
            # HTTPS on non-default port
            {
                "host": "testhost.com",
                "port": 1234,
                "path": "/mqtt",
                "extra_headers": None,
                "is_ssl": True,
            },
            [
                "GET /mqtt HTTP/1.1",
                "Host: testhost.com:1234",
                "Upgrade: websocket",
                "Connection: Upgrade",
                "Sec-Websocket-Protocol: mqtt",
                "Sec-Websocket-Version: 13",
                "Origin: https://testhost.com:1234",
            ],
        ),
        (
            # HTTPS on default port
            {
                "host": "testhost.com",
                "port": 443,
                "path": "/mqtt",
                "extra_headers": None,
                "is_ssl": True,
            },
            [
                "GET /mqtt HTTP/1.1",
                "Host: testhost.com",
                "Upgrade: websocket",
                "Connection: Upgrade",
                "Sec-Websocket-Protocol: mqtt",
                "Sec-Websocket-Version: 13",
                "Origin: https://testhost.com",
            ],
        ),
        (
            # HTTP on default port
            {
                "host": "testhost.com",
                "port": 80,
                "path": "/mqtt",
                "extra_headers": None,
                "is_ssl": False,
            },
            [
                "GET /mqtt HTTP/1.1",
                "Host: testhost.com",
                "Upgrade: websocket",
                "Connection: Upgrade",
                "Sec-Websocket-Protocol: mqtt",
                "Sec-Websocket-Version: 13",
                "Origin: http://testhost.com",
            ],
        ),
        (
            # HTTP on non-default port
            {
                "host": "testhost.com",
                "port": 443,  # This isn't the default *HTTP* port. It's on purpose to use httpS port
                "path": "/mqtt",
                "extra_headers": None,
                "is_ssl": False,
            },
            [
                "GET /mqtt HTTP/1.1",
                "Host: testhost.com:443",
                "Upgrade: websocket",
                "Connection: Upgrade",
                "Sec-Websocket-Protocol: mqtt",
                "Sec-Websocket-Version: 13",
                "Origin: http://testhost.com:443",
            ],
        ),
    ])
    def test_normal_headers(self, wargs, expected_sent):
        """ Normal headers as specified in RFC 6455 """

        response = [
            "HTTP/1.1 101 Switching Protocols",
            "Upgrade: websocket",
            "Connection: Upgrade",
            "Sec-WebSocket-Accept: badreturnvalue=",
            "Sec-WebSocket-Protocol: chat",
            "\r\n",
        ]

        def iter_response():
            for i in "\r\n".join(response).encode("utf8"):
                yield i

            for i in b"\r\n":
                yield i

        it = iter_response()

        def fakerecv(*args):
            return bytes([next(it)])

        mocksock = Mock(
            spec_set=socket.socket,
            recv=fakerecv,
            send=Mock(),
        )

        # Do a copy to avoid modifying input
        wargs_with_socket = dict(wargs)
        wargs_with_socket["socket"] = mocksock

        with pytest.raises(WebsocketConnectionError) as exc:
            _WebsocketWrapper(**wargs_with_socket)

        # We're not creating the response hash properly so it should raise this
        # error
        assert str(exc.value) == "WebSocket handshake error, invalid secret key"

        # Only sends the header once
        assert mocksock.send.call_count == 1

        got_lines = mocksock.send.call_args[0][0].decode("utf8").splitlines()

        # First line must be the GET line
        # 2nd line is required to be Host (rfc9110 said that it SHOULD be first header)
        assert expected_sent[0] == got_lines[0]
        assert expected_sent[1] == got_lines[1]

        # Other line order don't matter
        for line in expected_sent:
            assert line in got_lines