1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
|
from __future__ import absolute_import, unicode_literals
import hashlib
import hmac
import json
from django import forms
from django.conf import settings
from django.core.exceptions import ValidationError
from django.db import connections
from django.utils.crypto import constant_time_compare
from django.utils.encoding import force_bytes
from django.utils.functional import cached_property
from debug_toolbar.panels.sql.utils import reformat_sql
class SQLSelectForm(forms.Form):
"""
Validate params
sql: The sql statement with interpolated params
raw_sql: The sql statement with placeholders
params: JSON encoded parameter values
duration: time for SQL to execute passed in from toolbar just for redisplay
hash: the hash of (secret + sql + params) for tamper checking
"""
sql = forms.CharField()
raw_sql = forms.CharField()
params = forms.CharField()
alias = forms.CharField(required=False, initial='default')
duration = forms.FloatField()
hash = forms.CharField()
def __init__(self, *args, **kwargs):
initial = kwargs.get('initial', None)
if initial is not None:
initial['hash'] = self.make_hash(initial)
super(SQLSelectForm, self).__init__(*args, **kwargs)
for name in self.fields:
self.fields[name].widget = forms.HiddenInput()
def clean_raw_sql(self):
value = self.cleaned_data['raw_sql']
if not value.lower().strip().startswith('select'):
raise ValidationError("Only 'select' queries are allowed.")
return value
def clean_params(self):
value = self.cleaned_data['params']
try:
return json.loads(value)
except ValueError:
raise ValidationError('Is not valid JSON')
def clean_alias(self):
value = self.cleaned_data['alias']
if value not in connections:
raise ValidationError("Database alias '%s' not found" % value)
return value
def clean_hash(self):
hash = self.cleaned_data['hash']
if not constant_time_compare(hash, self.make_hash(self.data)):
raise ValidationError('Tamper alert')
return hash
def reformat_sql(self):
return reformat_sql(self.cleaned_data['sql'])
def make_hash(self, data):
m = hmac.new(key=force_bytes(settings.SECRET_KEY), digestmod=hashlib.sha1)
for item in [data['sql'], data['params']]:
m.update(force_bytes(item))
return m.hexdigest()
@property
def connection(self):
return connections[self.cleaned_data['alias']]
@cached_property
def cursor(self):
return self.connection.cursor()
|