1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
|
"""Example of using Python 3 function annotations to define
request arguments and output schemas.
Run the app:
$ python examples/annotations_example.py
Try the following with httpie (a cURL-like utility, http://httpie.org):
$ pip install httpie
$ http GET :5001/
$ http GET :5001/ name==Ada
$ http POST :5001/add x=40 y=2
$ http GET :5001/users/42
"""
import functools
import random
from flask import Flask, request
from marshmallow import Schema
from webargs import fields
from webargs.flaskparser import parser
app = Flask(__name__)
##### Routing wrapper ####
def route(*args, **kwargs):
"""Combines `Flask.route` and webargs parsing. Allows arguments to be specified
as function annotations. An output schema can optionally be specified by a
return annotation.
"""
def decorator(func):
@app.route(*args, **kwargs)
@functools.wraps(func)
def wrapped_view(*a, **kw):
annotations = getattr(func, "__annotations__", {})
reqargs = {
name: value
for name, value in annotations.items()
if isinstance(value, fields.Field) and name != "return"
}
response_schema = annotations.get("return")
schema_cls = Schema.from_dict(reqargs)
partial = request.method != "POST"
parsed = parser.parse(schema_cls(partial=partial), request)
kw.update(parsed)
response_data = func(*a, **kw)
if response_schema:
return response_schema.dump(response_data)
else:
return func(*a, **kw)
return wrapped_view
return decorator
##### Fake database and model #####
class Model:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def update(self, **kwargs):
self.__dict__.update(kwargs)
@classmethod
def insert(cls, db, **kwargs):
collection = db[cls.collection]
new_id = None
if "id" in kwargs: # for setting up fixtures
new_id = kwargs.pop("id")
else: # find a new id
found_id = False
while not found_id:
new_id = random.randint(1, 9999)
if new_id not in collection:
found_id = True
new_record = cls(id=new_id, **kwargs)
collection[new_id] = new_record
return new_record
class User(Model):
collection = "users"
db = {"users": {}}
##### Views #####
@route("/", methods=["GET"])
def index(name: fields.Str(load_default="Friend")): # noqa: F821
return {"message": f"Hello, {name}!"}
@route("/add", methods=["POST"])
def add(x: fields.Float(required=True), y: fields.Float(required=True)):
return {"result": x + y}
class UserSchema(Schema):
id = fields.Int(dump_only=True)
username = fields.Str(required=True)
first_name = fields.Str()
last_name = fields.Str()
@route("/users/<int:user_id>", methods=["GET", "PATCH"])
def user_detail(user_id, username: fields.Str(required=True) = None) -> UserSchema():
user = db["users"].get(user_id)
if not user:
return {"message": "User not found"}, 404
if request.method == "PATCH":
user.update(username=username)
return user
# Return validation errors as JSON
@app.errorhandler(422)
@app.errorhandler(400)
def handle_error(err):
headers = err.data.get("headers", None)
messages = err.data.get("messages", ["Invalid request."])
if headers:
return {"errors": messages}, err.code, headers
else:
return {"errors": messages}, err.code
if __name__ == "__main__":
User.insert(
db=db, id=42, username="fred", first_name="Freddie", last_name="Mercury"
)
app.run(port=5001, debug=True)
|