File: ext_csrf.py

package info (click to toggle)
wtforms 2.0.1-1
  • links: PTS, VCS
  • area: main
  • in suites: jessie, jessie-kfreebsd
  • size: 964 kB
  • ctags: 1,360
  • sloc: python: 5,163; makefile: 73
file content (126 lines) | stat: -rw-r--r-- 4,336 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from __future__ import unicode_literals

from unittest import TestCase

from wtforms.fields import TextField
from wtforms.ext.csrf import SecureForm
from wtforms.ext.csrf.session import SessionSecureForm
from tests.common import DummyPostData

import datetime
import hashlib
import hmac


class InsecureForm(SecureForm):
    def generate_csrf_token(self, csrf_context):
        return csrf_context

    a = TextField()


class FakeSessionRequest(object):
    def __init__(self, session):
        self.session = session


class StupidObject(object):
    a = None
    csrf_token = None


class SecureFormTest(TestCase):
    def test_base_class(self):
        self.assertRaises(NotImplementedError, SecureForm)

    def test_basic_impl(self):
        form = InsecureForm(csrf_context=42)
        self.assertEqual(form.csrf_token.current_token, 42)
        self.assertFalse(form.validate())
        self.assertEqual(len(form.csrf_token.errors), 1)
        self.assertEqual(form.csrf_token._value(), 42)
        # Make sure csrf_token is taken out from .data
        self.assertEqual(form.data, {'a': None})

    def test_with_data(self):
        post_data = DummyPostData(csrf_token='test', a='hi')
        form = InsecureForm(post_data, csrf_context='test')
        self.assertTrue(form.validate())
        self.assertEqual(form.data, {'a': 'hi'})

        form = InsecureForm(post_data, csrf_context='something')
        self.assertFalse(form.validate())

        # Make sure that value is still the current token despite
        # the posting of a different value
        self.assertEqual(form.csrf_token._value(), 'something')

        # Make sure populate_obj doesn't overwrite the token
        obj = StupidObject()
        form.populate_obj(obj)
        self.assertEqual(obj.a, 'hi')
        self.assertEqual(obj.csrf_token, None)

    def test_with_missing_token(self):
        post_data = DummyPostData(a='hi')
        form = InsecureForm(post_data, csrf_context='test')
        self.assertFalse(form.validate())

        self.assertEqual(form.csrf_token.data, '')
        self.assertEqual(form.csrf_token._value(), 'test')


class SessionSecureFormTest(TestCase):
    class SSF(SessionSecureForm):
        SECRET_KEY = 'abcdefghijklmnop'.encode('ascii')

    class BadTimeSSF(SessionSecureForm):
        SECRET_KEY = 'abcdefghijklmnop'.encode('ascii')
        TIME_LIMIT = datetime.timedelta(-1, 86300)

    class NoTimeSSF(SessionSecureForm):
        SECRET_KEY = 'abcdefghijklmnop'.encode('ascii')
        TIME_LIMIT = None

    def test_basic(self):
        self.assertRaises(Exception, SessionSecureForm)
        self.assertRaises(TypeError, self.SSF)
        session = {}
        form = self.SSF(csrf_context=FakeSessionRequest(session))
        assert 'csrf_token' in form
        assert 'csrf' in session

    def test_timestamped(self):
        session = {}
        postdata = DummyPostData(csrf_token='fake##fake')
        form = self.SSF(postdata, csrf_context=session)
        assert 'csrf' in session
        assert form.csrf_token._value()
        assert form.csrf_token._value() != session['csrf']
        assert not form.validate()
        self.assertEqual(form.csrf_token.errors[0], 'CSRF failed')
        # good_token = form.csrf_token._value()

        # Now test a valid CSRF with invalid timestamp
        evil_form = self.BadTimeSSF(csrf_context=session)
        bad_token = evil_form.csrf_token._value()

        postdata = DummyPostData(csrf_token=bad_token)
        form = self.SSF(postdata, csrf_context=session)
        assert not form.validate()
        self.assertEqual(form.csrf_token.errors[0], 'CSRF token expired')

    def test_notime(self):
        session = {}
        form = self.NoTimeSSF(csrf_context=session)
        hmacced = hmac.new(form.SECRET_KEY, session['csrf'].encode('utf8'), digestmod=hashlib.sha1)
        self.assertEqual(form.csrf_token._value(), '##%s' % hmacced.hexdigest())
        assert not form.validate()
        self.assertEqual(form.csrf_token.errors[0], 'CSRF token missing')

        # Test with pre-made values
        session = {'csrf': '00e9fa5fe507251ac5f32b1608e9282f75156a05'}
        postdata = DummyPostData(csrf_token='##d21f54b7dd2041fab5f8d644d4d3690c77beeb14')

        form = self.NoTimeSSF(postdata, csrf_context=session)
        assert form.validate()