# Copyright 2012-2014 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import copy
import datetime
import io
import operator

import pytest
from dateutil.tz import tzoffset, tzutc

import botocore
from botocore import xform_name
from botocore.awsrequest import AWSRequest, HeadersDict
from botocore.compat import json
from botocore.config import Config
from botocore.endpoint_provider import RuleSetEndpoint
from botocore.exceptions import (
    ClientError,
    ConfigNotFound,
    ConnectionClosedError,
    ConnectTimeoutError,
    InvalidDNSNameError,
    InvalidExpressionError,
    InvalidIMDSEndpointError,
    InvalidIMDSEndpointModeError,
    MetadataRetrievalError,
    ReadTimeoutError,
    SSOTokenLoadError,
    UnsupportedOutpostResourceError,
    UnsupportedS3AccesspointConfigurationError,
    UnsupportedS3ArnError,
)
from botocore.model import (
    DenormalizedStructureBuilder,
    OperationModel,
    ServiceModel,
    ShapeResolver,
)
from botocore.regions import EndpointRulesetResolver
from botocore.session import Session
from botocore.utils import (
    ArgumentGenerator,
    ArnParser,
    CachedProperty,
    ContainerMetadataFetcher,
    IMDSRegionProvider,
    InstanceMetadataFetcher,
    InstanceMetadataRegionFetcher,
    InvalidArnException,
    S3ArnParamHandler,
    S3EndpointSetter,
    S3RegionRedirectorv2,
    SSOTokenLoader,
    calculate_sha256,
    calculate_tree_hash,
    datetime2timestamp,
    deep_merge,
    determine_content_length,
    ensure_boolean,
    fix_s3_host,
    get_encoding_from_headers,
    get_service_module_name,
    has_header,
    instance_cache,
    is_json_value_header,
    is_s3_accelerate_url,
    is_valid_endpoint_url,
    is_valid_ipv6_endpoint_url,
    is_valid_uri,
    lowercase_dict,
    merge_dicts,
    normalize_url_path,
    parse_key_val_file,
    parse_key_val_file_contents,
    parse_timestamp,
    parse_to_aware_datetime,
    percent_encode,
    percent_encode_sequence,
    remove_dot_segments,
    resolve_imds_endpoint_mode,
    set_value_from_jmespath,
    switch_host_s3_accelerate,
    switch_to_virtual_host_style,
    validate_jmespath_for_set,
)
from tests import FreezeTime, RawResponse, create_session, mock, unittest

DATE = datetime.datetime(2021, 12, 10, 00, 00, 00)
DT_FORMAT = "%Y-%m-%dT%H:%M:%SZ"


class TestEnsureBoolean(unittest.TestCase):
    def test_boolean_true(self):
        self.assertEqual(ensure_boolean(True), True)

    def test_boolean_false(self):
        self.assertEqual(ensure_boolean(False), False)

    def test_string_true(self):
        self.assertEqual(ensure_boolean('True'), True)

    def test_string_false(self):
        self.assertEqual(ensure_boolean('False'), False)

    def test_string_lowercase_true(self):
        self.assertEqual(ensure_boolean('true'), True)

    def test_invalid_type_false(self):
        self.assertEqual(ensure_boolean({'foo': 'bar'}), False)


class TestResolveIMDSEndpointMode(unittest.TestCase):
    def create_session_with_config(self, endpoint_mode, imds_use_IPv6):
        session = create_session()
        session.set_config_variable(
            'ec2_metadata_service_endpoint_mode', endpoint_mode
        )
        session.set_config_variable('imds_use_ipv6', imds_use_IPv6)
        return session

    def test_resolve_endpoint_mode_no_config(self):
        session = self.create_session_with_config(None, None)
        self.assertEqual(resolve_imds_endpoint_mode(session), 'ipv4')

    def test_resolve_endpoint_mode_IPv6(self):
        session = self.create_session_with_config('IPv6', None)
        self.assertEqual(resolve_imds_endpoint_mode(session), 'ipv6')

    def test_resolve_endpoint_mode_IPv4(self):
        session = self.create_session_with_config('IPv4', None)
        self.assertEqual(resolve_imds_endpoint_mode(session), 'ipv4')

    def test_resolve_endpoint_mode_none_use_IPv6_true(self):
        session = self.create_session_with_config(None, True)
        self.assertEqual(resolve_imds_endpoint_mode(session), 'ipv6')

    def test_resolve_endpoint_mode_none_use_IPv6_false(self):
        session = self.create_session_with_config(None, False)
        self.assertEqual(resolve_imds_endpoint_mode(session), 'ipv4')

    def test_resolve_endpoint_mode_IPv6_use_IPv6_false(self):
        session = self.create_session_with_config('IPv6', False)
        self.assertEqual(resolve_imds_endpoint_mode(session), 'ipv6')

    def test_resolve_endpoint_mode_IPv4_use_IPv6_true(self):
        session = self.create_session_with_config('IPv4', True)
        self.assertEqual(resolve_imds_endpoint_mode(session), 'ipv4')

    def test_resolve_endpoint_mode_IPv6_use_IPv6_true(self):
        session = self.create_session_with_config('IPv6', True)
        self.assertEqual(resolve_imds_endpoint_mode(session), 'ipv6')

    def test_resolve_endpoint_mode_IPv6_mixed_casing_use_IPv6_true(self):
        session = self.create_session_with_config('iPv6', None)
        self.assertEqual(resolve_imds_endpoint_mode(session), 'ipv6')

    def test_resolve_endpoint_mode_invalid_input(self):
        session = self.create_session_with_config('IPv3', True)
        with self.assertRaises(InvalidIMDSEndpointModeError):
            resolve_imds_endpoint_mode(session)


class TestIsJSONValueHeader(unittest.TestCase):
    def test_no_serialization_section(self):
        shape = mock.Mock()
        shape.type_name = 'string'
        self.assertFalse(is_json_value_header(shape))

    def test_non_jsonvalue_shape(self):
        shape = mock.Mock()
        shape.serialization = {'location': 'header'}
        shape.type_name = 'string'
        self.assertFalse(is_json_value_header(shape))

    def test_non_header_jsonvalue_shape(self):
        shape = mock.Mock()
        shape.serialization = {'jsonvalue': True}
        shape.type_name = 'string'
        self.assertFalse(is_json_value_header(shape))

    def test_non_string_jsonvalue_shape(self):
        shape = mock.Mock()
        shape.serialization = {'location': 'header', 'jsonvalue': True}
        shape.type_name = 'integer'
        self.assertFalse(is_json_value_header(shape))

    def test_json_value_header(self):
        shape = mock.Mock()
        shape.serialization = {'jsonvalue': True, 'location': 'header'}
        shape.type_name = 'string'
        self.assertTrue(is_json_value_header(shape))


class TestURINormalization(unittest.TestCase):
    def test_remove_dot_segments(self):
        self.assertEqual(remove_dot_segments('../foo'), 'foo')
        self.assertEqual(remove_dot_segments('../../foo'), 'foo')
        self.assertEqual(remove_dot_segments('./foo'), 'foo')
        self.assertEqual(remove_dot_segments('/./'), '/')
        self.assertEqual(remove_dot_segments('/../'), '/')
        self.assertEqual(
            remove_dot_segments('/foo/bar/baz/../qux'), '/foo/bar/qux'
        )
        self.assertEqual(remove_dot_segments('/foo/..'), '/')
        self.assertEqual(remove_dot_segments('foo/bar/baz'), 'foo/bar/baz')
        self.assertEqual(remove_dot_segments('..'), '')
        self.assertEqual(remove_dot_segments('.'), '')
        self.assertEqual(remove_dot_segments('/.'), '/')
        self.assertEqual(remove_dot_segments('/.foo'), '/.foo')
        self.assertEqual(remove_dot_segments('/..foo'), '/..foo')
        self.assertEqual(remove_dot_segments(''), '')
        self.assertEqual(remove_dot_segments('/a/b/c/./../../g'), '/a/g')
        self.assertEqual(remove_dot_segments('mid/content=5/../6'), 'mid/6')
        # I don't think this is RFC compliant...
        self.assertEqual(remove_dot_segments('//foo//'), '/foo/')

    def test_empty_url_normalization(self):
        self.assertEqual(normalize_url_path(''), '/')


class TestTransformName(unittest.TestCase):
    def test_upper_camel_case(self):
        self.assertEqual(xform_name('UpperCamelCase'), 'upper_camel_case')
        self.assertEqual(xform_name('UpperCamelCase', '-'), 'upper-camel-case')

    def test_lower_camel_case(self):
        self.assertEqual(xform_name('lowerCamelCase'), 'lower_camel_case')
        self.assertEqual(xform_name('lowerCamelCase', '-'), 'lower-camel-case')

    def test_consecutive_upper_case(self):
        self.assertEqual(xform_name('HTTPHeaders'), 'http_headers')
        self.assertEqual(xform_name('HTTPHeaders', '-'), 'http-headers')

    def test_consecutive_upper_case_middle_string(self):
        self.assertEqual(xform_name('MainHTTPHeaders'), 'main_http_headers')
        self.assertEqual(
            xform_name('MainHTTPHeaders', '-'), 'main-http-headers'
        )

    def test_s3_prefix(self):
        self.assertEqual(xform_name('S3BucketName'), 's3_bucket_name')

    def test_already_snake_cased(self):
        self.assertEqual(xform_name('leave_alone'), 'leave_alone')
        self.assertEqual(xform_name('s3_bucket_name'), 's3_bucket_name')
        self.assertEqual(xform_name('bucket_s3_name'), 'bucket_s3_name')

    def test_special_cases(self):
        # Some patterns don't actually match the rules we expect.
        self.assertEqual(
            xform_name('SwapEnvironmentCNAMEs'), 'swap_environment_cnames'
        )
        self.assertEqual(
            xform_name('SwapEnvironmentCNAMEs', '-'), 'swap-environment-cnames'
        )
        self.assertEqual(
            xform_name('CreateCachediSCSIVolume', '-'),
            'create-cached-iscsi-volume',
        )
        self.assertEqual(
            xform_name('DescribeCachediSCSIVolumes', '-'),
            'describe-cached-iscsi-volumes',
        )
        self.assertEqual(
            xform_name('DescribeStorediSCSIVolumes', '-'),
            'describe-stored-iscsi-volumes',
        )
        self.assertEqual(
            xform_name('CreateStorediSCSIVolume', '-'),
            'create-stored-iscsi-volume',
        )
        self.assertEqual(
            xform_name('sourceServerIDs', '-'), 'source-server-ids'
        )

    def test_special_case_ends_with_s(self):
        self.assertEqual(xform_name('GatewayARNs', '-'), 'gateway-arns')

    def test_partial_rename(self):
        transformed = xform_name('IPV6', '-')
        self.assertEqual(transformed, 'ipv6')
        transformed = xform_name('IPV6', '_')
        self.assertEqual(transformed, 'ipv6')

    def test_s3_partial_rename(self):
        transformed = xform_name('s3Resources', '-')
        self.assertEqual(transformed, 's3-resources')
        transformed = xform_name('s3Resources', '_')
        self.assertEqual(transformed, 's3_resources')


class TestValidateJMESPathForSet(unittest.TestCase):
    def setUp(self):
        super().setUp()
        self.data = {
            'Response': {
                'Thing': {
                    'Id': 1,
                    'Name': 'Thing #1',
                }
            },
            'Marker': 'some-token',
        }

    def test_invalid_exp(self):
        with self.assertRaises(InvalidExpressionError):
            validate_jmespath_for_set('Response.*.Name')

        with self.assertRaises(InvalidExpressionError):
            validate_jmespath_for_set('Response.Things[0]')

        with self.assertRaises(InvalidExpressionError):
            validate_jmespath_for_set('')

        with self.assertRaises(InvalidExpressionError):
            validate_jmespath_for_set('.')


class TestSetValueFromJMESPath(unittest.TestCase):
    def setUp(self):
        super().setUp()
        self.data = {
            'Response': {
                'Thing': {
                    'Id': 1,
                    'Name': 'Thing #1',
                }
            },
            'Marker': 'some-token',
        }

    def test_single_depth_existing(self):
        set_value_from_jmespath(self.data, 'Marker', 'new-token')
        self.assertEqual(self.data['Marker'], 'new-token')

    def test_single_depth_new(self):
        self.assertFalse('Limit' in self.data)
        set_value_from_jmespath(self.data, 'Limit', 100)
        self.assertEqual(self.data['Limit'], 100)

    def test_multiple_depth_existing(self):
        set_value_from_jmespath(self.data, 'Response.Thing.Name', 'New Name')
        self.assertEqual(self.data['Response']['Thing']['Name'], 'New Name')

    def test_multiple_depth_new(self):
        self.assertFalse('Brand' in self.data)
        set_value_from_jmespath(self.data, 'Brand.New', {'abc': 123})
        self.assertEqual(self.data['Brand']['New']['abc'], 123)


class TestParseEC2CredentialsFile(unittest.TestCase):
    def test_parse_ec2_content(self):
        contents = "AWSAccessKeyId=a\nAWSSecretKey=b\n"
        self.assertEqual(
            parse_key_val_file_contents(contents),
            {'AWSAccessKeyId': 'a', 'AWSSecretKey': 'b'},
        )

    def test_parse_ec2_content_empty(self):
        contents = ""
        self.assertEqual(parse_key_val_file_contents(contents), {})

    def test_key_val_pair_with_blank_lines(self):
        # The \n\n has an extra blank between the access/secret keys.
        contents = "AWSAccessKeyId=a\n\nAWSSecretKey=b\n"
        self.assertEqual(
            parse_key_val_file_contents(contents),
            {'AWSAccessKeyId': 'a', 'AWSSecretKey': 'b'},
        )

    def test_key_val_parser_lenient(self):
        # Ignore any line that does not have a '=' char in it.
        contents = "AWSAccessKeyId=a\nNOTKEYVALLINE\nAWSSecretKey=b\n"
        self.assertEqual(
            parse_key_val_file_contents(contents),
            {'AWSAccessKeyId': 'a', 'AWSSecretKey': 'b'},
        )

    def test_multiple_equals_on_line(self):
        contents = "AWSAccessKeyId=a\nAWSSecretKey=secret_key_with_equals=b\n"
        self.assertEqual(
            parse_key_val_file_contents(contents),
            {
                'AWSAccessKeyId': 'a',
                'AWSSecretKey': 'secret_key_with_equals=b',
            },
        )

    def test_os_error_raises_config_not_found(self):
        mock_open = mock.Mock()
        mock_open.side_effect = OSError()
        with self.assertRaises(ConfigNotFound):
            parse_key_val_file('badfile', _open=mock_open)


class TestParseTimestamps(unittest.TestCase):
    def test_parse_iso8601(self):
        self.assertEqual(
            parse_timestamp('1970-01-01T00:10:00.000Z'),
            datetime.datetime(1970, 1, 1, 0, 10, tzinfo=tzutc()),
        )

    def test_parse_epoch(self):
        self.assertEqual(
            parse_timestamp(1222172800),
            datetime.datetime(2008, 9, 23, 12, 26, 40, tzinfo=tzutc()),
        )

    def test_parse_epoch_zero_time(self):
        self.assertEqual(
            parse_timestamp(0),
            datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=tzutc()),
        )

    def test_parse_epoch_as_string(self):
        self.assertEqual(
            parse_timestamp('1222172800'),
            datetime.datetime(2008, 9, 23, 12, 26, 40, tzinfo=tzutc()),
        )

    def test_parse_rfc822(self):
        self.assertEqual(
            parse_timestamp('Wed, 02 Oct 2002 13:00:00 GMT'),
            datetime.datetime(2002, 10, 2, 13, 0, tzinfo=tzutc()),
        )

    def test_parse_gmt_in_uk_time(self):
        # In the UK the time switches from GMT to BST and back as part of
        # their daylight savings time. time.tzname will therefore report
        # both time zones. dateutil sees that the time zone is a local time
        # zone and so parses it as local time, but it ends up being BST
        # instead of GMT. To remedy this issue we can provide a time zone
        # context which will enforce GMT == UTC.
        with mock.patch('time.tzname', ('GMT', 'BST')):
            self.assertEqual(
                parse_timestamp('Wed, 02 Oct 2002 13:00:00 GMT'),
                datetime.datetime(2002, 10, 2, 13, 0, tzinfo=tzutc()),
            )

    def test_parse_invalid_timestamp(self):
        with self.assertRaises(ValueError):
            parse_timestamp('invalid date')

    def test_parse_timestamp_fails_with_bad_tzinfo(self):
        mock_tzinfo = mock.Mock()
        mock_tzinfo.__name__ = 'tzinfo'
        mock_tzinfo.side_effect = OSError()
        mock_get_tzinfo_options = mock.MagicMock(return_value=(mock_tzinfo,))

        with mock.patch(
            'botocore.utils.get_tzinfo_options', mock_get_tzinfo_options
        ):
            with self.assertRaises(RuntimeError):
                parse_timestamp(0)


class TestDatetime2Timestamp(unittest.TestCase):
    def test_datetime2timestamp_naive(self):
        self.assertEqual(
            datetime2timestamp(datetime.datetime(1970, 1, 2)), 86400
        )

    def test_datetime2timestamp_aware(self):
        tzinfo = tzoffset("BRST", -10800)
        self.assertEqual(
            datetime2timestamp(datetime.datetime(1970, 1, 2, tzinfo=tzinfo)),
            97200,
        )


class TestParseToUTCDatetime(unittest.TestCase):
    def test_handles_utc_time(self):
        original = datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=tzutc())
        self.assertEqual(parse_to_aware_datetime(original), original)

    def test_handles_other_timezone(self):
        tzinfo = tzoffset("BRST", -10800)
        original = datetime.datetime(2014, 1, 1, 0, 0, 0, tzinfo=tzinfo)
        self.assertEqual(parse_to_aware_datetime(original), original)

    def test_handles_naive_datetime(self):
        original = datetime.datetime(1970, 1, 1, 0, 0, 0)
        expected = datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=tzutc())
        self.assertEqual(parse_to_aware_datetime(original), expected)

    def test_handles_string_epoch(self):
        expected = datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=tzutc())
        self.assertEqual(parse_to_aware_datetime('0'), expected)

    def test_handles_int_epoch(self):
        expected = datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=tzutc())
        self.assertEqual(parse_to_aware_datetime(0), expected)

    def test_handles_full_iso_8601(self):
        expected = datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=tzutc())
        self.assertEqual(
            parse_to_aware_datetime('1970-01-01T00:00:00Z'), expected
        )

    def test_year_only_iso_8601(self):
        expected = datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=tzutc())
        self.assertEqual(parse_to_aware_datetime('1970-01-01'), expected)


class TestCachedProperty(unittest.TestCase):
    def test_cached_property_same_value(self):
        class CacheMe:
            @CachedProperty
            def foo(self):
                return 'foo'

        c = CacheMe()
        self.assertEqual(c.foo, 'foo')
        self.assertEqual(c.foo, 'foo')

    def test_cached_property_only_called_once(self):
        # Note: you would normally never want to cache
        # a property that returns a new value each time,
        # but this is done to demonstrate the caching behavior.

        class NoIncrement:
            def __init__(self):
                self.counter = 0

            @CachedProperty
            def current_value(self):
                self.counter += 1
                return self.counter

        c = NoIncrement()
        self.assertEqual(c.current_value, 1)
        # If the property wasn't cached, the next value should be
        # be 2, but because it's cached, we know the value will be 1.
        self.assertEqual(c.current_value, 1)


class TestArgumentGenerator(unittest.TestCase):
    def setUp(self):
        self.arg_generator = ArgumentGenerator()

    def assert_skeleton_from_model_is(self, model, generated_skeleton):
        shape = (
            DenormalizedStructureBuilder().with_members(model).build_model()
        )
        actual = self.arg_generator.generate_skeleton(shape)
        self.assertEqual(actual, generated_skeleton)

    def test_generate_string(self):
        self.assert_skeleton_from_model_is(
            model={'A': {'type': 'string'}}, generated_skeleton={'A': ''}
        )

    def test_generate_string_enum(self):
        enum_values = ['A', 'B', 'C']
        model = {'A': {'type': 'string', 'enum': enum_values}}
        shape = (
            DenormalizedStructureBuilder().with_members(model).build_model()
        )
        actual = self.arg_generator.generate_skeleton(shape)

        self.assertIn(actual['A'], enum_values)

    def test_generate_scalars(self):
        self.assert_skeleton_from_model_is(
            model={
                'A': {'type': 'string'},
                'B': {'type': 'integer'},
                'C': {'type': 'float'},
                'D': {'type': 'boolean'},
                'E': {'type': 'timestamp'},
                'F': {'type': 'double'},
            },
            generated_skeleton={
                'A': '',
                'B': 0,
                'C': 0.0,
                'D': True,
                'E': datetime.datetime(1970, 1, 1, 0, 0, 0),
                'F': 0.0,
            },
        )

    def test_will_use_member_names_for_string_values(self):
        self.arg_generator = ArgumentGenerator(use_member_names=True)
        self.assert_skeleton_from_model_is(
            model={
                'A': {'type': 'string'},
                'B': {'type': 'integer'},
                'C': {'type': 'float'},
                'D': {'type': 'boolean'},
            },
            generated_skeleton={
                'A': 'A',
                'B': 0,
                'C': 0.0,
                'D': True,
            },
        )

    def test_will_use_member_names_for_string_values_of_list(self):
        self.arg_generator = ArgumentGenerator(use_member_names=True)
        # We're not using assert_skeleton_from_model_is
        # because we can't really control the name of strings shapes
        # being used in the DenormalizedStructureBuilder. We can only
        # control the name of structures and list shapes.
        shape_map = ShapeResolver(
            {
                'InputShape': {
                    'type': 'structure',
                    'members': {
                        'StringList': {'shape': 'StringList'},
                    },
                },
                'StringList': {
                    'type': 'list',
                    'member': {'shape': 'StringType'},
                },
                'StringType': {
                    'type': 'string',
                },
            }
        )
        shape = shape_map.get_shape_by_name('InputShape')
        actual = self.arg_generator.generate_skeleton(shape)

        expected = {'StringList': ['StringType']}
        self.assertEqual(actual, expected)

    def test_generate_nested_structure(self):
        self.assert_skeleton_from_model_is(
            model={
                'A': {
                    'type': 'structure',
                    'members': {
                        'B': {'type': 'string'},
                    },
                }
            },
            generated_skeleton={'A': {'B': ''}},
        )

    def test_generate_scalar_list(self):
        self.assert_skeleton_from_model_is(
            model={
                'A': {'type': 'list', 'member': {'type': 'string'}},
            },
            generated_skeleton={
                'A': [''],
            },
        )

    def test_generate_scalar_map(self):
        self.assert_skeleton_from_model_is(
            model={
                'A': {
                    'type': 'map',
                    'key': {'type': 'string'},
                    'value': {'type': 'string'},
                }
            },
            generated_skeleton={
                'A': {
                    'KeyName': '',
                }
            },
        )

    def test_handles_recursive_shapes(self):
        # We're not using assert_skeleton_from_model_is
        # because we can't use a DenormalizedStructureBuilder,
        # we need a normalized model to represent recursive
        # shapes.
        shape_map = ShapeResolver(
            {
                'InputShape': {
                    'type': 'structure',
                    'members': {
                        'A': {'shape': 'RecursiveStruct'},
                        'B': {'shape': 'StringType'},
                    },
                },
                'RecursiveStruct': {
                    'type': 'structure',
                    'members': {
                        'C': {'shape': 'RecursiveStruct'},
                        'D': {'shape': 'StringType'},
                    },
                },
                'StringType': {
                    'type': 'string',
                },
            }
        )
        shape = shape_map.get_shape_by_name('InputShape')
        actual = self.arg_generator.generate_skeleton(shape)
        expected = {
            'A': {
                'C': {
                    # For recurisve shapes, we'll just show
                    # an empty dict.
                },
                'D': '',
            },
            'B': '',
        }
        self.assertEqual(actual, expected)


class TestChecksums(unittest.TestCase):
    def test_empty_hash(self):
        self.assertEqual(
            calculate_sha256(io.BytesIO(b''), as_hex=True),
            'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855',
        )

    def test_as_hex(self):
        self.assertEqual(
            calculate_sha256(io.BytesIO(b'hello world'), as_hex=True),
            'b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9',
        )

    def test_as_binary(self):
        self.assertEqual(
            calculate_sha256(io.BytesIO(b'hello world'), as_hex=False),
            (
                b"\xb9M'\xb9\x93M>\x08\xa5.R\xd7\xda}\xab\xfa\xc4\x84\xef"
                b"\xe3zS\x80\xee\x90\x88\xf7\xac\xe2\xef\xcd\xe9"
            ),
        )


class TestTreeHash(unittest.TestCase):
    # Note that for these tests I've independently verified
    # what the expected tree hashes should be from other
    # SDK implementations.

    def test_empty_tree_hash(self):
        self.assertEqual(
            calculate_tree_hash(io.BytesIO(b'')),
            'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855',
        )

    def test_tree_hash_less_than_one_mb(self):
        one_k = io.BytesIO(b'a' * 1024)
        self.assertEqual(
            calculate_tree_hash(one_k),
            '2edc986847e209b4016e141a6dc8716d3207350f416969382d431539bf292e4a',
        )

    def test_tree_hash_exactly_one_mb(self):
        one_meg_bytestring = b'a' * (1 * 1024 * 1024)
        one_meg = io.BytesIO(one_meg_bytestring)
        self.assertEqual(
            calculate_tree_hash(one_meg),
            '9bc1b2a288b26af7257a36277ae3816a7d4f16e89c1e7e77d0a5c48bad62b360',
        )

    def test_tree_hash_multiple_of_one_mb(self):
        four_mb = io.BytesIO(b'a' * (4 * 1024 * 1024))
        self.assertEqual(
            calculate_tree_hash(four_mb),
            '9491cb2ed1d4e7cd53215f4017c23ec4ad21d7050a1e6bb636c4f67e8cddb844',
        )

    def test_tree_hash_offset_of_one_mb_multiple(self):
        offset_four_mb = io.BytesIO(b'a' * (4 * 1024 * 1024) + b'a' * 20)
        self.assertEqual(
            calculate_tree_hash(offset_four_mb),
            '12f3cbd6101b981cde074039f6f728071da8879d6f632de8afc7cdf00661b08f',
        )


class TestIsValidEndpointURL(unittest.TestCase):
    def test_dns_name_is_valid(self):
        self.assertTrue(is_valid_endpoint_url('https://s3.amazonaws.com/'))

    def test_ip_address_is_allowed(self):
        self.assertTrue(is_valid_endpoint_url('https://10.10.10.10/'))

    def test_path_component_ignored(self):
        self.assertTrue(
            is_valid_endpoint_url('https://foo.bar.com/other/path/')
        )

    def test_can_have_port(self):
        self.assertTrue(is_valid_endpoint_url('https://foo.bar.com:12345/'))

    def test_ip_can_have_port(self):
        self.assertTrue(is_valid_endpoint_url('https://10.10.10.10:12345/'))

    def test_cannot_have_spaces(self):
        self.assertFalse(is_valid_endpoint_url('https://my invalid name/'))

    def test_missing_scheme(self):
        self.assertFalse(is_valid_endpoint_url('foo.bar.com'))

    def test_no_new_lines(self):
        self.assertFalse(is_valid_endpoint_url('https://foo.bar.com\nbar/'))

    def test_long_hostname(self):
        long_hostname = 'htps://%s.com' % ('a' * 256)
        self.assertFalse(is_valid_endpoint_url(long_hostname))

    def test_hostname_can_end_with_dot(self):
        self.assertTrue(is_valid_endpoint_url('https://foo.bar.com./'))

    def test_hostname_no_dots(self):
        self.assertTrue(is_valid_endpoint_url('https://foo/'))


class TestFixS3Host(unittest.TestCase):
    def test_fix_s3_host_initial(self):
        request = AWSRequest(
            method='PUT',
            headers={},
            url='https://s3-us-west-2.amazonaws.com/bucket/key.txt',
        )
        region_name = 'us-west-2'
        signature_version = 's3'
        fix_s3_host(
            request=request,
            signature_version=signature_version,
            region_name=region_name,
        )
        self.assertEqual(
            request.url, 'https://bucket.s3-us-west-2.amazonaws.com/key.txt'
        )
        self.assertEqual(request.auth_path, '/bucket/key.txt')

    def test_fix_s3_host_only_applied_once(self):
        request = AWSRequest(
            method='PUT',
            headers={},
            url='https://s3.us-west-2.amazonaws.com/bucket/key.txt',
        )
        region_name = 'us-west-2'
        signature_version = 's3'
        fix_s3_host(
            request=request,
            signature_version=signature_version,
            region_name=region_name,
        )
        # Calling the handler again should not affect the end result:
        fix_s3_host(
            request=request,
            signature_version=signature_version,
            region_name=region_name,
        )
        self.assertEqual(
            request.url, 'https://bucket.s3.us-west-2.amazonaws.com/key.txt'
        )
        # This was a bug previously.  We want to make sure that
        # calling fix_s3_host() again does not alter the auth_path.
        # Otherwise we'll get signature errors.
        self.assertEqual(request.auth_path, '/bucket/key.txt')

    def test_dns_style_not_used_for_get_bucket_location(self):
        original_url = 'https://s3-us-west-2.amazonaws.com/bucket?location'
        request = AWSRequest(
            method='GET',
            headers={},
            url=original_url,
        )
        signature_version = 's3'
        region_name = 'us-west-2'
        fix_s3_host(
            request=request,
            signature_version=signature_version,
            region_name=region_name,
        )
        # The request url should not have been modified because this is
        # a request for GetBucketLocation.
        self.assertEqual(request.url, original_url)

    def test_can_provide_default_endpoint_url(self):
        request = AWSRequest(
            method='PUT',
            headers={},
            url='https://s3-us-west-2.amazonaws.com/bucket/key.txt',
        )
        region_name = 'us-west-2'
        signature_version = 's3'
        fix_s3_host(
            request=request,
            signature_version=signature_version,
            region_name=region_name,
            default_endpoint_url='foo.s3.amazonaws.com',
        )
        self.assertEqual(
            request.url, 'https://bucket.foo.s3.amazonaws.com/key.txt'
        )

    def test_no_endpoint_url_uses_request_url(self):
        request = AWSRequest(
            method='PUT',
            headers={},
            url='https://s3-us-west-2.amazonaws.com/bucket/key.txt',
        )
        region_name = 'us-west-2'
        signature_version = 's3'
        fix_s3_host(
            request=request,
            signature_version=signature_version,
            region_name=region_name,
            # A value of None means use the url in the current request.
            default_endpoint_url=None,
        )
        self.assertEqual(
            request.url, 'https://bucket.s3-us-west-2.amazonaws.com/key.txt'
        )


class TestSwitchToVirtualHostStyle(unittest.TestCase):
    def test_switch_to_virtual_host_style(self):
        request = AWSRequest(
            method='PUT',
            headers={},
            url='https://foo.amazonaws.com/bucket/key.txt',
        )
        region_name = 'us-west-2'
        signature_version = 's3'
        switch_to_virtual_host_style(
            request=request,
            signature_version=signature_version,
            region_name=region_name,
        )
        self.assertEqual(
            request.url, 'https://bucket.foo.amazonaws.com/key.txt'
        )
        self.assertEqual(request.auth_path, '/bucket/key.txt')

    def test_uses_default_endpoint(self):
        request = AWSRequest(
            method='PUT',
            headers={},
            url='https://foo.amazonaws.com/bucket/key.txt',
        )
        region_name = 'us-west-2'
        signature_version = 's3'
        switch_to_virtual_host_style(
            request=request,
            signature_version=signature_version,
            region_name=region_name,
            default_endpoint_url='s3.amazonaws.com',
        )
        self.assertEqual(
            request.url, 'https://bucket.s3.amazonaws.com/key.txt'
        )
        self.assertEqual(request.auth_path, '/bucket/key.txt')

    def test_throws_invalid_dns_name_error(self):
        request = AWSRequest(
            method='PUT',
            headers={},
            url='https://foo.amazonaws.com/mybucket.foo/key.txt',
        )
        region_name = 'us-west-2'
        signature_version = 's3'
        with self.assertRaises(InvalidDNSNameError):
            switch_to_virtual_host_style(
                request=request,
                signature_version=signature_version,
                region_name=region_name,
            )

    def test_fix_s3_host_only_applied_once(self):
        request = AWSRequest(
            method='PUT',
            headers={},
            url='https://foo.amazonaws.com/bucket/key.txt',
        )
        region_name = 'us-west-2'
        signature_version = 's3'
        switch_to_virtual_host_style(
            request=request,
            signature_version=signature_version,
            region_name=region_name,
        )
        # Calling the handler again should not affect the end result:
        switch_to_virtual_host_style(
            request=request,
            signature_version=signature_version,
            region_name=region_name,
        )
        self.assertEqual(
            request.url, 'https://bucket.foo.amazonaws.com/key.txt'
        )
        # This was a bug previously.  We want to make sure that
        # calling fix_s3_host() again does not alter the auth_path.
        # Otherwise we'll get signature errors.
        self.assertEqual(request.auth_path, '/bucket/key.txt')

    def test_virtual_host_style_for_make_bucket(self):
        request = AWSRequest(
            method='PUT', headers={}, url='https://foo.amazonaws.com/bucket'
        )
        region_name = 'us-west-2'
        signature_version = 's3'
        switch_to_virtual_host_style(
            request=request,
            signature_version=signature_version,
            region_name=region_name,
        )
        self.assertEqual(request.url, 'https://bucket.foo.amazonaws.com/')

    def test_virtual_host_style_not_used_for_get_bucket_location(self):
        original_url = 'https://foo.amazonaws.com/bucket?location'
        request = AWSRequest(
            method='GET',
            headers={},
            url=original_url,
        )
        signature_version = 's3'
        region_name = 'us-west-2'
        switch_to_virtual_host_style(
            request=request,
            signature_version=signature_version,
            region_name=region_name,
        )
        # The request url should not have been modified because this is
        # a request for GetBucketLocation.
        self.assertEqual(request.url, original_url)

    def test_virtual_host_style_not_used_for_list_buckets(self):
        original_url = 'https://foo.amazonaws.com/'
        request = AWSRequest(
            method='GET',
            headers={},
            url=original_url,
        )
        signature_version = 's3'
        region_name = 'us-west-2'
        switch_to_virtual_host_style(
            request=request,
            signature_version=signature_version,
            region_name=region_name,
        )
        # The request url should not have been modified because this is
        # a request for GetBucketLocation.
        self.assertEqual(request.url, original_url)

    def test_is_unaffected_by_sigv4(self):
        request = AWSRequest(
            method='PUT',
            headers={},
            url='https://foo.amazonaws.com/bucket/key.txt',
        )
        region_name = 'us-west-2'
        signature_version = 's3v4'
        switch_to_virtual_host_style(
            request=request,
            signature_version=signature_version,
            region_name=region_name,
            default_endpoint_url='s3.amazonaws.com',
        )
        self.assertEqual(
            request.url, 'https://bucket.s3.amazonaws.com/key.txt'
        )


class TestSwitchToChunkedEncodingForNonSeekableObjects(unittest.TestCase):
    def test_switch_to_chunked_encodeing_for_stream_like_object(self):
        request = AWSRequest(
            method='POST',
            headers={},
            data=io.BufferedIOBase(b"some initial binary data"),
            url='https://foo.amazonaws.com/bucket/key.txt',
        )
        prepared_request = request.prepare()
        self.assertEqual(
            prepared_request.headers, {'Transfer-Encoding': 'chunked'}
        )


class TestInstanceCache(unittest.TestCase):
    class DummyClass:
        def __init__(self, cache):
            self._instance_cache = cache

        @instance_cache
        def add(self, x, y):
            return x + y

        @instance_cache
        def sub(self, x, y):
            return x - y

    def setUp(self):
        self.cache = {}

    def test_cache_single_method_call(self):
        adder = self.DummyClass(self.cache)
        self.assertEqual(adder.add(2, 1), 3)
        # This should result in one entry in the cache.
        self.assertEqual(len(self.cache), 1)
        # When we call the method with the same args,
        # we should reuse the same entry in the cache.
        self.assertEqual(adder.add(2, 1), 3)
        self.assertEqual(len(self.cache), 1)

    def test_can_cache_multiple_methods(self):
        adder = self.DummyClass(self.cache)
        adder.add(2, 1)

        # A different method results in a new cache entry,
        # so now there should be two elements in the cache.
        self.assertEqual(adder.sub(2, 1), 1)
        self.assertEqual(len(self.cache), 2)
        self.assertEqual(adder.sub(2, 1), 1)

    def test_can_cache_kwargs(self):
        adder = self.DummyClass(self.cache)
        adder.add(x=2, y=1)
        self.assertEqual(adder.add(x=2, y=1), 3)
        self.assertEqual(len(self.cache), 1)


class TestMergeDicts(unittest.TestCase):
    def test_merge_dicts_overrides(self):
        first = {
            'foo': {'bar': {'baz': {'one': 'ORIGINAL', 'two': 'ORIGINAL'}}}
        }
        second = {'foo': {'bar': {'baz': {'one': 'UPDATE'}}}}

        merge_dicts(first, second)
        # The value from the second dict wins.
        self.assertEqual(first['foo']['bar']['baz']['one'], 'UPDATE')
        # And we still preserve the other attributes.
        self.assertEqual(first['foo']['bar']['baz']['two'], 'ORIGINAL')

    def test_merge_dicts_new_keys(self):
        first = {
            'foo': {'bar': {'baz': {'one': 'ORIGINAL', 'two': 'ORIGINAL'}}}
        }
        second = {'foo': {'bar': {'baz': {'three': 'UPDATE'}}}}

        merge_dicts(first, second)
        self.assertEqual(first['foo']['bar']['baz']['one'], 'ORIGINAL')
        self.assertEqual(first['foo']['bar']['baz']['two'], 'ORIGINAL')
        self.assertEqual(first['foo']['bar']['baz']['three'], 'UPDATE')

    def test_merge_empty_dict_does_nothing(self):
        first = {'foo': {'bar': 'baz'}}
        merge_dicts(first, {})
        self.assertEqual(first, {'foo': {'bar': 'baz'}})

    def test_more_than_one_sub_dict(self):
        first = {
            'one': {'inner': 'ORIGINAL', 'inner2': 'ORIGINAL'},
            'two': {'inner': 'ORIGINAL', 'inner2': 'ORIGINAL'},
        }
        second = {'one': {'inner': 'UPDATE'}, 'two': {'inner': 'UPDATE'}}

        merge_dicts(first, second)
        self.assertEqual(first['one']['inner'], 'UPDATE')
        self.assertEqual(first['one']['inner2'], 'ORIGINAL')

        self.assertEqual(first['two']['inner'], 'UPDATE')
        self.assertEqual(first['two']['inner2'], 'ORIGINAL')

    def test_new_keys(self):
        first = {'one': {'inner': 'ORIGINAL'}, 'two': {'inner': 'ORIGINAL'}}
        second = {'three': {'foo': {'bar': 'baz'}}}
        # In this case, second has no keys in common, but we'd still expect
        # this to get merged.
        merge_dicts(first, second)
        self.assertEqual(first['three']['foo']['bar'], 'baz')

    def test_list_values_no_append(self):
        dict1 = {'Foo': ['old_foo_value']}
        dict2 = {'Foo': ['new_foo_value']}
        merge_dicts(dict1, dict2)
        self.assertEqual(dict1, {'Foo': ['new_foo_value']})

    def test_list_values_append(self):
        dict1 = {'Foo': ['old_foo_value']}
        dict2 = {'Foo': ['new_foo_value']}
        merge_dicts(dict1, dict2, append_lists=True)
        self.assertEqual(dict1, {'Foo': ['old_foo_value', 'new_foo_value']})

    def test_list_values_mismatching_types(self):
        dict1 = {'Foo': 'old_foo_value'}
        dict2 = {'Foo': ['new_foo_value']}
        merge_dicts(dict1, dict2, append_lists=True)
        self.assertEqual(dict1, {'Foo': ['new_foo_value']})

    def test_list_values_missing_key(self):
        dict1 = {}
        dict2 = {'Foo': ['foo_value']}
        merge_dicts(dict1, dict2, append_lists=True)
        self.assertEqual(dict1, {'Foo': ['foo_value']})


class TestLowercaseDict(unittest.TestCase):
    def test_lowercase_dict_empty(self):
        original = {}
        copy = lowercase_dict(original)
        self.assertEqual(original, copy)

    def test_lowercase_dict_original_keys_lower(self):
        original = {
            'lower_key1': 1,
            'lower_key2': 2,
        }
        copy = lowercase_dict(original)
        self.assertEqual(original, copy)

    def test_lowercase_dict_original_keys_mixed(self):
        original = {
            'SOME_KEY': 'value',
            'AnOTher_OnE': 'anothervalue',
        }
        copy = lowercase_dict(original)
        expected = {
            'some_key': 'value',
            'another_one': 'anothervalue',
        }
        self.assertEqual(expected, copy)


class TestGetServiceModuleName(unittest.TestCase):
    def setUp(self):
        self.service_description = {
            'metadata': {
                'serviceFullName': 'AWS MyService',
                'apiVersion': '2014-01-01',
                'endpointPrefix': 'myservice',
                'signatureVersion': 'v4',
                'protocol': 'query',
            },
            'operations': {},
            'shapes': {},
        }
        self.service_model = ServiceModel(
            self.service_description, 'myservice'
        )

    def test_default(self):
        self.assertEqual(
            get_service_module_name(self.service_model), 'MyService'
        )

    def test_client_name_with_amazon(self):
        self.service_description['metadata'][
            'serviceFullName'
        ] = 'Amazon MyService'
        self.assertEqual(
            get_service_module_name(self.service_model), 'MyService'
        )

    def test_client_name_using_abreviation(self):
        self.service_description['metadata'][
            'serviceAbbreviation'
        ] = 'Abbreviation'
        self.assertEqual(
            get_service_module_name(self.service_model), 'Abbreviation'
        )

    def test_client_name_with_non_alphabet_characters(self):
        self.service_description['metadata'][
            'serviceFullName'
        ] = 'Amazon My-Service'
        self.assertEqual(
            get_service_module_name(self.service_model), 'MyService'
        )

    def test_client_name_with_no_full_name_or_abbreviation(self):
        del self.service_description['metadata']['serviceFullName']
        self.assertEqual(
            get_service_module_name(self.service_model), 'myservice'
        )


class TestPercentEncodeSequence(unittest.TestCase):
    def test_percent_encode_empty(self):
        self.assertEqual(percent_encode_sequence({}), '')

    def test_percent_encode_special_chars(self):
        self.assertEqual(
            percent_encode_sequence({'k1': 'with spaces++/'}),
            'k1=with%20spaces%2B%2B%2F',
        )

    def test_percent_encode_string_string_tuples(self):
        self.assertEqual(
            percent_encode_sequence([('k1', 'v1'), ('k2', 'v2')]),
            'k1=v1&k2=v2',
        )

    def test_percent_encode_dict_single_pair(self):
        self.assertEqual(percent_encode_sequence({'k1': 'v1'}), 'k1=v1')

    def test_percent_encode_dict_string_string(self):
        self.assertEqual(
            percent_encode_sequence({'k1': 'v1', 'k2': 'v2'}), 'k1=v1&k2=v2'
        )

    def test_percent_encode_single_list_of_values(self):
        self.assertEqual(
            percent_encode_sequence({'k1': ['a', 'b', 'c']}), 'k1=a&k1=b&k1=c'
        )

    def test_percent_encode_list_values_of_string(self):
        self.assertEqual(
            percent_encode_sequence(
                {'k1': ['a', 'list'], 'k2': ['another', 'list']}
            ),
            'k1=a&k1=list&k2=another&k2=list',
        )


class TestPercentEncode(unittest.TestCase):
    def test_percent_encode_obj(self):
        self.assertEqual(percent_encode(1), '1')

    def test_percent_encode_text(self):
        self.assertEqual(percent_encode(''), '')
        self.assertEqual(percent_encode('a'), 'a')
        self.assertEqual(percent_encode('\u0000'), '%00')
        # Codepoint > 0x7f
        self.assertEqual(percent_encode('\u2603'), '%E2%98%83')
        # Codepoint > 0xffff
        self.assertEqual(percent_encode('\U0001f32e'), '%F0%9F%8C%AE')

    def test_percent_encode_bytes(self):
        self.assertEqual(percent_encode(b''), '')
        self.assertEqual(percent_encode(b'a'), 'a')
        self.assertEqual(percent_encode(b'\x00'), '%00')
        # UTF-8 Snowman
        self.assertEqual(percent_encode(b'\xe2\x98\x83'), '%E2%98%83')
        # Arbitrary bytes (not valid UTF-8).
        self.assertEqual(percent_encode(b'\x80\x00'), '%80%00')


class TestSwitchHostS3Accelerate(unittest.TestCase):
    def setUp(self):
        self.original_url = 'https://s3.amazonaws.com/foo/key.txt'
        self.request = AWSRequest(
            method='PUT', headers={}, url=self.original_url
        )
        self.client_config = Config()
        self.request.context['client_config'] = self.client_config

    def test_switch_host(self):
        switch_host_s3_accelerate(self.request, 'PutObject')
        self.assertEqual(
            self.request.url, 'https://s3-accelerate.amazonaws.com/foo/key.txt'
        )

    def test_do_not_switch_black_listed_operations(self):
        # It should not get switched for ListBuckets, DeleteBucket, and
        # CreateBucket
        blacklist_ops = ['ListBuckets', 'DeleteBucket', 'CreateBucket']
        for op_name in blacklist_ops:
            switch_host_s3_accelerate(self.request, op_name)
            self.assertEqual(self.request.url, self.original_url)

    def test_uses_original_endpoint_scheme(self):
        self.request.url = 'http://s3.amazonaws.com/foo/key.txt'
        switch_host_s3_accelerate(self.request, 'PutObject')
        self.assertEqual(
            self.request.url, 'http://s3-accelerate.amazonaws.com/foo/key.txt'
        )

    def test_uses_dualstack(self):
        self.client_config.s3 = {'use_dualstack_endpoint': True}
        self.original_url = 'https://s3.dualstack.amazonaws.com/foo/key.txt'
        self.request = AWSRequest(
            method='PUT', headers={}, url=self.original_url
        )
        self.request.context['client_config'] = self.client_config
        switch_host_s3_accelerate(self.request, 'PutObject')
        self.assertEqual(
            self.request.url,
            'https://s3-accelerate.dualstack.amazonaws.com/foo/key.txt',
        )


class TestDeepMerge(unittest.TestCase):
    def test_simple_merge(self):
        a = {'key': 'value'}
        b = {'otherkey': 'othervalue'}
        deep_merge(a, b)

        expected = {'key': 'value', 'otherkey': 'othervalue'}
        self.assertEqual(a, expected)

    def test_merge_list(self):
        # Lists are treated as opaque data and so no effort should be made to
        # combine them.
        a = {'key': ['original']}
        b = {'key': ['new']}
        deep_merge(a, b)
        self.assertEqual(a, {'key': ['new']})

    def test_merge_number(self):
        # The value from b is always taken
        a = {'key': 10}
        b = {'key': 45}
        deep_merge(a, b)
        self.assertEqual(a, {'key': 45})

        a = {'key': 45}
        b = {'key': 10}
        deep_merge(a, b)
        self.assertEqual(a, {'key': 10})

    def test_merge_boolean(self):
        # The value from b is always taken
        a = {'key': False}
        b = {'key': True}
        deep_merge(a, b)
        self.assertEqual(a, {'key': True})

        a = {'key': True}
        b = {'key': False}
        deep_merge(a, b)
        self.assertEqual(a, {'key': False})

    def test_merge_string(self):
        a = {'key': 'value'}
        b = {'key': 'othervalue'}
        deep_merge(a, b)
        self.assertEqual(a, {'key': 'othervalue'})

    def test_merge_overrides_value(self):
        # The value from b is always taken, even when it's a different type
        a = {'key': 'original'}
        b = {'key': {'newkey': 'newvalue'}}
        deep_merge(a, b)
        self.assertEqual(a, {'key': {'newkey': 'newvalue'}})

        a = {'key': {'anotherkey': 'value'}}
        b = {'key': 'newvalue'}
        deep_merge(a, b)
        self.assertEqual(a, {'key': 'newvalue'})

    def test_deep_merge(self):
        a = {
            'first': {
                'second': {'key': 'value', 'otherkey': 'othervalue'},
                'key': 'value',
            }
        }
        b = {
            'first': {
                'second': {
                    'otherkey': 'newvalue',
                    'yetanotherkey': 'yetanothervalue',
                }
            }
        }
        deep_merge(a, b)

        expected = {
            'first': {
                'second': {
                    'key': 'value',
                    'otherkey': 'newvalue',
                    'yetanotherkey': 'yetanothervalue',
                },
                'key': 'value',
            }
        }
        self.assertEqual(a, expected)


class TestS3RegionRedirector(unittest.TestCase):
    def setUp(self):
        self.client = mock.Mock()
        self.client._ruleset_resolver = EndpointRulesetResolver(
            endpoint_ruleset_data={
                'version': '1.0',
                'parameters': {},
                'rules': [],
            },
            partition_data={},
            service_model=None,
            builtins={},
            client_context=None,
            event_emitter=None,
            use_ssl=True,
            requested_auth_scheme=None,
        )
        self.client._ruleset_resolver.construct_endpoint = mock.Mock(
            return_value=RuleSetEndpoint(
                url='https://new-endpoint.amazonaws.com',
                properties={},
                headers={},
            )
        )
        self.cache = {}
        self.redirector = S3RegionRedirectorv2(None, self.client)
        self.set_client_response_headers({})
        self.operation = mock.Mock()
        self.operation.name = 'foo'

    def set_client_response_headers(self, headers):
        error_response = ClientError(
            {
                'Error': {'Code': '', 'Message': ''},
                'ResponseMetadata': {'HTTPHeaders': headers},
            },
            'HeadBucket',
        )
        success_response = {'ResponseMetadata': {'HTTPHeaders': headers}}
        self.client.head_bucket.side_effect = [
            error_response,
            success_response,
        ]

    def test_set_request_url(self):
        old_url = 'https://us-west-2.amazonaws.com/foo'
        new_endpoint = 'https://eu-central-1.amazonaws.com'
        new_url = self.redirector.set_request_url(old_url, new_endpoint)
        self.assertEqual(new_url, 'https://eu-central-1.amazonaws.com/foo')

    def test_set_request_url_keeps_old_scheme(self):
        old_url = 'http://us-west-2.amazonaws.com/foo'
        new_endpoint = 'https://eu-central-1.amazonaws.com'
        new_url = self.redirector.set_request_url(old_url, new_endpoint)
        self.assertEqual(new_url, 'http://eu-central-1.amazonaws.com/foo')

    def test_sets_signing_context_from_cache(self):
        self.cache['foo'] = 'new-region-1'
        self.redirector = S3RegionRedirectorv2(
            None, self.client, cache=self.cache
        )
        params = {'Bucket': 'foo'}
        builtins = {'AWS::Region': 'old-region-1'}
        self.redirector.redirect_from_cache(builtins, params)
        self.assertEqual(builtins.get('AWS::Region'), 'new-region-1')

    def test_only_changes_context_if_bucket_in_cache(self):
        self.cache['foo'] = 'new-region-1'
        self.redirector = S3RegionRedirectorv2(
            None, self.client, cache=self.cache
        )
        params = {'Bucket': 'bar'}
        builtins = {'AWS::Region': 'old-region-1'}
        self.redirector.redirect_from_cache(builtins, params)
        self.assertEqual(builtins.get('AWS::Region'), 'old-region-1')

    def test_redirect_from_error(self):
        request_dict = {
            'context': {
                's3_redirect': {
                    'bucket': 'foo',
                    'redirected': False,
                    'params': {'Bucket': 'foo'},
                },
                'signing': {
                    'region': 'us-west-2',
                },
            },
            'url': 'https://us-west-2.amazonaws.com/foo',
        }
        response = (
            None,
            {
                'Error': {
                    'Code': 'PermanentRedirect',
                    'Endpoint': 'foo.eu-central-1.amazonaws.com',
                    'Bucket': 'foo',
                },
                'ResponseMetadata': {
                    'HTTPHeaders': {'x-amz-bucket-region': 'eu-central-1'}
                },
            },
        )

        self.client._ruleset_resolver.construct_endpoint.return_value = (
            RuleSetEndpoint(
                url='https://eu-central-1.amazonaws.com/foo',
                properties={
                    'authSchemes': [
                        {
                            'name': 'sigv4',
                            'signingRegion': 'eu-central-1',
                            'disableDoubleEncoding': True,
                        }
                    ]
                },
                headers={},
            )
        )

        redirect_response = self.redirector.redirect_from_error(
            request_dict, response, self.operation
        )

        # The response needs to be 0 so that there is no retry delay
        self.assertEqual(redirect_response, 0)

        self.assertEqual(
            request_dict['url'], 'https://eu-central-1.amazonaws.com/foo'
        )

        expected_signing_context = {
            'region': 'eu-central-1',
            'disableDoubleEncoding': True,
        }
        signing_context = request_dict['context'].get('signing')
        self.assertEqual(signing_context, expected_signing_context)
        self.assertTrue(
            request_dict['context']['s3_redirect'].get('redirected')
        )

    def test_does_not_redirect_if_previously_redirected(self):
        request_dict = {
            'context': {
                'signing': {'bucket': 'foo', 'region': 'us-west-2'},
                's3_redirected': True,
            },
            'url': 'https://us-west-2.amazonaws.com/foo',
        }
        response = (
            None,
            {
                'Error': {
                    'Code': '400',
                    'Message': 'Bad Request',
                },
                'ResponseMetadata': {
                    'HTTPHeaders': {'x-amz-bucket-region': 'us-west-2'}
                },
            },
        )
        redirect_response = self.redirector.redirect_from_error(
            request_dict, response, self.operation
        )
        self.assertIsNone(redirect_response)

    def test_does_not_redirect_unless_permanentredirect_recieved(self):
        request_dict = {}
        response = (None, {})
        redirect_response = self.redirector.redirect_from_error(
            request_dict, response, self.operation
        )
        self.assertIsNone(redirect_response)
        self.assertEqual(request_dict, {})

    def test_does_not_redirect_if_region_cannot_be_found(self):
        request_dict = {
            'url': 'https://us-west-2.amazonaws.com/foo',
            'context': {
                's3_redirect': {
                    'bucket': 'foo',
                    'redirected': False,
                    'params': {'Bucket': 'foo'},
                },
                'signing': {},
            },
        }
        response = (
            None,
            {
                'Error': {
                    'Code': 'PermanentRedirect',
                    'Endpoint': 'foo.eu-central-1.amazonaws.com',
                    'Bucket': 'foo',
                },
                'ResponseMetadata': {'HTTPHeaders': {}},
            },
        )

        redirect_response = self.redirector.redirect_from_error(
            request_dict, response, self.operation
        )

        self.assertIsNone(redirect_response)

    def test_redirects_301(self):
        request_dict = {
            'url': 'https://us-west-2.amazonaws.com/foo',
            'context': {
                's3_redirect': {
                    'bucket': 'foo',
                    'redirected': False,
                    'params': {'Bucket': 'foo'},
                },
                'signing': {},
            },
        }
        response = (
            None,
            {
                'Error': {'Code': '301', 'Message': 'Moved Permanently'},
                'ResponseMetadata': {
                    'HTTPHeaders': {'x-amz-bucket-region': 'eu-central-1'}
                },
            },
        )

        self.operation.name = 'HeadObject'
        redirect_response = self.redirector.redirect_from_error(
            request_dict, response, self.operation
        )
        self.assertEqual(redirect_response, 0)

        self.operation.name = 'ListObjects'
        redirect_response = self.redirector.redirect_from_error(
            request_dict, response, self.operation
        )
        self.assertIsNone(redirect_response)

    def test_redirects_400_head_bucket(self):
        request_dict = {
            'url': 'https://us-west-2.amazonaws.com/foo',
            'context': {
                's3_redirect': {
                    'bucket': 'foo',
                    'redirected': False,
                    'params': {'Bucket': 'foo'},
                },
                'signing': {},
            },
        }
        response = (
            None,
            {
                'Error': {'Code': '400', 'Message': 'Bad Request'},
                'ResponseMetadata': {
                    'HTTPHeaders': {'x-amz-bucket-region': 'eu-central-1'}
                },
            },
        )

        self.operation.name = 'HeadObject'
        redirect_response = self.redirector.redirect_from_error(
            request_dict, response, self.operation
        )
        self.assertEqual(redirect_response, 0)

        self.operation.name = 'ListObjects'
        redirect_response = self.redirector.redirect_from_error(
            request_dict, response, self.operation
        )
        self.assertIsNone(redirect_response)

    def test_does_not_redirect_400_head_bucket_no_region_header(self):
        # We should not redirect a 400 Head* if the region header is not
        # present as this will lead to infinitely calling HeadBucket.
        request_dict = {
            'url': 'https://us-west-2.amazonaws.com/foo',
            'context': {'signing': {'bucket': 'foo'}},
        }
        response = (
            None,
            {
                'Error': {'Code': '400', 'Message': 'Bad Request'},
                'ResponseMetadata': {'HTTPHeaders': {}},
            },
        )

        self.operation.name = 'HeadBucket'
        redirect_response = self.redirector.redirect_from_error(
            request_dict, response, self.operation
        )
        head_bucket_calls = self.client.head_bucket.call_count
        self.assertIsNone(redirect_response)
        # We should not have made an additional head bucket call
        self.assertEqual(head_bucket_calls, 0)

    def test_does_not_redirect_if_None_response(self):
        request_dict = {
            'url': 'https://us-west-2.amazonaws.com/foo',
            'context': {'signing': {'bucket': 'foo'}},
        }
        response = None
        redirect_response = self.redirector.redirect_from_error(
            request_dict, response, self.operation
        )
        self.assertIsNone(redirect_response)

    def test_get_region_from_response(self):
        response = (
            None,
            {
                'Error': {
                    'Code': 'PermanentRedirect',
                    'Endpoint': 'foo.eu-central-1.amazonaws.com',
                    'Bucket': 'foo',
                },
                'ResponseMetadata': {
                    'HTTPHeaders': {'x-amz-bucket-region': 'eu-central-1'}
                },
            },
        )
        region = self.redirector.get_bucket_region('foo', response)
        self.assertEqual(region, 'eu-central-1')

    def test_get_region_from_response_error_body(self):
        response = (
            None,
            {
                'Error': {
                    'Code': 'PermanentRedirect',
                    'Endpoint': 'foo.eu-central-1.amazonaws.com',
                    'Bucket': 'foo',
                    'Region': 'eu-central-1',
                },
                'ResponseMetadata': {'HTTPHeaders': {}},
            },
        )
        region = self.redirector.get_bucket_region('foo', response)
        self.assertEqual(region, 'eu-central-1')

    def test_get_region_from_head_bucket_error(self):
        self.set_client_response_headers(
            {'x-amz-bucket-region': 'eu-central-1'}
        )
        response = (
            None,
            {
                'Error': {
                    'Code': 'PermanentRedirect',
                    'Endpoint': 'foo.eu-central-1.amazonaws.com',
                    'Bucket': 'foo',
                },
                'ResponseMetadata': {'HTTPHeaders': {}},
            },
        )
        region = self.redirector.get_bucket_region('foo', response)
        self.assertEqual(region, 'eu-central-1')

    def test_get_region_from_head_bucket_success(self):
        success_response = {
            'ResponseMetadata': {
                'HTTPHeaders': {'x-amz-bucket-region': 'eu-central-1'}
            }
        }
        self.client.head_bucket.side_effect = None
        self.client.head_bucket.return_value = success_response
        response = (
            None,
            {
                'Error': {
                    'Code': 'PermanentRedirect',
                    'Endpoint': 'foo.eu-central-1.amazonaws.com',
                    'Bucket': 'foo',
                },
                'ResponseMetadata': {'HTTPHeaders': {}},
            },
        )
        region = self.redirector.get_bucket_region('foo', response)
        self.assertEqual(region, 'eu-central-1')

    def test_no_redirect_from_error_for_accesspoint(self):
        request_dict = {
            'url': (
                'https://myendpoint-123456789012.s3-accesspoint.'
                'us-west-2.amazonaws.com/key'
            ),
            'context': {
                's3_redirect': {
                    'redirected': False,
                    'bucket': 'arn:aws:s3:us-west-2:123456789012:myendpoint',
                    'params': {},
                }
            },
        }
        response = (
            None,
            {
                'Error': {'Code': '400', 'Message': 'Bad Request'},
                'ResponseMetadata': {
                    'HTTPHeaders': {'x-amz-bucket-region': 'eu-central-1'}
                },
            },
        )

        self.operation.name = 'HeadObject'
        redirect_response = self.redirector.redirect_from_error(
            request_dict, response, self.operation
        )
        self.assertEqual(redirect_response, None)

    def test_no_redirect_from_error_for_mrap_accesspoint(self):
        mrap_arn = 'arn:aws:s3::123456789012:accesspoint:mfzwi23gnjvgw.mrap'
        request_dict = {
            'url': (
                'https://mfzwi23gnjvgw.mrap.accesspoint.'
                's3-global.amazonaws.com'
            ),
            'context': {
                's3_redirect': {
                    'redirected': False,
                    'bucket': mrap_arn,
                    'params': {},
                }
            },
        }
        response = (
            None,
            {
                'Error': {'Code': '400', 'Message': 'Bad Request'},
                'ResponseMetadata': {
                    'HTTPHeaders': {'x-amz-bucket-region': 'eu-central-1'}
                },
            },
        )

        self.operation.name = 'HeadObject'
        redirect_response = self.redirector.redirect_from_error(
            request_dict, response, self.operation
        )
        self.assertEqual(redirect_response, None)


class TestArnParser(unittest.TestCase):
    def setUp(self):
        self.parser = ArnParser()

    def test_parse(self):
        arn = 'arn:aws:s3:us-west-2:1023456789012:myresource'
        self.assertEqual(
            self.parser.parse_arn(arn),
            {
                'partition': 'aws',
                'service': 's3',
                'region': 'us-west-2',
                'account': '1023456789012',
                'resource': 'myresource',
            },
        )

    def test_parse_invalid_arn(self):
        with self.assertRaises(InvalidArnException):
            self.parser.parse_arn('arn:aws:s3')

    def test_parse_arn_with_resource_type(self):
        arn = 'arn:aws:s3:us-west-2:1023456789012:bucket_name:mybucket'
        self.assertEqual(
            self.parser.parse_arn(arn),
            {
                'partition': 'aws',
                'service': 's3',
                'region': 'us-west-2',
                'account': '1023456789012',
                'resource': 'bucket_name:mybucket',
            },
        )

    def test_parse_arn_with_empty_elements(self):
        arn = 'arn:aws:s3:::mybucket'
        self.assertEqual(
            self.parser.parse_arn(arn),
            {
                'partition': 'aws',
                'service': 's3',
                'region': '',
                'account': '',
                'resource': 'mybucket',
            },
        )


class TestS3ArnParamHandler(unittest.TestCase):
    def setUp(self):
        self.arn_handler = S3ArnParamHandler()
        self.model = mock.Mock(OperationModel)
        self.model.name = 'GetObject'

    def test_register(self):
        event_emitter = mock.Mock()
        self.arn_handler.register(event_emitter)
        event_emitter.register.assert_called_with(
            'before-parameter-build.s3', self.arn_handler.handle_arn
        )

    def test_accesspoint_arn(self):
        params = {
            'Bucket': 'arn:aws:s3:us-west-2:123456789012:accesspoint/endpoint'
        }
        context = {}
        self.arn_handler.handle_arn(params, self.model, context)
        self.assertEqual(params, {'Bucket': 'endpoint'})
        self.assertEqual(
            context,
            {
                's3_accesspoint': {
                    'name': 'endpoint',
                    'account': '123456789012',
                    'region': 'us-west-2',
                    'partition': 'aws',
                    'service': 's3',
                }
            },
        )

    def test_accesspoint_arn_with_colon(self):
        params = {
            'Bucket': 'arn:aws:s3:us-west-2:123456789012:accesspoint:endpoint'
        }
        context = {}
        self.arn_handler.handle_arn(params, self.model, context)
        self.assertEqual(params, {'Bucket': 'endpoint'})
        self.assertEqual(
            context,
            {
                's3_accesspoint': {
                    'name': 'endpoint',
                    'account': '123456789012',
                    'region': 'us-west-2',
                    'partition': 'aws',
                    'service': 's3',
                }
            },
        )

    def test_errors_for_non_accesspoint_arn(self):
        params = {
            'Bucket': 'arn:aws:s3:us-west-2:123456789012:unsupported:resource'
        }
        context = {}
        with self.assertRaises(UnsupportedS3ArnError):
            self.arn_handler.handle_arn(params, self.model, context)

    def test_outpost_arn_with_colon(self):
        params = {
            'Bucket': (
                'arn:aws:s3-outposts:us-west-2:123456789012:outpost:'
                'op-01234567890123456:accesspoint:myaccesspoint'
            )
        }
        context = {}
        self.arn_handler.handle_arn(params, self.model, context)
        self.assertEqual(params, {'Bucket': 'myaccesspoint'})
        self.assertEqual(
            context,
            {
                's3_accesspoint': {
                    'name': 'myaccesspoint',
                    'outpost_name': 'op-01234567890123456',
                    'account': '123456789012',
                    'region': 'us-west-2',
                    'partition': 'aws',
                    'service': 's3-outposts',
                }
            },
        )

    def test_outpost_arn_with_slash(self):
        params = {
            'Bucket': (
                'arn:aws:s3-outposts:us-west-2:123456789012:outpost/'
                'op-01234567890123456/accesspoint/myaccesspoint'
            )
        }
        context = {}
        self.arn_handler.handle_arn(params, self.model, context)
        self.assertEqual(params, {'Bucket': 'myaccesspoint'})
        self.assertEqual(
            context,
            {
                's3_accesspoint': {
                    'name': 'myaccesspoint',
                    'outpost_name': 'op-01234567890123456',
                    'account': '123456789012',
                    'region': 'us-west-2',
                    'partition': 'aws',
                    'service': 's3-outposts',
                }
            },
        )

    def test_outpost_arn_errors_for_missing_fields(self):
        params = {
            'Bucket': 'arn:aws:s3-outposts:us-west-2:123456789012:outpost/'
            'op-01234567890123456/accesspoint'
        }
        with self.assertRaises(UnsupportedOutpostResourceError):
            self.arn_handler.handle_arn(params, self.model, {})

    def test_outpost_arn_errors_for_empty_fields(self):
        params = {
            'Bucket': 'arn:aws:s3-outposts:us-west-2:123456789012:outpost/'
            '/accesspoint/myaccesspoint'
        }
        with self.assertRaises(UnsupportedOutpostResourceError):
            self.arn_handler.handle_arn(params, self.model, {})

    def test_ignores_bucket_names(self):
        params = {'Bucket': 'mybucket'}
        context = {}
        self.arn_handler.handle_arn(params, self.model, context)
        self.assertEqual(params, {'Bucket': 'mybucket'})
        self.assertEqual(context, {})

    def test_ignores_create_bucket(self):
        arn = 'arn:aws:s3:us-west-2:123456789012:accesspoint/endpoint'
        params = {'Bucket': arn}
        context = {}
        self.model.name = 'CreateBucket'
        self.arn_handler.handle_arn(params, self.model, context)
        self.assertEqual(params, {'Bucket': arn})
        self.assertEqual(context, {})


class TestS3EndpointSetter(unittest.TestCase):
    def setUp(self):
        self.operation_name = 'GetObject'
        self.signature_version = 's3v4'
        self.region_name = 'us-west-2'
        self.service = 's3'
        self.account = '123456789012'
        self.bucket = 'mybucket'
        self.key = 'key.txt'
        self.accesspoint_name = 'myaccesspoint'
        self.outpost_name = 'op-123456789012'
        self.partition = 'aws'
        self.endpoint_resolver = mock.Mock()
        self.dns_suffix = 'amazonaws.com'
        self.endpoint_resolver.construct_endpoint.return_value = {
            'dnsSuffix': self.dns_suffix
        }
        self.endpoint_setter = self.get_endpoint_setter()

    def get_endpoint_setter(self, **kwargs):
        setter_kwargs = {
            'endpoint_resolver': self.endpoint_resolver,
            'region': self.region_name,
        }
        setter_kwargs.update(kwargs)
        return S3EndpointSetter(**setter_kwargs)

    def get_s3_request(
        self, bucket=None, key=None, scheme='https://', querystring=None
    ):
        url = scheme + 's3.us-west-2.amazonaws.com/'
        if bucket:
            url += bucket
        if key:
            url += '/%s' % key
        if querystring:
            url += '?%s' % querystring
        return AWSRequest(method='GET', headers={}, url=url)

    def get_s3_outpost_request(self, **s3_request_kwargs):
        request = self.get_s3_request(
            self.accesspoint_name, **s3_request_kwargs
        )
        accesspoint_context = self.get_s3_accesspoint_context(
            name=self.accesspoint_name, outpost_name=self.outpost_name
        )
        request.context['s3_accesspoint'] = accesspoint_context
        return request

    def get_s3_accesspoint_request(
        self,
        accesspoint_name=None,
        accesspoint_context=None,
        **s3_request_kwargs
    ):
        if not accesspoint_name:
            accesspoint_name = self.accesspoint_name
        request = self.get_s3_request(accesspoint_name, **s3_request_kwargs)
        if accesspoint_context is None:
            accesspoint_context = self.get_s3_accesspoint_context(
                name=accesspoint_name
            )
        request.context['s3_accesspoint'] = accesspoint_context
        return request

    def get_s3_accesspoint_context(self, **overrides):
        accesspoint_context = {
            'name': self.accesspoint_name,
            'account': self.account,
            'region': self.region_name,
            'partition': self.partition,
            'service': self.service,
        }
        accesspoint_context.update(overrides)
        return accesspoint_context

    def call_set_endpoint(self, endpoint_setter, request, **kwargs):
        set_endpoint_kwargs = {
            'request': request,
            'operation_name': self.operation_name,
            'signature_version': self.signature_version,
            'region_name': self.region_name,
        }
        set_endpoint_kwargs.update(kwargs)
        endpoint_setter.set_endpoint(**set_endpoint_kwargs)

    def test_register(self):
        event_emitter = mock.Mock()
        self.endpoint_setter.register(event_emitter)
        event_emitter.register.assert_has_calls(
            [
                mock.call('before-sign.s3', self.endpoint_setter.set_endpoint),
                mock.call('choose-signer.s3', self.endpoint_setter.set_signer),
                mock.call(
                    'before-call.s3.WriteGetObjectResponse',
                    self.endpoint_setter.update_endpoint_to_s3_object_lambda,
                ),
            ]
        )

    def test_outpost_endpoint(self):
        request = self.get_s3_outpost_request()
        self.call_set_endpoint(self.endpoint_setter, request=request)
        expected_url = 'https://{}-{}.{}.s3-outposts.{}.amazonaws.com/'.format(
            self.accesspoint_name,
            self.account,
            self.outpost_name,
            self.region_name,
        )
        self.assertEqual(request.url, expected_url)

    def test_outpost_endpoint_preserves_key_in_path(self):
        request = self.get_s3_outpost_request(key=self.key)
        self.call_set_endpoint(self.endpoint_setter, request=request)
        expected_url = (
            'https://{}-{}.{}.s3-outposts.{}.amazonaws.com/{}'.format(
                self.accesspoint_name,
                self.account,
                self.outpost_name,
                self.region_name,
                self.key,
            )
        )
        self.assertEqual(request.url, expected_url)

    def test_accesspoint_endpoint(self):
        request = self.get_s3_accesspoint_request()
        self.call_set_endpoint(self.endpoint_setter, request=request)
        expected_url = 'https://{}-{}.s3-accesspoint.{}.amazonaws.com/'.format(
            self.accesspoint_name, self.account, self.region_name
        )
        self.assertEqual(request.url, expected_url)

    def test_accesspoint_preserves_key_in_path(self):
        request = self.get_s3_accesspoint_request(key=self.key)
        self.call_set_endpoint(self.endpoint_setter, request=request)
        expected_url = (
            'https://{}-{}.s3-accesspoint.{}.amazonaws.com/{}'.format(
                self.accesspoint_name, self.account, self.region_name, self.key
            )
        )
        self.assertEqual(request.url, expected_url)

    def test_accesspoint_preserves_scheme(self):
        request = self.get_s3_accesspoint_request(scheme='http://')
        self.call_set_endpoint(self.endpoint_setter, request=request)
        expected_url = 'http://{}-{}.s3-accesspoint.{}.amazonaws.com/'.format(
            self.accesspoint_name,
            self.account,
            self.region_name,
        )
        self.assertEqual(request.url, expected_url)

    def test_accesspoint_preserves_query_string(self):
        request = self.get_s3_accesspoint_request(querystring='acl')
        self.call_set_endpoint(self.endpoint_setter, request=request)
        expected_url = (
            'https://{}-{}.s3-accesspoint.{}.amazonaws.com/?acl'.format(
                self.accesspoint_name,
                self.account,
                self.region_name,
            )
        )
        self.assertEqual(request.url, expected_url)

    def test_uses_resolved_dns_suffix(self):
        self.endpoint_resolver.construct_endpoint.return_value = {
            'dnsSuffix': 'mysuffix.com'
        }
        request = self.get_s3_accesspoint_request()
        self.call_set_endpoint(self.endpoint_setter, request=request)
        expected_url = 'https://{}-{}.s3-accesspoint.{}.mysuffix.com/'.format(
            self.accesspoint_name,
            self.account,
            self.region_name,
        )
        self.assertEqual(request.url, expected_url)

    def test_uses_region_of_client_if_use_arn_disabled(self):
        client_region = 'client-region'
        self.endpoint_setter = self.get_endpoint_setter(
            region=client_region, s3_config={'use_arn_region': False}
        )
        request = self.get_s3_accesspoint_request()
        self.call_set_endpoint(self.endpoint_setter, request=request)
        expected_url = 'https://{}-{}.s3-accesspoint.{}.amazonaws.com/'.format(
            self.accesspoint_name,
            self.account,
            client_region,
        )
        self.assertEqual(request.url, expected_url)

    def test_accesspoint_supports_custom_endpoint(self):
        endpoint_setter = self.get_endpoint_setter(
            endpoint_url='https://custom.com'
        )
        request = self.get_s3_accesspoint_request()
        self.call_set_endpoint(endpoint_setter, request=request)
        expected_url = 'https://{}-{}.custom.com/'.format(
            self.accesspoint_name,
            self.account,
        )
        self.assertEqual(request.url, expected_url)

    def test_errors_for_mismatching_partition(self):
        endpoint_setter = self.get_endpoint_setter(partition='aws-cn')
        accesspoint_context = self.get_s3_accesspoint_context(partition='aws')
        request = self.get_s3_accesspoint_request(
            accesspoint_context=accesspoint_context
        )
        with self.assertRaises(UnsupportedS3AccesspointConfigurationError):
            self.call_set_endpoint(endpoint_setter, request=request)

    def test_errors_for_mismatching_partition_when_using_client_region(self):
        endpoint_setter = self.get_endpoint_setter(
            s3_config={'use_arn_region': False}, partition='aws-cn'
        )
        accesspoint_context = self.get_s3_accesspoint_context(partition='aws')
        request = self.get_s3_accesspoint_request(
            accesspoint_context=accesspoint_context
        )
        with self.assertRaises(UnsupportedS3AccesspointConfigurationError):
            self.call_set_endpoint(endpoint_setter, request=request)

    def test_set_endpoint_for_auto(self):
        endpoint_setter = self.get_endpoint_setter(
            s3_config={'addressing_style': 'auto'}
        )
        request = self.get_s3_request(self.bucket, self.key)
        self.call_set_endpoint(endpoint_setter, request)
        expected_url = 'https://{}.s3.us-west-2.amazonaws.com/{}'.format(
            self.bucket, self.key
        )
        self.assertEqual(request.url, expected_url)

    def test_set_endpoint_for_virtual(self):
        endpoint_setter = self.get_endpoint_setter(
            s3_config={'addressing_style': 'virtual'}
        )
        request = self.get_s3_request(self.bucket, self.key)
        self.call_set_endpoint(endpoint_setter, request)
        expected_url = 'https://{}.s3.us-west-2.amazonaws.com/{}'.format(
            self.bucket, self.key
        )
        self.assertEqual(request.url, expected_url)

    def test_set_endpoint_for_path(self):
        endpoint_setter = self.get_endpoint_setter(
            s3_config={'addressing_style': 'path'}
        )
        request = self.get_s3_request(self.bucket, self.key)
        self.call_set_endpoint(endpoint_setter, request)
        expected_url = 'https://s3.us-west-2.amazonaws.com/{}/{}'.format(
            self.bucket, self.key
        )
        self.assertEqual(request.url, expected_url)

    def test_set_endpoint_for_accelerate(self):
        endpoint_setter = self.get_endpoint_setter(
            s3_config={'use_accelerate_endpoint': True}
        )
        request = self.get_s3_request(self.bucket, self.key)
        self.call_set_endpoint(endpoint_setter, request)
        expected_url = 'https://{}.s3-accelerate.amazonaws.com/{}'.format(
            self.bucket, self.key
        )
        self.assertEqual(request.url, expected_url)


class TestContainerMetadataFetcher(unittest.TestCase):
    def setUp(self):
        self.responses = []
        self.http = mock.Mock()
        self.sleep = mock.Mock()

    def create_fetcher(self):
        return ContainerMetadataFetcher(self.http, sleep=self.sleep)

    def fake_response(self, status_code, body):
        response = mock.Mock()
        response.status_code = status_code
        response.content = body
        return response

    def set_http_responses_to(self, *responses):
        http_responses = []
        for response in responses:
            if isinstance(response, Exception):
                # Simulating an error condition.
                http_response = response
            elif hasattr(response, 'status_code'):
                # It's a precreated fake_response.
                http_response = response
            else:
                http_response = self.fake_response(
                    status_code=200, body=json.dumps(response).encode('utf-8')
                )
            http_responses.append(http_response)
        self.http.send.side_effect = http_responses

    def assert_request(self, method, url, headers):
        request = self.http.send.call_args[0][0]
        self.assertEqual(request.method, method)
        self.assertEqual(request.url, url)
        self.assertEqual(request.headers, headers)

    def assert_can_retrieve_metadata_from(self, full_uri):
        response_body = {'foo': 'bar'}
        self.set_http_responses_to(response_body)
        fetcher = self.create_fetcher()
        response = fetcher.retrieve_full_uri(full_uri)
        self.assertEqual(response, response_body)
        self.assert_request('GET', full_uri, {'Accept': 'application/json'})

    def assert_host_is_not_allowed(self, full_uri):
        response_body = {'foo': 'bar'}
        self.set_http_responses_to(response_body)
        fetcher = self.create_fetcher()
        with self.assertRaisesRegex(ValueError, 'Unsupported host'):
            fetcher.retrieve_full_uri(full_uri)
        self.assertFalse(self.http.send.called)

    def test_can_specify_extra_headers_are_merged(self):
        headers = {
            # The 'Accept' header will override the
            # default Accept header of application/json.
            'Accept': 'application/not-json',
            'X-Other-Header': 'foo',
        }
        self.set_http_responses_to({'foo': 'bar'})
        fetcher = self.create_fetcher()
        fetcher.retrieve_full_uri('http://localhost', headers)
        self.assert_request('GET', 'http://localhost', headers)

    def test_can_retrieve_uri(self):
        json_body = {
            "AccessKeyId": "a",
            "SecretAccessKey": "b",
            "Token": "c",
            "Expiration": "d",
        }
        self.set_http_responses_to(json_body)

        fetcher = self.create_fetcher()
        response = fetcher.retrieve_uri('/foo?id=1')

        self.assertEqual(response, json_body)
        # Ensure we made calls to the right endpoint.
        headers = {'Accept': 'application/json'}
        self.assert_request('GET', 'http://169.254.170.2/foo?id=1', headers)

    def test_can_retry_requests(self):
        success_response = {
            "AccessKeyId": "a",
            "SecretAccessKey": "b",
            "Token": "c",
            "Expiration": "d",
        }
        self.set_http_responses_to(
            # First response is a connection error, should
            # be retried.
            ConnectionClosedError(endpoint_url=''),
            # Second response is the successful JSON response
            # with credentials.
            success_response,
        )
        fetcher = self.create_fetcher()
        response = fetcher.retrieve_uri('/foo?id=1')
        self.assertEqual(response, success_response)

    def test_propagates_credential_error_on_http_errors(self):
        self.set_http_responses_to(
            # In this scenario, we never get a successful response.
            ConnectionClosedError(endpoint_url=''),
            ConnectionClosedError(endpoint_url=''),
            ConnectionClosedError(endpoint_url=''),
            ConnectionClosedError(endpoint_url=''),
            ConnectionClosedError(endpoint_url=''),
        )
        # As a result, we expect an appropriate error to be raised.
        fetcher = self.create_fetcher()
        with self.assertRaises(MetadataRetrievalError):
            fetcher.retrieve_uri('/foo?id=1')
        self.assertEqual(self.http.send.call_count, fetcher.RETRY_ATTEMPTS)

    def test_error_raised_on_non_200_response(self):
        self.set_http_responses_to(
            self.fake_response(status_code=404, body=b'Error not found'),
            self.fake_response(status_code=404, body=b'Error not found'),
            self.fake_response(status_code=404, body=b'Error not found'),
        )
        fetcher = self.create_fetcher()
        with self.assertRaises(MetadataRetrievalError):
            fetcher.retrieve_uri('/foo?id=1')
        # Should have tried up to RETRY_ATTEMPTS.
        self.assertEqual(self.http.send.call_count, fetcher.RETRY_ATTEMPTS)

    def test_error_raised_on_no_json_response(self):
        # If the service returns a sucess response but with a body that
        # does not contain JSON, we should still retry up to RETRY_ATTEMPTS,
        # but after exhausting retries we propagate the exception.
        self.set_http_responses_to(
            self.fake_response(status_code=200, body=b'Not JSON'),
            self.fake_response(status_code=200, body=b'Not JSON'),
            self.fake_response(status_code=200, body=b'Not JSON'),
        )
        fetcher = self.create_fetcher()
        with self.assertRaises(MetadataRetrievalError) as e:
            fetcher.retrieve_uri('/foo?id=1')
        self.assertNotIn('Not JSON', str(e.exception))
        # Should have tried up to RETRY_ATTEMPTS.
        self.assertEqual(self.http.send.call_count, fetcher.RETRY_ATTEMPTS)

    def test_can_retrieve_full_uri_with_fixed_ip(self):
        self.assert_can_retrieve_metadata_from(
            'http://%s/foo?id=1' % ContainerMetadataFetcher.IP_ADDRESS
        )

    def test_localhost_http_is_allowed(self):
        self.assert_can_retrieve_metadata_from('http://localhost/foo')

    def test_localhost_with_port_http_is_allowed(self):
        self.assert_can_retrieve_metadata_from('http://localhost:8000/foo')

    def test_localhost_https_is_allowed(self):
        self.assert_can_retrieve_metadata_from('https://localhost/foo')

    def test_can_use_127_ip_addr(self):
        self.assert_can_retrieve_metadata_from('https://127.0.0.1/foo')

    def test_can_use_127_ip_addr_with_port(self):
        self.assert_can_retrieve_metadata_from('https://127.0.0.1:8080/foo')

    def test_link_local_http_is_not_allowed(self):
        self.assert_host_is_not_allowed('http://169.254.0.1/foo')

    def test_link_local_https_is_not_allowed(self):
        self.assert_host_is_not_allowed('https://169.254.0.1/foo')

    def test_non_link_local_nonallowed_url(self):
        self.assert_host_is_not_allowed('http://169.1.2.3/foo')

    def test_error_raised_on_nonallowed_url(self):
        self.assert_host_is_not_allowed('http://somewhere.com/foo')

    def test_external_host_not_allowed_if_https(self):
        self.assert_host_is_not_allowed('https://somewhere.com/foo')


class TestUnsigned(unittest.TestCase):
    def test_copy_returns_same_object(self):
        self.assertIs(botocore.UNSIGNED, copy.copy(botocore.UNSIGNED))

    def test_deepcopy_returns_same_object(self):
        self.assertIs(botocore.UNSIGNED, copy.deepcopy(botocore.UNSIGNED))


class TestInstanceMetadataFetcher(unittest.TestCase):
    def setUp(self):
        urllib3_session_send = 'botocore.httpsession.URLLib3Session.send'
        self._urllib3_patch = mock.patch(urllib3_session_send)
        self._send = self._urllib3_patch.start()
        self._imds_responses = []
        self._send.side_effect = self.get_imds_response
        self._role_name = 'role-name'
        self._creds = {
            'AccessKeyId': 'spam',
            'SecretAccessKey': 'eggs',
            'Token': 'spam-token',
            'Expiration': 'something',
        }
        self._expected_creds = {
            'access_key': self._creds['AccessKeyId'],
            'secret_key': self._creds['SecretAccessKey'],
            'token': self._creds['Token'],
            'expiry_time': self._creds['Expiration'],
            'role_name': self._role_name,
        }

    def tearDown(self):
        self._urllib3_patch.stop()

    def add_imds_response(self, body, status_code=200):
        response = botocore.awsrequest.AWSResponse(
            url='http://169.254.169.254/',
            status_code=status_code,
            headers={},
            raw=RawResponse(body),
        )
        self._imds_responses.append(response)

    def add_get_role_name_imds_response(self, role_name=None):
        if role_name is None:
            role_name = self._role_name
        self.add_imds_response(body=role_name.encode('utf-8'))

    def add_get_credentials_imds_response(self, creds=None):
        if creds is None:
            creds = self._creds
        self.add_imds_response(body=json.dumps(creds).encode('utf-8'))

    def add_get_token_imds_response(self, token, status_code=200):
        self.add_imds_response(
            body=token.encode('utf-8'), status_code=status_code
        )

    def add_metadata_token_not_supported_response(self):
        self.add_imds_response(b'', status_code=404)

    def add_imds_connection_error(self, exception):
        self._imds_responses.append(exception)

    def add_default_imds_responses(self):
        self.add_get_token_imds_response(token='token')
        self.add_get_role_name_imds_response()
        self.add_get_credentials_imds_response()

    def get_imds_response(self, request):
        response = self._imds_responses.pop(0)
        if isinstance(response, Exception):
            raise response
        return response

    def _test_imds_base_url(self, config, expected_url):
        self.add_default_imds_responses()

        fetcher = InstanceMetadataFetcher(config=config)
        result = fetcher.retrieve_iam_role_credentials()

        self.assertEqual(result, self._expected_creds)
        self.assertEqual(fetcher.get_base_url(), expected_url)

    def test_disabled_by_environment(self):
        env = {'AWS_EC2_METADATA_DISABLED': 'true'}
        fetcher = InstanceMetadataFetcher(env=env)
        result = fetcher.retrieve_iam_role_credentials()
        self.assertEqual(result, {})
        self._send.assert_not_called()

    def test_disabled_by_environment_mixed_case(self):
        env = {'AWS_EC2_METADATA_DISABLED': 'tRuE'}
        fetcher = InstanceMetadataFetcher(env=env)
        result = fetcher.retrieve_iam_role_credentials()
        self.assertEqual(result, {})
        self._send.assert_not_called()

    def test_disabling_env_var_not_true(self):
        url = 'https://example.com/'
        env = {'AWS_EC2_METADATA_DISABLED': 'false'}

        self.add_default_imds_responses()

        fetcher = InstanceMetadataFetcher(base_url=url, env=env)
        result = fetcher.retrieve_iam_role_credentials()

        self.assertEqual(result, self._expected_creds)

    def test_ec2_metadata_endpoint_service_mode(self):
        configs = [
            (
                {'ec2_metadata_service_endpoint_mode': 'ipv6'},
                'http://[fd00:ec2::254]/',
            ),
            (
                {'ec2_metadata_service_endpoint_mode': 'ipv6'},
                'http://[fd00:ec2::254]/',
            ),
            (
                {'ec2_metadata_service_endpoint_mode': 'ipv4'},
                'http://169.254.169.254/',
            ),
            (
                {'ec2_metadata_service_endpoint_mode': 'foo'},
                'http://169.254.169.254/',
            ),
            (
                {
                    'ec2_metadata_service_endpoint_mode': 'ipv6',
                    'ec2_metadata_service_endpoint': 'http://[fd00:ec2::010]/',
                },
                'http://[fd00:ec2::010]/',
            ),
        ]

        for config, expected_url in configs:
            self._test_imds_base_url(config, expected_url)

    def test_metadata_endpoint(self):
        urls = [
            'http://fd00:ec2:0000:0000:0000:0000:0000:0000/',
            'http://[fd00:ec2::010]/',
            'http://192.168.1.1/',
        ]
        for url in urls:
            self.assertTrue(is_valid_uri(url))

    def test_ipv6_endpoint_no_brackets_env_var_set(self):
        url = 'http://fd00:ec2::010/'
        self.assertFalse(is_valid_ipv6_endpoint_url(url))

    def test_ipv6_invalid_endpoint(self):
        url = 'not.a:valid:dom@in'
        config = {'ec2_metadata_service_endpoint': url}
        with self.assertRaises(InvalidIMDSEndpointError):
            InstanceMetadataFetcher(config=config)

    def test_ipv6_endpoint_env_var_set_and_args(self):
        url = 'http://[fd00:ec2::254]/'
        url_arg = 'http://fd00:ec2:0000:0000:0000:8a2e:0370:7334/'
        config = {'ec2_metadata_service_endpoint': url}

        self.add_default_imds_responses()

        fetcher = InstanceMetadataFetcher(config=config, base_url=url_arg)
        result = fetcher.retrieve_iam_role_credentials()

        self.assertEqual(result, self._expected_creds)
        self.assertEqual(fetcher.get_base_url(), url_arg)

    def test_ipv6_imds_not_allocated(self):
        url = 'http://fd00:ec2:0000:0000:0000:0000:0000:0000/'
        config = {'ec2_metadata_service_endpoint': url}

        self.add_imds_response(status_code=400, body=b'{}')

        fetcher = InstanceMetadataFetcher(config=config)
        result = fetcher.retrieve_iam_role_credentials()
        self.assertEqual(result, {})

    def test_ipv6_imds_empty_config(self):
        configs = [
            ({'ec2_metadata_service_endpoint': ''}, 'http://169.254.169.254/'),
            (
                {'ec2_metadata_service_endpoint_mode': ''},
                'http://169.254.169.254/',
            ),
            ({}, 'http://169.254.169.254/'),
            (None, 'http://169.254.169.254/'),
        ]

        for config, expected_url in configs:
            self._test_imds_base_url(config, expected_url)

    def test_includes_user_agent_header(self):
        user_agent = 'my-user-agent'
        self.add_default_imds_responses()

        InstanceMetadataFetcher(
            user_agent=user_agent
        ).retrieve_iam_role_credentials()

        self.assertEqual(self._send.call_count, 3)
        for call in self._send.calls:
            self.assertTrue(call[0][0].headers['User-Agent'], user_agent)

    def test_non_200_response_for_role_name_is_retried(self):
        # Response for role name that have a non 200 status code should
        # be retried.
        self.add_get_token_imds_response(token='token')
        self.add_imds_response(
            status_code=429, body=b'{"message": "Slow down"}'
        )
        self.add_get_role_name_imds_response()
        self.add_get_credentials_imds_response()
        result = InstanceMetadataFetcher(
            num_attempts=2
        ).retrieve_iam_role_credentials()
        self.assertEqual(result, self._expected_creds)

    def test_http_connection_error_for_role_name_is_retried(self):
        # Connection related errors should be retried
        self.add_get_token_imds_response(token='token')
        self.add_imds_connection_error(ConnectionClosedError(endpoint_url=''))
        self.add_get_role_name_imds_response()
        self.add_get_credentials_imds_response()
        result = InstanceMetadataFetcher(
            num_attempts=2
        ).retrieve_iam_role_credentials()
        self.assertEqual(result, self._expected_creds)

    def test_empty_response_for_role_name_is_retried(self):
        # Response for role name that have a non 200 status code should
        # be retried.
        self.add_get_token_imds_response(token='token')
        self.add_imds_response(body=b'')
        self.add_get_role_name_imds_response()
        self.add_get_credentials_imds_response()
        result = InstanceMetadataFetcher(
            num_attempts=2
        ).retrieve_iam_role_credentials()
        self.assertEqual(result, self._expected_creds)

    def test_non_200_response_is_retried(self):
        self.add_get_token_imds_response(token='token')
        self.add_get_role_name_imds_response()
        # Response for creds that has a 200 status code but has an empty
        # body should be retried.
        self.add_imds_response(
            status_code=429, body=b'{"message": "Slow down"}'
        )
        self.add_get_credentials_imds_response()
        result = InstanceMetadataFetcher(
            num_attempts=2
        ).retrieve_iam_role_credentials()
        self.assertEqual(result, self._expected_creds)

    def test_http_connection_errors_is_retried(self):
        self.add_get_token_imds_response(token='token')
        self.add_get_role_name_imds_response()
        # Connection related errors should be retried
        self.add_imds_connection_error(ConnectionClosedError(endpoint_url=''))
        self.add_get_credentials_imds_response()
        result = InstanceMetadataFetcher(
            num_attempts=2
        ).retrieve_iam_role_credentials()
        self.assertEqual(result, self._expected_creds)

    def test_empty_response_is_retried(self):
        self.add_get_token_imds_response(token='token')
        self.add_get_role_name_imds_response()
        # Response for creds that has a 200 status code but is empty.
        # This should be retried.
        self.add_imds_response(body=b'')
        self.add_get_credentials_imds_response()
        result = InstanceMetadataFetcher(
            num_attempts=2
        ).retrieve_iam_role_credentials()
        self.assertEqual(result, self._expected_creds)

    def test_invalid_json_is_retried(self):
        self.add_get_token_imds_response(token='token')
        self.add_get_role_name_imds_response()
        # Response for creds that has a 200 status code but is invalid JSON.
        # This should be retried.
        self.add_imds_response(body=b'{"AccessKey":')
        self.add_get_credentials_imds_response()
        result = InstanceMetadataFetcher(
            num_attempts=2
        ).retrieve_iam_role_credentials()
        self.assertEqual(result, self._expected_creds)

    def test_exhaust_retries_on_role_name_request(self):
        self.add_get_token_imds_response(token='token')
        self.add_imds_response(status_code=400, body=b'')
        result = InstanceMetadataFetcher(
            num_attempts=1
        ).retrieve_iam_role_credentials()
        self.assertEqual(result, {})

    def test_exhaust_retries_on_credentials_request(self):
        self.add_get_token_imds_response(token='token')
        self.add_get_role_name_imds_response()
        self.add_imds_response(status_code=400, body=b'')
        result = InstanceMetadataFetcher(
            num_attempts=1
        ).retrieve_iam_role_credentials()
        self.assertEqual(result, {})

    def test_missing_fields_in_credentials_response(self):
        self.add_get_token_imds_response(token='token')
        self.add_get_role_name_imds_response()
        # Response for creds that has a 200 status code and a JSON body
        # representing an error. We do not necessarily want to retry this.
        self.add_imds_response(
            body=b'{"Code":"AssumeRoleUnauthorizedAccess","Message":"error"}'
        )
        result = InstanceMetadataFetcher().retrieve_iam_role_credentials()
        self.assertEqual(result, {})

    def test_token_is_included(self):
        user_agent = 'my-user-agent'
        self.add_default_imds_responses()

        result = InstanceMetadataFetcher(
            user_agent=user_agent
        ).retrieve_iam_role_credentials()

        # Check that subsequent calls after getting the token include the token.
        self.assertEqual(self._send.call_count, 3)
        for call in self._send.call_args_list[1:]:
            self.assertEqual(
                call[0][0].headers['x-aws-ec2-metadata-token'], 'token'
            )
        self.assertEqual(result, self._expected_creds)

    def test_metadata_token_not_supported_404(self):
        user_agent = 'my-user-agent'
        self.add_imds_response(b'', status_code=404)
        self.add_get_role_name_imds_response()
        self.add_get_credentials_imds_response()

        result = InstanceMetadataFetcher(
            user_agent=user_agent
        ).retrieve_iam_role_credentials()

        for call in self._send.call_args_list[1:]:
            self.assertNotIn('x-aws-ec2-metadata-token', call[0][0].headers)
        self.assertEqual(result, self._expected_creds)

    def test_metadata_token_not_supported_403(self):
        user_agent = 'my-user-agent'
        self.add_imds_response(b'', status_code=403)
        self.add_get_role_name_imds_response()
        self.add_get_credentials_imds_response()

        result = InstanceMetadataFetcher(
            user_agent=user_agent
        ).retrieve_iam_role_credentials()

        for call in self._send.call_args_list[1:]:
            self.assertNotIn('x-aws-ec2-metadata-token', call[0][0].headers)
        self.assertEqual(result, self._expected_creds)

    def test_metadata_token_not_supported_405(self):
        user_agent = 'my-user-agent'
        self.add_imds_response(b'', status_code=405)
        self.add_get_role_name_imds_response()
        self.add_get_credentials_imds_response()

        result = InstanceMetadataFetcher(
            user_agent=user_agent
        ).retrieve_iam_role_credentials()

        for call in self._send.call_args_list[1:]:
            self.assertNotIn('x-aws-ec2-metadata-token', call[0][0].headers)
        self.assertEqual(result, self._expected_creds)

    def test_metadata_token_not_supported_timeout(self):
        user_agent = 'my-user-agent'
        self.add_imds_connection_error(ReadTimeoutError(endpoint_url='url'))
        self.add_get_role_name_imds_response()
        self.add_get_credentials_imds_response()

        result = InstanceMetadataFetcher(
            user_agent=user_agent
        ).retrieve_iam_role_credentials()

        for call in self._send.call_args_list[1:]:
            self.assertNotIn('x-aws-ec2-metadata-token', call[0][0].headers)
        self.assertEqual(result, self._expected_creds)

    def test_token_not_supported_exhaust_retries(self):
        user_agent = 'my-user-agent'
        self.add_imds_connection_error(ConnectTimeoutError(endpoint_url='url'))
        self.add_get_role_name_imds_response()
        self.add_get_credentials_imds_response()

        result = InstanceMetadataFetcher(
            user_agent=user_agent
        ).retrieve_iam_role_credentials()

        for call in self._send.call_args_list[1:]:
            self.assertNotIn('x-aws-ec2-metadata-token', call[0][0].headers)
        self.assertEqual(result, self._expected_creds)

    def test_metadata_token_bad_request_yields_no_credentials(self):
        user_agent = 'my-user-agent'
        self.add_imds_response(b'', status_code=400)
        result = InstanceMetadataFetcher(
            user_agent=user_agent
        ).retrieve_iam_role_credentials()
        self.assertEqual(result, {})

    def _get_datetime(self, dt=None, offset=None, offset_func=operator.add):
        if dt is None:
            dt = datetime.datetime.utcnow()
        if offset is not None:
            dt = offset_func(dt, offset)

        return dt

    def _get_default_creds(self, overrides=None):
        if overrides is None:
            overrides = {}

        creds = {
            'AccessKeyId': 'access',
            'SecretAccessKey': 'secret',
            'Token': 'token',
            'Expiration': '1970-01-01T00:00:00',
        }
        creds.update(overrides)
        return creds

    def _convert_creds_to_imds_fetcher(self, creds):
        return {
            'access_key': creds['AccessKeyId'],
            'secret_key': creds['SecretAccessKey'],
            'token': creds['Token'],
            'expiry_time': creds['Expiration'],
            'role_name': self._role_name,
        }

    def _add_default_imds_response(self, status_code=200, creds=''):
        self.add_get_token_imds_response(token='token')
        self.add_get_role_name_imds_response()
        self.add_imds_response(
            status_code=200, body=json.dumps(creds).encode('utf-8')
        )

    def mock_randint(self, int_val=600):
        randint_mock = mock.Mock()
        randint_mock.return_value = int_val
        return randint_mock

    @FreezeTime(module=botocore.utils.datetime, date=DATE)
    def test_expiry_time_extension(self):
        current_time = self._get_datetime()
        expiration_time = self._get_datetime(
            dt=current_time, offset=datetime.timedelta(seconds=14 * 60)
        )
        new_expiration = self._get_datetime(
            dt=current_time, offset=datetime.timedelta(seconds=20 * 60)
        )

        creds = self._get_default_creds(
            {"Expiration": expiration_time.strftime(DT_FORMAT)}
        )
        expected_data = self._convert_creds_to_imds_fetcher(creds)
        expected_data["expiry_time"] = new_expiration.strftime(DT_FORMAT)

        self._add_default_imds_response(200, creds)

        with mock.patch("random.randint", self.mock_randint()):
            fetcher = InstanceMetadataFetcher()
            result = fetcher.retrieve_iam_role_credentials()
            assert result == expected_data

    @FreezeTime(module=botocore.utils.datetime, date=DATE)
    def test_expired_expiry_extension(self):
        current_time = self._get_datetime()
        expiration_time = self._get_datetime(
            dt=current_time,
            offset=datetime.timedelta(seconds=14 * 60),
            offset_func=operator.sub,
        )
        new_expiration = self._get_datetime(
            dt=current_time, offset=datetime.timedelta(seconds=20 * 60)
        )
        assert current_time > expiration_time
        assert new_expiration > current_time

        creds = self._get_default_creds(
            {"Expiration": expiration_time.strftime(DT_FORMAT)}
        )
        expected_data = self._convert_creds_to_imds_fetcher(creds)
        expected_data["expiry_time"] = new_expiration.strftime(DT_FORMAT)

        self._add_default_imds_response(200, creds)

        with mock.patch("random.randint", self.mock_randint()):
            fetcher = InstanceMetadataFetcher()
            result = fetcher.retrieve_iam_role_credentials()
            assert result == expected_data

    @FreezeTime(module=botocore.utils.datetime, date=DATE)
    def test_expiry_extension_with_config(self):
        current_time = self._get_datetime()
        expiration_time = self._get_datetime(
            dt=current_time,
            offset=datetime.timedelta(seconds=14 * 60),
            offset_func=operator.sub,
        )
        new_expiration = self._get_datetime(
            dt=current_time, offset=datetime.timedelta(seconds=25 * 60)
        )
        assert current_time > expiration_time
        assert new_expiration > current_time

        creds = self._get_default_creds(
            {"Expiration": expiration_time.strftime(DT_FORMAT)}
        )
        expected_data = self._convert_creds_to_imds_fetcher(creds)
        expected_data["expiry_time"] = new_expiration.strftime(DT_FORMAT)

        self._add_default_imds_response(200, creds)

        with mock.patch("random.randint", self.mock_randint()):
            fetcher = InstanceMetadataFetcher(
                config={"ec2_credential_refresh_window": 15 * 60}
            )
            result = fetcher.retrieve_iam_role_credentials()
            assert result == expected_data

    @FreezeTime(module=botocore.utils.datetime, date=DATE)
    def test_expiry_extension_with_bad_datetime(self):
        bad_datetime = "May 20th, 2020 19:00:00"
        creds = self._get_default_creds({"Expiration": bad_datetime})
        self._add_default_imds_response(200, creds)

        fetcher = InstanceMetadataFetcher(
            config={"ec2_credential_refresh_window": 15 * 60}
        )
        results = fetcher.retrieve_iam_role_credentials()
        assert results['expiry_time'] == bad_datetime


class TestIMDSRegionProvider(unittest.TestCase):
    def setUp(self):
        self.environ = {}
        self.environ_patch = mock.patch('os.environ', self.environ)
        self.environ_patch.start()

    def tearDown(self):
        self.environ_patch.stop()

    def assert_does_provide_expected_value(
        self,
        fetcher_region=None,
        expected_result=None,
    ):
        fake_session = mock.Mock(spec=Session)
        fake_fetcher = mock.Mock(spec=InstanceMetadataRegionFetcher)
        fake_fetcher.retrieve_region.return_value = fetcher_region
        provider = IMDSRegionProvider(fake_session, fetcher=fake_fetcher)
        value = provider.provide()
        self.assertEqual(value, expected_result)

    def test_does_provide_region_when_present(self):
        self.assert_does_provide_expected_value(
            fetcher_region='us-mars-2',
            expected_result='us-mars-2',
        )

    def test_does_provide_none(self):
        self.assert_does_provide_expected_value(
            fetcher_region=None,
            expected_result=None,
        )

    @mock.patch('botocore.httpsession.URLLib3Session.send')
    def test_use_truncated_user_agent(self, send):
        session = Session()
        session = Session()
        session.user_agent_version = '3.0'
        provider = IMDSRegionProvider(session)
        provider.provide()
        args, _ = send.call_args
        self.assertIn('Botocore/3.0', args[0].headers['User-Agent'])

    @mock.patch('botocore.httpsession.URLLib3Session.send')
    def test_can_use_ipv6(self, send):
        session = Session()
        session.set_config_variable('imds_use_ipv6', True)
        provider = IMDSRegionProvider(session)
        provider.provide()
        args, _ = send.call_args
        self.assertIn('[fd00:ec2::254]', args[0].url)

    @mock.patch('botocore.httpsession.URLLib3Session.send')
    def test_use_ipv4_by_default(self, send):
        session = Session()
        provider = IMDSRegionProvider(session)
        provider.provide()
        args, _ = send.call_args
        self.assertIn('169.254.169.254', args[0].url)

    @mock.patch('botocore.httpsession.URLLib3Session.send')
    def test_can_set_imds_endpoint_mode_to_ipv4(self, send):
        session = Session()
        session.set_config_variable(
            'ec2_metadata_service_endpoint_mode', 'ipv4'
        )
        provider = IMDSRegionProvider(session)
        provider.provide()
        args, _ = send.call_args
        self.assertIn('169.254.169.254', args[0].url)

    @mock.patch('botocore.httpsession.URLLib3Session.send')
    def test_can_set_imds_endpoint_mode_to_ipv6(self, send):
        session = Session()
        session.set_config_variable(
            'ec2_metadata_service_endpoint_mode', 'ipv6'
        )
        provider = IMDSRegionProvider(session)
        provider.provide()
        args, _ = send.call_args
        self.assertIn('[fd00:ec2::254]', args[0].url)

    @mock.patch('botocore.httpsession.URLLib3Session.send')
    def test_can_set_imds_service_endpoint(self, send):
        session = Session()
        session.set_config_variable(
            'ec2_metadata_service_endpoint', 'http://myendpoint/'
        )
        provider = IMDSRegionProvider(session)
        provider.provide()
        args, _ = send.call_args
        self.assertIn('http://myendpoint/', args[0].url)

    @mock.patch('botocore.httpsession.URLLib3Session.send')
    def test_can_set_imds_service_endpoint_custom(self, send):
        session = Session()
        session.set_config_variable(
            'ec2_metadata_service_endpoint', 'http://myendpoint'
        )
        provider = IMDSRegionProvider(session)
        provider.provide()
        args, _ = send.call_args
        self.assertIn('http://myendpoint/latest/meta-data', args[0].url)

    @mock.patch('botocore.httpsession.URLLib3Session.send')
    def test_imds_service_endpoint_overrides_ipv6_endpoint(self, send):
        session = Session()
        session.set_config_variable(
            'ec2_metadata_service_endpoint_mode', 'ipv6'
        )
        session.set_config_variable(
            'ec2_metadata_service_endpoint', 'http://myendpoint/'
        )
        provider = IMDSRegionProvider(session)
        provider.provide()
        args, _ = send.call_args
        self.assertIn('http://myendpoint/', args[0].url)


class TestSSOTokenLoader(unittest.TestCase):
    def setUp(self):
        super().setUp()
        self.session_name = 'admin'
        self.start_url = 'https://d-abc123.awsapps.com/start'
        self.cache_key = '40a89917e3175433e361b710a9d43528d7f1890a'
        self.session_cache_key = 'd033e22ae348aeb5660fc2140aec35850c4da997'
        self.access_token = 'totally.a.token'
        self.cached_token = {
            'accessToken': self.access_token,
            'expiresAt': '2002-10-18T03:52:38UTC',
        }
        self.cache = {}
        self.loader = SSOTokenLoader(cache=self.cache)

    def test_can_load_token_exists(self):
        self.cache[self.cache_key] = self.cached_token
        access_token = self.loader(self.start_url)
        self.assertEqual(self.cached_token, access_token)

    def test_can_handle_does_not_exist(self):
        with self.assertRaises(SSOTokenLoadError):
            self.loader(self.start_url)

    def test_can_handle_invalid_cache(self):
        self.cache[self.cache_key] = {}
        with self.assertRaises(SSOTokenLoadError):
            self.loader(self.start_url)

    def test_can_save_token(self):
        self.loader.save_token(self.start_url, self.cached_token)
        access_token = self.loader(self.start_url)
        self.assertEqual(self.cached_token, access_token)

    def test_can_save_token_sso_session(self):
        self.loader.save_token(
            self.start_url,
            self.cached_token,
            session_name=self.session_name,
        )
        access_token = self.loader(
            self.start_url,
            session_name=self.session_name,
        )
        self.assertEqual(self.cached_token, access_token)

    def test_can_load_token_exists_sso_session_name(self):
        self.cache[self.session_cache_key] = self.cached_token
        access_token = self.loader(
            self.start_url,
            session_name=self.session_name,
        )
        self.assertEqual(self.cached_token, access_token)


@pytest.mark.parametrize(
    'header_name, headers, expected',
    (
        ('test_header', {'test_header': 'foo'}, True),
        ('Test_Header', {'test_header': 'foo'}, True),
        ('test_header', {'Test_Header': 'foo'}, True),
        ('missing_header', {'Test_Header': 'foo'}, False),
        (None, {'Test_Header': 'foo'}, False),
        ('test_header', HeadersDict({'test_header': 'foo'}), True),
        ('Test_Header', HeadersDict({'test_header': 'foo'}), True),
        ('test_header', HeadersDict({'Test_Header': 'foo'}), True),
        ('missing_header', HeadersDict({'Test_Header': 'foo'}), False),
        (None, HeadersDict({'Test_Header': 'foo'}), False),
    ),
)
def test_has_header(header_name, headers, expected):
    assert has_header(header_name, headers) is expected


class TestDetermineContentLength(unittest.TestCase):
    def test_basic_bytes(self):
        length = determine_content_length(b'hello')
        self.assertEqual(length, 5)

    def test_empty_bytes(self):
        length = determine_content_length(b'')
        self.assertEqual(length, 0)

    def test_buffered_io_base(self):
        length = determine_content_length(io.BufferedIOBase())
        self.assertIsNone(length)

    def test_none(self):
        length = determine_content_length(None)
        self.assertEqual(length, 0)

    def test_basic_len_obj(self):
        class HasLen:
            def __len__(self):
                return 12

        length = determine_content_length(HasLen())
        self.assertEqual(length, 12)

    def test_non_seekable_fileobj(self):
        class Readable:
            def read(self, *args, **kwargs):
                pass

        length = determine_content_length(Readable())
        self.assertIsNone(length)

    def test_seekable_fileobj(self):
        class Seekable:
            _pos = 0

            def read(self, *args, **kwargs):
                pass

            def tell(self, *args, **kwargs):
                return self._pos

            def seek(self, *args, **kwargs):
                self._pos = 50

        length = determine_content_length(Seekable())
        self.assertEqual(length, 50)


@pytest.mark.parametrize(
    'url, expected',
    (
        ('https://s3-accelerate.amazonaws.com', True),
        ('https://s3-accelerate.amazonaws.com/', True),
        ('https://s3-accelerate.amazonaws.com/key', True),
        ('http://s3-accelerate.amazonaws.com/key', True),
        ('https://s3-accelerate.foo.amazonaws.com/key', False),
        # bucket prefixes are not allowed
        ('https://bucket.s3-accelerate.amazonaws.com/key', False),
        # S3 accelerate can be combined with dualstack
        ('https://s3-accelerate.dualstack.amazonaws.com/key', True),
        ('https://bucket.s3-accelerate.dualstack.amazonaws.com/key', False),
        ('https://s3-accelerate.dualstack.dualstack.amazonaws.com/key', False),
        ('https://s3-accelerate.dualstack.foo.amazonaws.com/key', False),
        ('https://dualstack.s3-accelerate.amazonaws.com/key', False),
        # assorted other ways for URLs to not be valid for s3-accelerate
        ('ftp://s3-accelerate.dualstack.foo.amazonaws.com/key', False),
        ('https://s3-accelerate.dualstack.foo.c2s.ic.gov/key', False),
        # None-valued url is accepted
        (None, False),
    ),
)
def test_is_s3_accelerate_url(url, expected):
    assert is_s3_accelerate_url(url) == expected


@pytest.mark.parametrize(
    'headers, default, expected',
    (
        ({}, 'ISO-8859-1', None),
        ({'Content-Type': 'text/html; charset=utf-8'}, 'default', 'utf-8'),
        ({'Content-Type': 'text/html; charset="utf-8"'}, 'default', 'utf-8'),
        ({'Content-Type': 'text/html'}, 'ascii', 'ascii'),
        ({'Content-Type': 'application/json'}, 'ISO-8859-1', None),
    ),
)
def test_get_encoding_from_headers(headers, default, expected):
    charset = get_encoding_from_headers(HeadersDict(headers), default=default)
    assert charset == expected
