File: test_websocket_handshake.py

package info (click to toggle)
python-aiohttp 1.2.0-1
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 2,288 kB
  • ctags: 4,380
  • sloc: python: 27,221; makefile: 236
file content (149 lines) | stat: -rw-r--r-- 4,944 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""Tests for http/websocket.py"""

import base64
import hashlib
import os
from unittest import mock

import multidict
import pytest

from aiohttp import errors, protocol
from aiohttp._ws_impl import WS_KEY, do_handshake


@pytest.fixture()
def transport():
    return mock.Mock()


@pytest.fixture()
def message():
    headers = multidict.MultiDict()
    return protocol.RawRequestMessage(
        'GET', '/path', (1, 0), headers, [], True, None)


def gen_ws_headers(protocols=''):
    key = base64.b64encode(os.urandom(16)).decode()
    hdrs = [('Upgrade', 'websocket'),
            ('Connection', 'upgrade'),
            ('Sec-Websocket-Version', '13'),
            ('Sec-Websocket-Key', key)]
    if protocols:
        hdrs += [('Sec-Websocket-Protocol', protocols)]
    return hdrs, key


def test_not_get(message, transport):
    with pytest.raises(errors.HttpProcessingError):
        do_handshake('POST', message.headers, transport)


def test_no_upgrade(message, transport):
    with pytest.raises(errors.HttpBadRequest):
        do_handshake(message.method, message.headers, transport)


def test_no_connection(message, transport):
    message.headers.extend([('Upgrade', 'websocket'),
                            ('Connection', 'keep-alive')])
    with pytest.raises(errors.HttpBadRequest):
        do_handshake(message.method, message.headers, transport)


def test_protocol_version(message, transport):
    message.headers.extend([('Upgrade', 'websocket'),
                            ('Connection', 'upgrade')])
    with pytest.raises(errors.HttpBadRequest):
        do_handshake(message.method, message.headers, transport)

    message.headers.extend([('Upgrade', 'websocket'),
                            ('Connection', 'upgrade'),
                            ('Sec-Websocket-Version', '1')])

    with pytest.raises(errors.HttpBadRequest):
        do_handshake(message.method, message.headers, transport)


def test_protocol_key(message, transport):
    message.headers.extend([('Upgrade', 'websocket'),
                            ('Connection', 'upgrade'),
                            ('Sec-Websocket-Version', '13')])
    with pytest.raises(errors.HttpBadRequest):
        do_handshake(message.method, message.headers, transport)

    message.headers.extend([('Upgrade', 'websocket'),
                            ('Connection', 'upgrade'),
                            ('Sec-Websocket-Version', '13'),
                            ('Sec-Websocket-Key', '123')])
    with pytest.raises(errors.HttpBadRequest):
        do_handshake(message.method, message.headers, transport)

    sec_key = base64.b64encode(os.urandom(2))
    message.headers.extend([('Upgrade', 'websocket'),
                            ('Connection', 'upgrade'),
                            ('Sec-Websocket-Version', '13'),
                            ('Sec-Websocket-Key', sec_key.decode())])
    with pytest.raises(errors.HttpBadRequest):
        do_handshake(message.method, message.headers, transport)


def test_handshake(message, transport):
    hdrs, sec_key = gen_ws_headers()

    message.headers.extend(hdrs)
    status, headers, parser, writer, protocol = do_handshake(
        message.method, message.headers, transport)
    assert status == 101
    assert protocol is None

    key = base64.b64encode(
        hashlib.sha1(sec_key.encode() + WS_KEY).digest())
    headers = dict(headers)
    assert headers['Sec-Websocket-Accept'] == key.decode()


def test_handshake_protocol(message, transport):
    '''Tests if one protocol is returned by do_handshake'''
    proto = 'chat'

    message.headers.extend(gen_ws_headers(proto)[0])
    _, resp_headers, _, _, protocol = do_handshake(
        message.method, message.headers, transport,
        protocols=[proto])

    assert protocol == proto

    # also test if we reply with the protocol
    resp_headers = dict(resp_headers)
    assert resp_headers['Sec-Websocket-Protocol'] == proto


def test_handshake_protocol_agreement(message, transport):
    '''Tests if the right protocol is selected given multiple'''
    best_proto = 'worse_proto'
    wanted_protos = ['best', 'chat', 'worse_proto']
    server_protos = 'worse_proto,chat'

    message.headers.extend(gen_ws_headers(server_protos)[0])
    _, resp_headers, _, _, protocol = do_handshake(
        message.method, message.headers, transport,
        protocols=wanted_protos)

    assert protocol == best_proto


def test_handshake_protocol_unsupported(log, message, transport):
    '''Tests if a protocol mismatch handshake warns and returns None'''
    proto = 'chat'
    message.headers.extend(gen_ws_headers('test')[0])

    with log('aiohttp.websocket') as ctx:
        _, _, _, _, protocol = do_handshake(
            message.method, message.headers, transport,
            protocols=[proto])

        assert protocol is None
    assert (ctx.records[-1].msg ==
            'Client protocols %r don’t overlap server-known ones %r')