File: test_csrf.py

package info (click to toggle)
wtforms 3.2.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,064 kB
  • sloc: python: 5,264; makefile: 27; sh: 17
file content (178 lines) | stat: -rw-r--r-- 4,650 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import datetime
import hashlib
import hmac
from contextlib import contextmanager
from functools import partial

import pytest

from tests.common import DummyPostData
from wtforms.csrf.core import CSRF
from wtforms.csrf.session import SessionCSRF
from wtforms.fields import StringField
from wtforms.form import Form


class DummyCSRF(CSRF):
    def generate_csrf_token(self, csrf_token_field):
        return "dummytoken"


class TimePin(SessionCSRF):
    """
    CSRF with ability to pin times so that we can do a thorough test
    of expected values and keys.
    """

    pinned_time = None

    @classmethod
    @contextmanager
    def pin_time(cls, value):
        original = cls.pinned_time
        cls.pinned_time = value
        yield
        cls.pinned_time = original

    def now(self):
        return self.pinned_time


class SimplePopulateObject:
    a = None
    csrf_token = None


class F(Form):
    class Meta:
        csrf = True
        csrf_class = DummyCSRF

    a = StringField()


def test_dummy_csrf_base_class():
    with pytest.raises(NotImplementedError):
        F(meta={"csrf_class": CSRF})


def test_dummy_csrf_basic_impl():
    form = F()
    assert "csrf_token" in form
    assert not form.validate()
    assert form.csrf_token._value() == "dummytoken"
    form = F(DummyPostData(csrf_token="dummytoken"))
    assert form.validate()


def test_dummy_csrf_off():
    form = F(meta={"csrf": False})
    assert "csrf_token" not in form


def test_dummy_csrf_rename():
    form = F(meta={"csrf_field_name": "mycsrf"})
    assert "mycsrf" in form
    assert "csrf_token" not in form


def test_dummy_csrf_no_populate():
    obj = SimplePopulateObject()
    form = F(a="test", csrf_token="dummytoken")
    form.populate_obj(obj)
    assert obj.csrf_token is None
    assert obj.a == "test"


class H(Form):
    class Meta:
        csrf = True
        csrf_secret = b"foobar"

    a = StringField()


class NoTimeLimit(H):
    class Meta:
        csrf_time_limit = None


class Pinned(H):
    class Meta:
        csrf_class = TimePin


def test_session_csrf_various_failures():
    with pytest.raises(TypeError):
        H()
    with pytest.raises(
        Exception,
        match="must set `csrf_secret` on class Meta for SessionCSRF to work",
    ):
        H(meta={"csrf_secret": None})


def test_session_csrf_no_time_limit():
    session = {}
    form = _test_phase1(NoTimeLimit, session)
    expected_csrf = hmac.new(
        b"foobar", session["csrf"].encode("ascii"), digestmod=hashlib.sha1
    ).hexdigest()
    assert form.csrf_token.current_token == "##" + expected_csrf
    _test_phase2(NoTimeLimit, session, form.csrf_token.current_token)


def test_session_csrf_with_time_limit():
    session = {}
    form = _test_phase1(H, session)
    _test_phase2(H, session, form.csrf_token.current_token)


def test_session_csrf_detailed_expected_values():
    """
    A full test with the date and time pinned so we get deterministic output.
    """
    session = {"csrf": "93fed52fa69a2b2b0bf9c350c8aeeb408b6b6dfa"}
    dt = partial(datetime.datetime, 2013, 1, 15)
    with TimePin.pin_time(dt(8, 11, 12)):
        form = _test_phase1(Pinned, session)
        token = form.csrf_token.current_token
        assert token == "20130115084112##53812764d65abb8fa88384551a751ca590dff5fb"

    # Make sure that CSRF validates in a normal case.
    with TimePin.pin_time(dt(8, 18)):
        form = _test_phase2(Pinned, session, token)
        new_token = form.csrf_token.current_token
        assert new_token != token
        assert new_token == "20130115084800##e399e3a6a84860762723672b694134507ba21b58"

    # Make sure that CSRF fails when we're past time
    with TimePin.pin_time(dt(8, 43)):
        form = _test_phase2(Pinned, session, token, False)
        assert not form.validate()
        assert form.csrf_token.errors == ["CSRF token expired."]

        # We can succeed with a slightly newer token
        _test_phase2(Pinned, session, new_token)

    with TimePin.pin_time(dt(8, 44)):
        bad_token = "20130115084800##e399e3a6a84860762723672b694134507ba21b59"
        form = _test_phase2(Pinned, session, bad_token, False)
        assert not form.validate()


def _test_phase1(form_class, session):
    form = form_class(meta={"csrf_context": session})
    assert not form.validate()
    assert form.csrf_token.errors
    assert "csrf" in session
    return form


def _test_phase2(form_class, session, token, must_validate=True):
    form = form_class(
        formdata=DummyPostData(csrf_token=token), meta={"csrf_context": session}
    )
    if must_validate:
        assert form.validate()
    return form