1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
|
import anyio
from aiobotocore.session import AioSession
from ...mock_server import AIOServer
from .. import ClientHTTPStubber
def get_captured_ua_strings(stubber):
"""Get captured request-level user agent strings from stubber.
:type stubber: tests.BaseHTTPStubber
"""
return [req.headers['User-Agent'].decode() for req in stubber.requests]
def parse_registered_feature_ids(ua_string):
"""Parse registered feature ids in user agent string.
:type ua_string: str
:rtype: list[str]
"""
ua_fields = ua_string.split(' ')
feature_field = [field for field in ua_fields if field.startswith('m/')][0]
return feature_field[2:].split(',')
async def test_user_agent_has_registered_feature_id():
session = AioSession()
async with (
AIOServer() as server,
session.create_client(
's3',
endpoint_url=server.endpoint_url,
aws_secret_access_key='xxx',
aws_access_key_id='xxx',
) as s3_client,
):
with ClientHTTPStubber(s3_client) as stub_client:
stub_client.add_response()
paginator = s3_client.get_paginator('list_buckets')
# The `paginate()` method registers `'PAGINATOR': 'C'`
async for _ in paginator.paginate():
pass
ua_string = get_captured_ua_strings(stub_client)[0]
feature_list = parse_registered_feature_ids(ua_string)
assert 'C' in feature_list
async def test_registered_feature_ids_dont_bleed_between_requests():
session = AioSession()
async with (
AIOServer() as server,
session.create_client(
's3',
endpoint_url=server.endpoint_url,
aws_secret_access_key='xxx',
aws_access_key_id='xxx',
) as s3_client,
):
with ClientHTTPStubber(s3_client) as stub_client:
stub_client.add_response()
waiter = s3_client.get_waiter('bucket_exists')
# The `wait()` method registers `'WAITER': 'B'`
await waiter.wait(Bucket='mybucket')
stub_client.add_response()
paginator = s3_client.get_paginator('list_buckets')
# The `paginate()` method registers `'PAGINATOR': 'C'`
async for _ in paginator.paginate():
pass
ua_strings = get_captured_ua_strings(stub_client)
waiter_feature_list = parse_registered_feature_ids(ua_strings[0])
assert 'B' in waiter_feature_list
paginator_feature_list = parse_registered_feature_ids(ua_strings[1])
assert 'C' in paginator_feature_list
assert 'B' not in paginator_feature_list
# This tests context's bleeding across tasks instead
async def test_registered_feature_ids_dont_bleed_across_threads():
session = AioSession()
async with (
AIOServer() as server,
session.create_client(
's3',
endpoint_url=server.endpoint_url,
aws_secret_access_key='xxx',
aws_access_key_id='xxx',
) as s3_client,
):
waiter_features = []
paginator_features = []
async def wait():
with ClientHTTPStubber(s3_client) as stub_client:
stub_client.add_response()
waiter = s3_client.get_waiter('bucket_exists')
# The `wait()` method registers `'WAITER': 'B'`
await waiter.wait(Bucket='mybucket')
ua_string = get_captured_ua_strings(stub_client)[0]
waiter_features.extend(parse_registered_feature_ids(ua_string))
async def paginate():
with ClientHTTPStubber(s3_client) as stub_client:
stub_client.add_response()
paginator = s3_client.get_paginator('list_buckets')
# The `paginate()` method registers `'PAGINATOR': 'C'`
async for _ in paginator.paginate():
pass
ua_string = get_captured_ua_strings(stub_client)[0]
paginator_features.extend(parse_registered_feature_ids(ua_string))
async with anyio.create_task_group() as tg:
tg.start_soon(wait)
tg.start_soon(paginate)
assert 'B' in waiter_features
assert 'C' not in waiter_features
assert 'C' in paginator_features
assert 'B' not in paginator_features
|