File: schema_example.py

package info (click to toggle)
python-webargs 8.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 696 kB
  • sloc: python: 4,907; makefile: 149
file content (151 lines) | stat: -rw-r--r-- 4,067 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""Example implementation of using a marshmallow Schema for both request input
and output with a `use_schema` decorator.
Run the app:

    $ python examples/schema_example.py

Try the following with httpie (a cURL-like utility, http://httpie.org):

    $ pip install httpie
    $ http GET :5001/users/
    $ http GET :5001/users/42
    $ http POST :5001/users/ username=brian first_name=Brian last_name=May
    $ http PATCH :5001/users/42 username=freddie
    $ http GET :5001/users/ limit==1
"""

import functools
import random

from flask import Flask, request
from marshmallow import Schema, fields, post_dump

from webargs.flaskparser import parser, use_kwargs

app = Flask(__name__)

##### Fake database and model #####


class Model:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

    def update(self, **kwargs):
        self.__dict__.update(kwargs)

    @classmethod
    def insert(cls, db, **kwargs):
        collection = db[cls.collection]
        new_id = None
        if "id" in kwargs:  # for setting up fixtures
            new_id = kwargs.pop("id")
        else:  # find a new id
            found_id = False
            while not found_id:
                new_id = random.randint(1, 9999)
                if new_id not in collection:
                    found_id = True
        new_record = cls(id=new_id, **kwargs)
        collection[new_id] = new_record
        return new_record


class User(Model):
    collection = "users"


db = {"users": {}}


##### use_schema #####


def use_schema(schema_cls, list_view=False, locations=None):
    """View decorator for using a marshmallow schema to
    (1) parse a request's input and
    (2) serializing the view's output to a JSON response.
    """

    def decorator(func):
        @functools.wraps(func)
        def wrapped(*args, **kwargs):
            partial = request.method != "POST"
            schema = schema_cls(partial=partial)
            use_args_wrapper = parser.use_args(schema, locations=locations)
            # Function wrapped with use_args
            func_with_args = use_args_wrapper(func)
            ret = func_with_args(*args, **kwargs)

            # support (json, status) tuples
            if isinstance(ret, tuple) and len(ret) == 2 and isinstance(ret[1], int):
                return schema.dump(ret[0], many=list_view), ret[1]

            return schema.dump(ret, many=list_view)

        return wrapped

    return decorator


##### Schemas #####


class UserSchema(Schema):
    id = fields.Int(dump_only=True)
    username = fields.Str(required=True)
    first_name = fields.Str()
    last_name = fields.Str()

    @post_dump(pass_many=True)
    def wrap_with_envelope(self, data, many, **kwargs):
        return {"data": data}


##### Routes #####


@app.route("/users/<int:user_id>", methods=["GET", "PATCH"])
@use_schema(UserSchema)
def user_detail(reqargs, user_id):
    user = db["users"].get(user_id)
    if not user:
        return {"message": "User not found"}, 404
    if request.method == "PATCH" and reqargs:
        user.update(**reqargs)
    return user


# You can add additional arguments with use_kwargs
@app.route("/users/", methods=["GET", "POST"])
@use_kwargs({"limit": fields.Int(load_default=10, location="query")})
@use_schema(UserSchema, list_view=True)
def user_list(reqargs, limit):
    users = db["users"].values()
    if request.method == "POST":
        User.insert(db=db, **reqargs)
    return list(users)[:limit]


# Return validation errors as JSON
@app.errorhandler(422)
@app.errorhandler(400)
def handle_validation_error(err):
    exc = getattr(err, "exc", None)
    if exc:
        headers = err.data["headers"]
        messages = exc.messages
    else:
        headers = None
        messages = ["Invalid request."]
    if headers:
        return {"errors": messages}, err.code, headers
    else:
        return {"errors": messages}, err.code


if __name__ == "__main__":
    User.insert(
        db=db, id=42, username="fred", first_name="Freddie", last_name="Mercury"
    )
    app.run(port=5001, debug=True)