File: annotations_example.py

package info (click to toggle)
python-webargs 8.7.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 724 kB
  • sloc: python: 4,906; makefile: 149
file content (142 lines) | stat: -rw-r--r-- 3,806 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""Example of using Python 3 function annotations to define
request arguments and output schemas.

Run the app:

    $ python examples/annotations_example.py

Try the following with httpie (a cURL-like utility, http://httpie.org):

    $ pip install httpie
    $ http GET :5001/
    $ http GET :5001/ name==Ada
    $ http POST :5001/add x=40 y=2
    $ http GET :5001/users/42
"""

import functools
import random

from flask import Flask, request
from marshmallow import Schema

from webargs import fields
from webargs.flaskparser import parser

app = Flask(__name__)

##### Routing wrapper ####


def route(*args, **kwargs):
    """Combines `Flask.route` and webargs parsing. Allows arguments to be specified
    as function annotations. An output schema can optionally be specified by a
    return annotation.
    """

    def decorator(func):
        @app.route(*args, **kwargs)
        @functools.wraps(func)
        def wrapped_view(*a, **kw):
            annotations = getattr(func, "__annotations__", {})
            reqargs = {
                name: value
                for name, value in annotations.items()
                if isinstance(value, fields.Field) and name != "return"
            }
            response_schema = annotations.get("return")
            schema_cls = Schema.from_dict(reqargs)
            partial = request.method != "POST"
            parsed = parser.parse(schema_cls(partial=partial), request)
            kw.update(parsed)
            response_data = func(*a, **kw)
            if response_schema:
                return response_schema.dump(response_data)
            else:
                return func(*a, **kw)

        return wrapped_view

    return decorator


##### Fake database and model #####


class Model:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

    def update(self, **kwargs):
        self.__dict__.update(kwargs)

    @classmethod
    def insert(cls, db, **kwargs):
        collection = db[cls.collection]
        new_id = None
        if "id" in kwargs:  # for setting up fixtures
            new_id = kwargs.pop("id")
        else:  # find a new id
            found_id = False
            while not found_id:
                new_id = random.randint(1, 9999)
                if new_id not in collection:
                    found_id = True
        new_record = cls(id=new_id, **kwargs)
        collection[new_id] = new_record
        return new_record


class User(Model):
    collection = "users"


db = {"users": {}}

##### Views #####


@route("/", methods=["GET"])
def index(name: fields.Str(load_default="Friend")):  # noqa: F821
    return {"message": f"Hello, {name}!"}


@route("/add", methods=["POST"])
def add(x: fields.Float(required=True), y: fields.Float(required=True)):
    return {"result": x + y}


class UserSchema(Schema):
    id = fields.Int(dump_only=True)
    username = fields.Str(required=True)
    first_name = fields.Str()
    last_name = fields.Str()


@route("/users/<int:user_id>", methods=["GET", "PATCH"])
def user_detail(user_id, username: fields.Str(required=True) = None) -> UserSchema():
    user = db["users"].get(user_id)
    if not user:
        return {"message": "User not found"}, 404
    if request.method == "PATCH":
        user.update(username=username)
    return user


# Return validation errors as JSON
@app.errorhandler(422)
@app.errorhandler(400)
def handle_error(err):
    headers = err.data.get("headers", None)
    messages = err.data.get("messages", ["Invalid request."])
    if headers:
        return {"errors": messages}, err.code, headers
    else:
        return {"errors": messages}, err.code


if __name__ == "__main__":
    User.insert(
        db=db, id=42, username="fred", first_name="Freddie", last_name="Mercury"
    )
    app.run(port=5001, debug=True)