File: test_regression.py

package info (click to toggle)
flask 1.0.2-3
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 2,224 kB
  • sloc: python: 8,975; makefile: 55; pascal: 51; sql: 22
file content (99 lines) | stat: -rw-r--r-- 2,358 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# -*- coding: utf-8 -*-
"""
    tests.regression
    ~~~~~~~~~~~~~~~~~~~~~~~~~~

    Tests regressions.

    :copyright: © 2010 by the Pallets team.
    :license: BSD, see LICENSE for more details.
"""

import gc
import sys
import threading

import pytest
from werkzeug.exceptions import NotFound

import flask

_gc_lock = threading.Lock()


class assert_no_leak(object):

    def __enter__(self):
        gc.disable()
        _gc_lock.acquire()
        loc = flask._request_ctx_stack._local

        # Force Python to track this dictionary at all times.
        # This is necessary since Python only starts tracking
        # dicts if they contain mutable objects.  It's a horrible,
        # horrible hack but makes this kinda testable.
        loc.__storage__['FOOO'] = [1, 2, 3]

        gc.collect()
        self.old_objects = len(gc.get_objects())

    def __exit__(self, exc_type, exc_value, tb):
        gc.collect()
        new_objects = len(gc.get_objects())
        if new_objects > self.old_objects:
            pytest.fail('Example code leaked')
        _gc_lock.release()
        gc.enable()


def test_memory_consumption():
    app = flask.Flask(__name__)

    @app.route('/')
    def index():
        return flask.render_template('simple_template.html', whiskey=42)

    def fire():
        with app.test_client() as c:
            rv = c.get('/')
            assert rv.status_code == 200
            assert rv.data == b'<h1>42</h1>'

    # Trigger caches
    fire()

    # This test only works on CPython 2.7.
    if sys.version_info >= (2, 7) and \
            not hasattr(sys, 'pypy_translation_info'):
        with assert_no_leak():
            for x in range(10):
                fire()


def test_safe_join_toplevel_pardir():
    from flask.helpers import safe_join
    with pytest.raises(NotFound):
        safe_join('/foo', '..')


def test_aborting(app):
    class Foo(Exception):
        whatever = 42

    @app.errorhandler(Foo)
    def handle_foo(e):
        return str(e.whatever)

    @app.route('/')
    def index():
        raise flask.abort(flask.redirect(flask.url_for('test')))

    @app.route('/test')
    def test():
        raise Foo()

    with app.test_client() as c:
        rv = c.get('/')
        assert rv.headers['Location'] == 'http://localhost/test'
        rv = c.get('/test')
        assert rv.data == b'42'