File: csrf.py

package info (click to toggle)
wtforms 2.0.1-1
  • links: PTS, VCS
  • area: main
  • in suites: jessie, jessie-kfreebsd
  • size: 964 kB
  • ctags: 1,360
  • sloc: python: 5,163; makefile: 73
file content (165 lines) | stat: -rw-r--r-- 5,239 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from __future__ import unicode_literals

from contextlib import contextmanager
from functools import partial
from unittest import TestCase

from wtforms.fields import TextField
from wtforms.form import Form
from wtforms.csrf.core import CSRF
from wtforms.csrf.session import SessionCSRF
from tests.common import DummyPostData

import datetime
import hashlib
import hmac


class DummyCSRF(CSRF):
    def generate_csrf_token(self, csrf_token_field):
        return 'dummytoken'


class FakeSessionRequest(object):
    def __init__(self, session):
        self.session = session


class TimePin(SessionCSRF):
    """
    CSRF with ability to pin times so that we can do a thorough test
    of expected values and keys.
    """
    pinned_time = None

    @classmethod
    @contextmanager
    def pin_time(cls, value):
        original = cls.pinned_time
        cls.pinned_time = value
        yield
        cls.pinned_time = original

    def now(self):
        return self.pinned_time


class SimplePopulateObject(object):
    a = None
    csrf_token = None


class DummyCSRFTest(TestCase):
    class F(Form):
        class Meta:
            csrf = True
            csrf_class = DummyCSRF
        a = TextField()

    def test_base_class(self):
        self.assertRaises(NotImplementedError, self.F, meta={'csrf_class': CSRF})

    def test_basic_impl(self):
        form = self.F()
        assert 'csrf_token' in form
        assert not form.validate()
        self.assertEqual(form.csrf_token._value(), 'dummytoken')
        form = self.F(DummyPostData(csrf_token='dummytoken'))
        assert form.validate()

    def test_csrf_off(self):
        form = self.F(meta={'csrf': False})
        assert 'csrf_token' not in form

    def test_rename(self):
        form = self.F(meta={'csrf_field_name': 'mycsrf'})
        assert 'mycsrf' in form
        assert 'csrf_token' not in form

    def test_no_populate(self):
        obj = SimplePopulateObject()
        form = self.F(a='test', csrf_token='dummytoken')
        form.populate_obj(obj)
        assert obj.csrf_token is None
        self.assertEqual(obj.a, 'test')


class SessionCSRFTest(TestCase):
    class F(Form):
        class Meta:
            csrf = True
            csrf_secret = b'foobar'

        a = TextField()

    class NoTimeLimit(F):
        class Meta:
            csrf_time_limit = None

    class Pinned(F):
        class Meta:
            csrf_class = TimePin

    def test_various_failures(self):
        self.assertRaises(TypeError, self.F)
        self.assertRaises(Exception, self.F, meta={'csrf_secret': None})

    def test_no_time_limit(self):
        session = {}
        form = self._test_phase1(self.NoTimeLimit, session)
        expected_csrf = hmac.new(b'foobar', session['csrf'].encode('ascii'), digestmod=hashlib.sha1).hexdigest()
        self.assertEqual(form.csrf_token.current_token, '##' + expected_csrf)
        self._test_phase2(self.NoTimeLimit, session, form.csrf_token.current_token)

    def test_with_time_limit(self):
        session = {}
        form = self._test_phase1(self.F, session)
        self._test_phase2(self.F, session, form.csrf_token.current_token)

    def test_detailed_expected_values(self):
        """
        A full test with the date and time pinned so we get deterministic output.
        """
        session = {'csrf': '93fed52fa69a2b2b0bf9c350c8aeeb408b6b6dfa'}
        dt = partial(datetime.datetime, 2013, 1, 15)
        with TimePin.pin_time(dt(8, 11, 12)):
            form = self._test_phase1(self.Pinned, session)
            token = form.csrf_token.current_token
            self.assertEqual(token, '20130115084112##53812764d65abb8fa88384551a751ca590dff5fb')

        # Make sure that CSRF validates in a normal case.
        with TimePin.pin_time(dt(8, 18)):
            form = self._test_phase2(self.Pinned, session, token)
            new_token = form.csrf_token.current_token
            self.assertNotEqual(new_token, token)
            self.assertEqual(new_token, '20130115084800##e399e3a6a84860762723672b694134507ba21b58')

        # Make sure that CSRF fails when we're past time
        with TimePin.pin_time(dt(8, 43)):
            form = self._test_phase2(self.Pinned, session, token, False)
            assert not form.validate()
            self.assertEqual(form.csrf_token.errors, ['CSRF token expired'])

            # We can succeed with a slightly newer token
            self._test_phase2(self.Pinned, session, new_token)

        with TimePin.pin_time(dt(8, 44)):
            bad_token = '20130115084800##e399e3a6a84860762723672b694134507ba21b59'
            form = self._test_phase2(self.Pinned, session, bad_token, False)
            assert not form.validate()

    def _test_phase1(self, form_class, session):
        form = form_class(meta={'csrf_context': session})
        assert not form.validate()
        assert form.csrf_token.errors
        assert 'csrf' in session
        return form

    def _test_phase2(self, form_class, session, token, must_validate=True):
        form = form_class(
            formdata=DummyPostData(csrf_token=token),
            meta={'csrf_context': session}
        )
        if must_validate:
            assert form.validate()
        return form