1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
|
import pytest
from flask import Blueprint
from flask import g
from flask import render_template_string
from flask_wtf import FlaskForm
from flask_wtf.csrf import CSRFError
from flask_wtf.csrf import CSRFProtect
from flask_wtf.csrf import generate_csrf
@pytest.fixture
def app(app):
CSRFProtect(app)
@app.route("/", methods=["GET", "POST"])
def index():
pass
@app.after_request
def add_csrf_header(response):
response.headers.set("X-CSRF-Token", generate_csrf())
return response
return app
@pytest.fixture
def csrf(app):
return app.extensions["csrf"]
def test_render_token(req_ctx):
token = generate_csrf()
assert render_template_string("{{ csrf_token() }}") == token
def test_protect(app, client, app_ctx):
response = client.post("/")
assert response.status_code == 400
assert "The CSRF token is missing." in response.get_data(as_text=True)
app.config["WTF_CSRF_ENABLED"] = False
assert client.post("/").get_data() == b""
app.config["WTF_CSRF_ENABLED"] = True
app.config["WTF_CSRF_CHECK_DEFAULT"] = False
assert client.post("/").get_data() == b""
app.config["WTF_CSRF_CHECK_DEFAULT"] = True
assert client.options("/").status_code == 200
assert client.post("/not-found").status_code == 404
response = client.get("/")
assert response.status_code == 200
token = response.headers["X-CSRF-Token"]
assert client.post("/", data={"csrf_token": token}).status_code == 200
assert client.post("/", data={"prefix-csrf_token": token}).status_code == 200
assert client.post("/", data={"prefix-csrf_token": ""}).status_code == 400
assert client.post("/", headers={"X-CSRF-Token": token}).status_code == 200
def test_same_origin(client):
token = client.get("/").headers["X-CSRF-Token"]
response = client.post(
"/", base_url="https://localhost", headers={"X-CSRF-Token": token}
)
data = response.get_data(as_text=True)
assert "The referrer header is missing." in data
response = client.post(
"/",
base_url="https://localhost",
headers={"X-CSRF-Token": token, "Referer": "http://localhost/"},
)
data = response.get_data(as_text=True)
assert "The referrer does not match the host." in data
response = client.post(
"/",
base_url="https://localhost",
headers={"X-CSRF-Token": token, "Referer": "https://other/"},
)
data = response.get_data(as_text=True)
assert "The referrer does not match the host." in data
response = client.post(
"/",
base_url="https://localhost",
headers={"X-CSRF-Token": token, "Referer": "https://localhost:8080/"},
)
data = response.get_data(as_text=True)
assert "The referrer does not match the host." in data
response = client.post(
"/",
base_url="https://localhost",
headers={"X-CSRF-Token": token, "Referer": "https://localhost/"},
)
assert response.status_code == 200
def test_form_csrf_short_circuit(app, client):
@app.route("/skip", methods=["POST"])
def skip():
assert g.get("csrf_valid")
# don't pass the token, then validate the form
# this would fail if CSRFProtect didn't run
form = FlaskForm(None)
assert form.validate()
token = client.get("/").headers["X-CSRF-Token"]
response = client.post("/skip", headers={"X-CSRF-Token": token})
assert response.status_code == 200
def test_exempt_view(app, csrf, client):
@app.route("/exempt", methods=["POST"])
@csrf.exempt
def exempt():
pass
response = client.post("/exempt")
assert response.status_code == 200
csrf.exempt("test_csrf_extension.index")
response = client.post("/")
assert response.status_code == 200
def test_manual_protect(app, csrf, client):
@app.route("/manual", methods=["GET", "POST"])
@csrf.exempt
def manual():
csrf.protect()
response = client.get("/manual")
assert response.status_code == 200
response = client.post("/manual")
assert response.status_code == 400
def test_exempt_blueprint(app, csrf, client):
bp = Blueprint("exempt", __name__, url_prefix="/exempt")
csrf.exempt(bp)
@bp.route("/", methods=["POST"])
def index():
pass
app.register_blueprint(bp)
response = client.post("/exempt/")
assert response.status_code == 200
def test_exempt_nested_blueprint(app, csrf, client):
bp1 = Blueprint("exempt1", __name__, url_prefix="/")
bp2 = Blueprint("exempt2", __name__, url_prefix="/exempt")
csrf.exempt(bp2)
@bp2.route("/", methods=["POST"])
def index():
pass
bp1.register_blueprint(bp2)
app.register_blueprint(bp1)
response = client.post("/exempt/")
assert response.status_code == 200
def test_error_handler(app, client):
@app.errorhandler(CSRFError)
def handle_csrf_error(e):
return e.description.lower()
response = client.post("/")
assert response.get_data(as_text=True) == "the csrf token is missing."
def test_validate_error_logged(client, monkeypatch):
from flask_wtf.csrf import logger
messages = []
def assert_info(message):
messages.append(message)
monkeypatch.setattr(logger, "info", assert_info)
client.post("/")
assert len(messages) == 1
assert messages[0] == "The CSRF token is missing."
|