File: test_protected_endpoints.py

package info (click to toggle)
flask-jwt-simple 0.0.3-14
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 288 kB
  • sloc: python: 772; makefile: 194; sh: 6
file content (229 lines) | stat: -rw-r--r-- 7,217 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import pytest
import datetime
from flask import Flask, jsonify, json

from flask_jwt_simple.utils import get_jwt_identity, create_jwt
from flask_jwt_simple import JWTManager, jwt_required, jwt_optional


RSA_PRIVATE = """
-----BEGIN RSA PRIVATE KEY-----
MIICXgIBAAKBgQDN+p9a9oMyqRzkae8yLdJcEK0O0WesH6JiMz+KDrpUwAoAM/KP
DnxFnROJDSBHyHEmPVn5x8GqV5lQ9+6l97jdEEcPo6wkshycM82fgcxOmvtAy4Uo
xq/AeplYqplhcUTGVuo4ZldOLmN8ksGmzhWpsOdT0bkYipHCn5sWZxd21QIDAQAB
AoGBAMJ0++KVXXEDZMpjFDWsOq898xNNMHG3/8ZzmWXN161RC1/7qt/RjhLuYtX9
NV9vZRrzyrDcHAKj5pMhLgUzpColKzvdG2vKCldUs2b0c8HEGmjsmpmgoI1Tdf9D
G1QK+q9pKHlbj/MLr4vZPX6xEwAFeqRKlzL30JPD+O6mOXs1AkEA8UDzfadH1Y+H
bcNN2COvCqzqJMwLNRMXHDmUsjHfR2gtzk6D5dDyEaL+O4FLiQCaNXGWWoDTy/HJ
Clh1Z0+KYwJBANqRtJ+RvdgHMq0Yd45MMyy0ODGr1B3PoRbUK8EdXpyUNMi1g3iJ
tXMbLywNkTfcEXZTlbbkVYwrEl6P2N1r42cCQQDb9UQLBEFSTRJE2RRYQ/CL4yt3
cTGmqkkfyr/v19ii2jEpMBzBo8eQnPL+fdvIhWwT3gQfb+WqxD9v10bzcmnRAkEA
mzTgeHd7wg3KdJRtQYTmyhXn2Y3VAJ5SG+3qbCW466NqoCQVCeFwEh75rmSr/Giv
lcDhDZCzFuf3EWNAcmuMfQJARsWfM6q7v2p6vkYLLJ7+VvIwookkr6wymF5Zgb9d
E6oTM2EeUPSyyrj5IdsU2JCNBH1m3JnUflz8p8/NYCoOZg==
-----END RSA PRIVATE KEY-----
"""


RSA_PUBLIC = """
-----BEGIN RSA PUBLIC KEY-----
MIGJAoGBAM36n1r2gzKpHORp7zIt0lwQrQ7RZ6wfomIzP4oOulTACgAz8o8OfEWd
E4kNIEfIcSY9WfnHwapXmVD37qX3uN0QRw+jrCSyHJwzzZ+BzE6a+0DLhSjGr8B6
mViqmWFxRMZW6jhmV04uY3ySwabOFamw51PRuRiKkcKfmxZnF3bVAgMBAAE=
-----END RSA PUBLIC KEY-----
"""

# Slightly modifed version of above to test invalid jwts
BAD_RSA_PUBLIC = """
-----BEGIN RSA PUBLIC KEY-----
MIGJAoGBAM36n1r2gzKpHORp8zIt0lwQrQ7RZ6wfomIzP4oOulTACgAz8o8OfEWd
E4kNIEfIcSY9WfnHwapXmVD37qX3uN0QRw+jrCSyHJwzzZ+BzE6a+0DLhSjGr8B6
mViqmWFxRMZW6jhmV04uY3ySwabOFamw51PRuRiKkcKfmxZnF3bVAgMBAAE=
-----END RSA PUBLIC KEY-----
"""


def cartesian_product_configs():
    jwt_identity_claims = ['identity', 'sub']

    configs = []
    for identity in jwt_identity_claims:
        configs.append({
            'JWT_SECRET_KEY': 'testing_secret_key',
            'JWT_ALGORITHM': 'HS256',
            'JWT_IDENTITY_CLAIM': identity
        })
        configs.append({
            'JWT_PUBLIC_KEY': RSA_PUBLIC,
            'JWT_PRIVATE_KEY': RSA_PRIVATE,
            'JWT_ALGORITHM': 'RS256',
            'JWT_IDENTITY_CLAIM': identity
        })
    return configs


CONFIG_COMBINATIONS = cartesian_product_configs()


@pytest.fixture(scope='function', params=CONFIG_COMBINATIONS)
def app(request):
    app = Flask(__name__)

    for key, value in request.param.items():
        app.config[key] = value

    JWTManager(app)

    @app.route('/jwt', methods=['POST'])
    def fresh_access_jwt():
        access_token = create_jwt('username')
        return jsonify(jwt=access_token)

    @app.route('/protected')
    @jwt_required
    def protected():
        return jsonify(foo='bar')

    @app.route('/optional')
    @jwt_optional
    def optional():
        if get_jwt_identity():
            return jsonify(foo='baz')
        else:
            return jsonify(foo='bar')

    return app


def _make_jwt_request(test_client, jwt, request_url):
    app = test_client.application
    header_name = app.config['JWT_HEADER_NAME']
    header_type = app.config['JWT_HEADER_TYPE']
    return test_client.get(
        request_url,
        content_type='application/json',
        headers={header_name: '{} {}'.format(header_type, jwt).strip()}
    )


def _get_jwt(test_client):
    response = test_client.post('/jwt')
    json_data = json.loads(response.get_data(as_text=True))
    assert response.status_code == 200
    assert 'jwt' in json_data
    return json_data['jwt']


def test_protected_without_jwt(app):
    test_client = app.test_client()
    response = test_client.get('/protected')
    json_data = json.loads(response.get_data(as_text=True))

    assert response.status_code == 401
    assert json_data == {'msg': 'Missing Authorization Header'}


def test_protected_with_jwt(app):
    test_client = app.test_client()
    jwt = _get_jwt(test_client)
    response = _make_jwt_request(test_client, jwt, '/protected')
    json_data = json.loads(response.get_data(as_text=True))

    assert response.status_code == 200
    assert json_data == {'foo': 'bar'}


def test_optional_without_jwt(app):
    test_client = app.test_client()
    response = test_client.get('/optional')
    json_data = json.loads(response.get_data(as_text=True))

    assert response.status_code == 200
    assert json_data == {'foo': 'bar'}


def test_optional_with_jwt(app):
    test_client = app.test_client()
    jwt = _get_jwt(test_client)
    response = _make_jwt_request(test_client, jwt, '/optional')
    json_data = json.loads(response.get_data(as_text=True))

    assert response.status_code == 200
    assert json_data == {'foo': 'baz'}


@pytest.mark.parametrize("header_name", ['Authorization', 'Foo'])
@pytest.mark.parametrize("header_type", ['Bearer', 'JWT', ''])
def test_with_custom_headers(app, header_name, header_type):
    app.config['JWT_HEADER_NAME'] = header_name
    app.config['JWT_HEADER_TYPE'] = header_type

    test_client = app.test_client()
    jwt = _get_jwt(test_client)
    response = _make_jwt_request(test_client, jwt, '/protected')
    json_data = json.loads(response.get_data(as_text=True))

    assert response.status_code == 200
    assert json_data == {'foo': 'bar'}


@pytest.mark.parametrize("endpoint", [
    '/protected',
    '/optional',
])
@pytest.mark.parametrize("header_type", ['Foo', ''])
def test_with_bad_header(app, endpoint, header_type):
    app.config['JWT_HEADER_TYPE'] = header_type

    test_client = app.test_client()
    jwt = _get_jwt(test_client)

    headers = {'Authorization': 'Bearer {}'.format(jwt)}
    response = test_client.get(
        endpoint,
        content_type='application/json',
        headers=headers
    )
    json_data = json.loads(response.get_data(as_text=True))

    expected_results = (
        (422, {'msg': "Bad Authorization header. Expected value '<JWT>'"}),
        (422, {'msg': "Bad Authorization header. Expected value 'Foo <JWT>'"}),
        (200, {'foo': "bar"})  # Returns this if unauthorized in jwt_optional test endpoint
    )
    assert (response.status_code, json_data) in expected_results


@pytest.mark.parametrize("endpoint", [
    '/protected',
    '/optional',
])
def test_with_bad_token(app, endpoint):
    test_client = app.test_client()
    jwt = _get_jwt(test_client)

    # change teh secret key here to make the token we just got invalid
    app.config['JWT_SECRET_KEY'] = 'something_different'
    app.config['JWT_PUBLIC_KEY'] = BAD_RSA_PUBLIC

    response = _make_jwt_request(test_client, jwt, endpoint)
    json_data = json.loads(response.get_data(as_text=True))

    assert json_data == {'msg': 'Signature verification failed'}
    assert response.status_code == 422


@pytest.mark.parametrize("endpoint", [
    '/protected',
    '/optional',
])
def test_expired_token(app, endpoint):
    app.config['JWT_EXPIRES'] = datetime.timedelta(hours=-1)

    test_client = app.test_client()
    jwt = _get_jwt(test_client)
    response = _make_jwt_request(test_client, jwt, endpoint)
    json_data = json.loads(response.get_data(as_text=True))

    assert json_data == {'msg': 'Token has expired'}
    assert response.status_code == 401