1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
|
import functools
import os
from flask import Blueprint
from flask import abort
from flask import flash
from flask import g
from flask import redirect
from flask import render_template
from flask import request
from flask import session
from flask import url_for
from peewee import *
from wtforms import Form
from wtforms import PasswordField
from wtforms.fields import StringField
from flask_peewee.utils import check_password
from flask_peewee.utils import get_next
from flask_peewee.utils import make_password
from flask_peewee._wtforms_compat import DataRequired
current_dir = os.path.dirname(__file__)
class LoginForm(Form):
username = StringField('Username', validators=[DataRequired()])
password = PasswordField('Password', validators=[DataRequired()])
class BaseUser(object):
def set_password(self, password):
self.password = make_password(password)
def check_password(self, password):
return check_password(password, self.password)
class Auth(object):
def __init__(self, app, db, user_model=None, prefix='/accounts', name='auth',
clear_session=False, default_next_url='/', db_table='user'):
self.app = app
self.db = db
self.db_table = db_table
self.User = user_model or self.get_user_model()
self.blueprint = self.get_blueprint(name)
self.url_prefix = prefix
self.clear_session = clear_session
self.default_next_url = default_next_url
self.setup()
def get_context_user(self):
return {'user': self.get_logged_in_user()}
def get_user_model(self):
class User(self.db.Model, BaseUser):
username = CharField(unique=True)
password = CharField()
email = CharField(unique=True)
active = BooleanField()
admin = BooleanField(default=False)
def __unicode__(self):
return self.username
class Meta:
table_name = self.db_table
return User
def get_model_admin(self, model_admin=None):
if model_admin is None:
from flask_peewee.admin import ModelAdmin
model_admin = ModelAdmin
class UserAdmin(model_admin):
columns = getattr(model_admin, 'columns') or (
['username', 'email', 'active', 'admin'])
def save_model(self, instance, form, adding=False):
orig_password = instance.password
user = super(UserAdmin, self).save_model(instance, form, adding)
if orig_password != form.password.data:
user.set_password(form.password.data)
user.save()
return user
return UserAdmin
def register_admin(self, admin_site, model_admin=None):
admin_site.register(self.User, self.get_model_admin(model_admin))
def get_blueprint(self, blueprint_name):
return Blueprint(
blueprint_name,
__name__,
static_folder=os.path.join(current_dir, 'static'),
template_folder=os.path.join(current_dir, 'templates'),
)
def get_urls(self):
return (
('/logout/', self.logout),
('/login/', self.login),
)
def get_login_form(self):
return LoginForm
def test_user(self, test_fn):
def decorator(fn):
@functools.wraps(fn)
def inner(*args, **kwargs):
user = self.get_logged_in_user()
if not user or not test_fn(user):
login_url = url_for('%s.login' % self.blueprint.name, next=get_next())
return redirect(login_url)
return fn(*args, **kwargs)
return inner
return decorator
def login_required(self, func):
return self.test_user(lambda u: True)(func)
def admin_required(self, func):
return self.test_user(lambda u: u.admin)(func)
def authenticate(self, username, password):
active = self.User.select().where(self.User.active==True)
try:
user = active.where(self.User.username==username).get()
except self.User.DoesNotExist:
return False
else:
if not user.check_password(password):
return False
return user
def login_user(self, user):
session['logged_in'] = True
session['user_pk'] = user._pk
session.permanent = True
g.user = user
flash('You are logged in as %s' % user, 'success')
def logout_user(self):
if self.clear_session:
session.clear()
else:
session.pop('logged_in', None)
g.user = None
flash('You are now logged out', 'success')
def get_logged_in_user(self):
if session.get('logged_in'):
if getattr(g, 'user', None):
return g.user
try:
return self.User.select().where(
self.User.active==True,
self.User.id==session.get('user_pk')
).get()
except self.User.DoesNotExist:
pass
def login(self):
error = None
Form = self.get_login_form()
if request.method == 'POST':
form = Form(request.form)
next_url = request.form.get('next') or self.default_next_url
if form.validate():
authenticated_user = self.authenticate(
form.username.data,
form.password.data,
)
if authenticated_user:
self.login_user(authenticated_user)
return redirect(next_url)
else:
flash('Incorrect username or password')
else:
form = Form()
next_url = request.args.get('next')
return render_template(
'auth/login.html',
error=error,
form=form,
login_url=url_for('%s.login' % self.blueprint.name),
next=next_url)
def logout(self):
self.logout_user()
return redirect(request.args.get('next') or self.default_next_url)
def configure_routes(self):
for url, callback in self.get_urls():
self.blueprint.route(url, methods=['GET', 'POST'])(callback)
def register_blueprint(self, **kwargs):
self.app.register_blueprint(self.blueprint, url_prefix=self.url_prefix, **kwargs)
def load_user(self):
g.user = self.get_logged_in_user()
def register_handlers(self):
self.app.before_request_funcs.setdefault(None, [])
self.app.before_request_funcs[None].append(self.load_user)
def register_context_processors(self):
self.app.template_context_processors[None].append(self.get_context_user)
def setup(self):
self.configure_routes()
self.register_blueprint()
self.register_handlers()
self.register_context_processors()
|