1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
|
import pytest
from flask import Flask, render_template_string
from aws_xray_sdk.ext.flask.middleware import XRayMiddleware
from aws_xray_sdk.core.context import Context
from tests.util import get_new_stubbed_recorder
# define a flask app for testing purpose
app = Flask(__name__)
@app.route('/ok')
def ok():
return 'ok'
@app.route('/error')
def error():
return 'Not Found', 404
@app.route('/fault')
def fault():
return {}['key']
@app.route('/template')
def template():
return render_template_string('hello template')
# add X-Ray middleware to flask app
recorder = get_new_stubbed_recorder()
recorder.configure(service='test', sampling=False, context=Context())
XRayMiddleware(app, recorder)
# enable testing mode
app.config['TESTING'] = True
app = app.test_client()
BASE_URL = 'http://localhost{}'
@pytest.fixture(autouse=True)
def cleanup():
"""
Clean up context storage before and after each test run
"""
recorder.clear_trace_entities()
yield
recorder.clear_trace_entities()
def test_ok():
path = '/ok'
app.get(path)
segment = recorder.emitter.pop()
assert not segment.in_progress
request = segment.http['request']
response = segment.http['response']
assert request['method'] == 'GET'
assert request['url'] == BASE_URL.format(path)
assert request['client_ip'] == '127.0.0.1'
assert response['status'] == 200
assert response['content_length'] == 2
def test_error():
path = '/error'
app.get(path)
segment = recorder.emitter.pop()
assert not segment.in_progress
assert segment.error
request = segment.http['request']
response = segment.http['response']
assert request['method'] == 'GET'
assert request['url'] == BASE_URL.format(path)
assert request['client_ip'] == '127.0.0.1'
assert response['status'] == 404
def test_fault():
path = '/fault'
try:
app.get(path)
except Exception:
pass
segment = recorder.emitter.pop()
assert not segment.in_progress
assert segment.fault
response = segment.http['response']
assert response['status'] == 500
exception = segment.cause['exceptions'][0]
assert exception.type == 'KeyError'
def test_render_template():
path = '/template'
app.get(path)
segment = recorder.emitter.pop()
assert not segment.in_progress
# segment should contain a template render subsegment
assert segment.subsegments
subsegment = segment.subsegments[0]
assert subsegment.name
assert subsegment.namespace == 'local'
assert not subsegment.in_progress
|