1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
|
"""Example implementation of using a marshmallow Schema for both request input
and output with a `use_schema` decorator.
Run the app:
$ python examples/schema_example.py
Try the following with httpie (a cURL-like utility, http://httpie.org):
$ pip install httpie
$ http GET :5001/users/
$ http GET :5001/users/42
$ http POST :5001/users/ username=brian first_name=Brian last_name=May
$ http PATCH :5001/users/42 username=freddie
$ http GET :5001/users/ limit==1
"""
import functools
import random
from flask import Flask, request
from marshmallow import Schema, fields, post_dump
from webargs.flaskparser import parser, use_kwargs
app = Flask(__name__)
##### Fake database and model #####
class Model:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def update(self, **kwargs):
self.__dict__.update(kwargs)
@classmethod
def insert(cls, db, **kwargs):
collection = db[cls.collection]
new_id = None
if "id" in kwargs: # for setting up fixtures
new_id = kwargs.pop("id")
else: # find a new id
found_id = False
while not found_id:
new_id = random.randint(1, 9999)
if new_id not in collection:
found_id = True
new_record = cls(id=new_id, **kwargs)
collection[new_id] = new_record
return new_record
class User(Model):
collection = "users"
db = {"users": {}}
##### use_schema #####
def use_schema(schema_cls, list_view=False, locations=None):
"""View decorator for using a marshmallow schema to
(1) parse a request's input and
(2) serializing the view's output to a JSON response.
"""
def decorator(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
partial = request.method != "POST"
schema = schema_cls(partial=partial)
use_args_wrapper = parser.use_args(schema, locations=locations)
# Function wrapped with use_args
func_with_args = use_args_wrapper(func)
ret = func_with_args(*args, **kwargs)
# support (json, status) tuples
if isinstance(ret, tuple) and len(ret) == 2 and isinstance(ret[1], int):
return schema.dump(ret[0], many=list_view), ret[1]
return schema.dump(ret, many=list_view)
return wrapped
return decorator
##### Schemas #####
class UserSchema(Schema):
id = fields.Int(dump_only=True)
username = fields.Str(required=True)
first_name = fields.Str()
last_name = fields.Str()
@post_dump(pass_many=True)
def wrap_with_envelope(self, data, many, **kwargs):
return {"data": data}
##### Routes #####
@app.route("/users/<int:user_id>", methods=["GET", "PATCH"])
@use_schema(UserSchema)
def user_detail(reqargs, user_id):
user = db["users"].get(user_id)
if not user:
return {"message": "User not found"}, 404
if request.method == "PATCH" and reqargs:
user.update(**reqargs)
return user
# You can add additional arguments with use_kwargs
@app.route("/users/", methods=["GET", "POST"])
@use_kwargs({"limit": fields.Int(load_default=10, location="query")})
@use_schema(UserSchema, list_view=True)
def user_list(reqargs, limit):
users = db["users"].values()
if request.method == "POST":
User.insert(db=db, **reqargs)
return list(users)[:limit]
# Return validation errors as JSON
@app.errorhandler(422)
@app.errorhandler(400)
def handle_validation_error(err):
exc = getattr(err, "exc", None)
if exc:
headers = err.data["headers"]
messages = exc.messages
else:
headers = None
messages = ["Invalid request."]
if headers:
return {"errors": messages}, err.code, headers
else:
return {"errors": messages}, err.code
if __name__ == "__main__":
User.insert(
db=db, id=42, username="fred", first_name="Freddie", last_name="Mercury"
)
app.run(port=5001, debug=True)
|