1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
|
from __future__ import unicode_literals
from unittest import TestCase
from wtforms.fields import TextField
from wtforms.ext.csrf import SecureForm
from wtforms.ext.csrf.session import SessionSecureForm
from tests.common import DummyPostData
import datetime
import hashlib
import hmac
class InsecureForm(SecureForm):
def generate_csrf_token(self, csrf_context):
return csrf_context
a = TextField()
class FakeSessionRequest(object):
def __init__(self, session):
self.session = session
class StupidObject(object):
a = None
csrf_token = None
class SecureFormTest(TestCase):
def test_base_class(self):
self.assertRaises(NotImplementedError, SecureForm)
def test_basic_impl(self):
form = InsecureForm(csrf_context=42)
self.assertEqual(form.csrf_token.current_token, 42)
self.assertFalse(form.validate())
self.assertEqual(len(form.csrf_token.errors), 1)
self.assertEqual(form.csrf_token._value(), 42)
# Make sure csrf_token is taken out from .data
self.assertEqual(form.data, {'a': None})
def test_with_data(self):
post_data = DummyPostData(csrf_token='test', a='hi')
form = InsecureForm(post_data, csrf_context='test')
self.assertTrue(form.validate())
self.assertEqual(form.data, {'a': 'hi'})
form = InsecureForm(post_data, csrf_context='something')
self.assertFalse(form.validate())
# Make sure that value is still the current token despite
# the posting of a different value
self.assertEqual(form.csrf_token._value(), 'something')
# Make sure populate_obj doesn't overwrite the token
obj = StupidObject()
form.populate_obj(obj)
self.assertEqual(obj.a, 'hi')
self.assertEqual(obj.csrf_token, None)
def test_with_missing_token(self):
post_data = DummyPostData(a='hi')
form = InsecureForm(post_data, csrf_context='test')
self.assertFalse(form.validate())
self.assertEqual(form.csrf_token.data, '')
self.assertEqual(form.csrf_token._value(), 'test')
class SessionSecureFormTest(TestCase):
class SSF(SessionSecureForm):
SECRET_KEY = 'abcdefghijklmnop'.encode('ascii')
class BadTimeSSF(SessionSecureForm):
SECRET_KEY = 'abcdefghijklmnop'.encode('ascii')
TIME_LIMIT = datetime.timedelta(-1, 86300)
class NoTimeSSF(SessionSecureForm):
SECRET_KEY = 'abcdefghijklmnop'.encode('ascii')
TIME_LIMIT = None
def test_basic(self):
self.assertRaises(Exception, SessionSecureForm)
self.assertRaises(TypeError, self.SSF)
session = {}
form = self.SSF(csrf_context=FakeSessionRequest(session))
assert 'csrf_token' in form
assert 'csrf' in session
def test_timestamped(self):
session = {}
postdata = DummyPostData(csrf_token='fake##fake')
form = self.SSF(postdata, csrf_context=session)
assert 'csrf' in session
assert form.csrf_token._value()
assert form.csrf_token._value() != session['csrf']
assert not form.validate()
self.assertEqual(form.csrf_token.errors[0], 'CSRF failed')
# good_token = form.csrf_token._value()
# Now test a valid CSRF with invalid timestamp
evil_form = self.BadTimeSSF(csrf_context=session)
bad_token = evil_form.csrf_token._value()
postdata = DummyPostData(csrf_token=bad_token)
form = self.SSF(postdata, csrf_context=session)
assert not form.validate()
self.assertEqual(form.csrf_token.errors[0], 'CSRF token expired')
def test_notime(self):
session = {}
form = self.NoTimeSSF(csrf_context=session)
hmacced = hmac.new(form.SECRET_KEY, session['csrf'].encode('utf8'), digestmod=hashlib.sha1)
self.assertEqual(form.csrf_token._value(), '##%s' % hmacced.hexdigest())
assert not form.validate()
self.assertEqual(form.csrf_token.errors[0], 'CSRF token missing')
# Test with pre-made values
session = {'csrf': '00e9fa5fe507251ac5f32b1608e9282f75156a05'}
postdata = DummyPostData(csrf_token='##d21f54b7dd2041fab5f8d644d4d3690c77beeb14')
form = self.NoTimeSSF(postdata, csrf_context=session)
assert form.validate()
|