# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import argparse
import itertools
import logging
import uuid

import mock
from oslo_config import cfg
from oslo_config import fixture as config
from oslo_serialization import jsonutils
import requests
import six
from testtools import matchers

from keystoneclient import adapter
from keystoneclient.auth import base
from keystoneclient import exceptions
from keystoneclient.i18n import _
from keystoneclient import session as client_session
from keystoneclient.tests.unit import utils


class SessionTests(utils.TestCase):

    TEST_URL = 'http://127.0.0.1:5000/'

    def setUp(self):
        super(SessionTests, self).setUp()
        self.deprecations.expect_deprecations()

    def test_get(self):
        session = client_session.Session()
        self.stub_url('GET', text='response')
        resp = session.get(self.TEST_URL)

        self.assertEqual('GET', self.requests_mock.last_request.method)
        self.assertEqual(resp.text, 'response')
        self.assertTrue(resp.ok)

    def test_post(self):
        session = client_session.Session()
        self.stub_url('POST', text='response')
        resp = session.post(self.TEST_URL, json={'hello': 'world'})

        self.assertEqual('POST', self.requests_mock.last_request.method)
        self.assertEqual(resp.text, 'response')
        self.assertTrue(resp.ok)
        self.assertRequestBodyIs(json={'hello': 'world'})

    def test_head(self):
        session = client_session.Session()
        self.stub_url('HEAD')
        resp = session.head(self.TEST_URL)

        self.assertEqual('HEAD', self.requests_mock.last_request.method)
        self.assertTrue(resp.ok)
        self.assertRequestBodyIs('')

    def test_put(self):
        session = client_session.Session()
        self.stub_url('PUT', text='response')
        resp = session.put(self.TEST_URL, json={'hello': 'world'})

        self.assertEqual('PUT', self.requests_mock.last_request.method)
        self.assertEqual(resp.text, 'response')
        self.assertTrue(resp.ok)
        self.assertRequestBodyIs(json={'hello': 'world'})

    def test_delete(self):
        session = client_session.Session()
        self.stub_url('DELETE', text='response')
        resp = session.delete(self.TEST_URL)

        self.assertEqual('DELETE', self.requests_mock.last_request.method)
        self.assertTrue(resp.ok)
        self.assertEqual(resp.text, 'response')

    def test_patch(self):
        session = client_session.Session()
        self.stub_url('PATCH', text='response')
        resp = session.patch(self.TEST_URL, json={'hello': 'world'})

        self.assertEqual('PATCH', self.requests_mock.last_request.method)
        self.assertTrue(resp.ok)
        self.assertEqual(resp.text, 'response')
        self.assertRequestBodyIs(json={'hello': 'world'})

    def test_user_agent(self):
        session = client_session.Session(user_agent='test-agent')
        self.stub_url('GET', text='response')
        resp = session.get(self.TEST_URL)

        self.assertTrue(resp.ok)
        self.assertRequestHeaderEqual('User-Agent', 'test-agent')

        resp = session.get(self.TEST_URL, headers={'User-Agent': 'new-agent'})
        self.assertTrue(resp.ok)
        self.assertRequestHeaderEqual('User-Agent', 'new-agent')

        resp = session.get(self.TEST_URL, headers={'User-Agent': 'new-agent'},
                           user_agent='overrides-agent')
        self.assertTrue(resp.ok)
        self.assertRequestHeaderEqual('User-Agent', 'overrides-agent')

    def test_http_session_opts(self):
        session = client_session.Session(cert='cert.pem', timeout=5,
                                         verify='certs')

        FAKE_RESP = utils.test_response(text='resp')
        RESP = mock.Mock(return_value=FAKE_RESP)

        with mock.patch.object(session.session, 'request', RESP) as mocked:
            session.post(self.TEST_URL, data='value')

            mock_args, mock_kwargs = mocked.call_args

            self.assertEqual(mock_args[0], 'POST')
            self.assertEqual(mock_args[1], self.TEST_URL)
            self.assertEqual(mock_kwargs['data'], 'value')
            self.assertEqual(mock_kwargs['cert'], 'cert.pem')
            self.assertEqual(mock_kwargs['verify'], 'certs')
            self.assertEqual(mock_kwargs['timeout'], 5)

    def test_not_found(self):
        session = client_session.Session()
        self.stub_url('GET', status_code=404)
        self.assertRaises(exceptions.NotFound, session.get, self.TEST_URL)

    def test_server_error(self):
        session = client_session.Session()
        self.stub_url('GET', status_code=500)
        self.assertRaises(exceptions.InternalServerError,
                          session.get, self.TEST_URL)

    def test_session_debug_output(self):
        """Test request and response headers in debug logs.

        in order to redact secure headers while debug is true.
        """
        session = client_session.Session(verify=False)
        headers = {'HEADERA': 'HEADERVALB',
                   'Content-Type': 'application/json'}
        security_headers = {'Authorization': uuid.uuid4().hex,
                            'X-Auth-Token': uuid.uuid4().hex,
                            'X-Subject-Token': uuid.uuid4().hex,
                            'X-Service-Token': uuid.uuid4().hex}
        body = '{"a": "b"}'
        data = '{"c": "d"}'
        all_headers = dict(
            itertools.chain(headers.items(), security_headers.items()))
        self.stub_url('POST', text=body, headers=all_headers)
        resp = session.post(self.TEST_URL, headers=all_headers, data=data)
        self.assertEqual(resp.status_code, 200)

        self.assertIn('curl', self.logger.output)
        self.assertIn('POST', self.logger.output)
        self.assertIn('--insecure', self.logger.output)
        self.assertIn(body, self.logger.output)
        self.assertIn("'%s'" % data, self.logger.output)

        for k, v in headers.items():
            self.assertIn(k, self.logger.output)
            self.assertIn(v, self.logger.output)

        # Assert that response headers contains actual values and
        # only debug logs has been masked
        for k, v in security_headers.items():
            self.assertIn('%s: {SHA1}' % k, self.logger.output)
            self.assertEqual(v, resp.headers[k])
            self.assertNotIn(v, self.logger.output)

    def test_logs_failed_output(self):
        """Test that output is logged even for failed requests."""
        session = client_session.Session()
        body = {uuid.uuid4().hex: uuid.uuid4().hex}

        self.stub_url('GET', json=body, status_code=400,
                      headers={'Content-Type': 'application/json'})
        resp = session.get(self.TEST_URL, raise_exc=False)

        self.assertEqual(resp.status_code, 400)
        self.assertIn(list(body.keys())[0], self.logger.output)
        self.assertIn(list(body.values())[0], self.logger.output)

    def test_logging_body_only_for_specified_content_types(self):
        """Verify response body is only logged in specific content types.

        Response bodies are logged only when the response's Content-Type header
        is set to application/json. This prevents us to get an unexpected
        MemoryError when reading arbitrary responses, such as streams.
        """
        OMITTED_BODY = ('Omitted, Content-Type is set to %s. Only '
                        'application/json responses have their bodies logged.')
        session = client_session.Session(verify=False)

        # Content-Type is not set
        body = jsonutils.dumps({'token': {'id': '...'}})
        self.stub_url('POST', text=body)
        session.post(self.TEST_URL)
        self.assertNotIn(body, self.logger.output)
        self.assertIn(OMITTED_BODY % None, self.logger.output)

        # Content-Type is set to text/xml
        body = '<token><id>...</id></token>'
        self.stub_url('POST', text=body, headers={'Content-Type': 'text/xml'})
        session.post(self.TEST_URL)
        self.assertNotIn(body, self.logger.output)
        self.assertIn(OMITTED_BODY % 'text/xml', self.logger.output)

        # Content-Type is set to application/json
        body = jsonutils.dumps({'token': {'id': '...'}})
        self.stub_url('POST', text=body,
                      headers={'Content-Type': 'application/json'})
        session.post(self.TEST_URL)
        self.assertIn(body, self.logger.output)
        self.assertNotIn(OMITTED_BODY % 'application/json', self.logger.output)

        # Content-Type is set to application/json; charset=UTF-8
        body = jsonutils.dumps({'token': {'id': '...'}})
        self.stub_url(
            'POST', text=body,
            headers={'Content-Type': 'application/json; charset=UTF-8'})
        session.post(self.TEST_URL)
        self.assertIn(body, self.logger.output)
        self.assertNotIn(OMITTED_BODY % 'application/json; charset=UTF-8',
                         self.logger.output)

    def test_unicode_data_in_debug_output(self):
        """Verify that ascii-encodable data is logged without modification."""
        session = client_session.Session(verify=False)

        body = 'RESP'
        data = u'αβγδ'
        self.stub_url('POST', text=body)
        session.post(self.TEST_URL, data=data)

        self.assertIn("'%s'" % data, self.logger.output)

    def test_binary_data_not_in_debug_output(self):
        """Verify that non-ascii-encodable data causes replacement."""
        if six.PY2:
            data = "my data" + chr(255)
        else:
            # Python 3 logging handles binary data well.
            return

        session = client_session.Session(verify=False)

        body = 'RESP'
        self.stub_url('POST', text=body)

        # Forced mixed unicode and byte strings in request
        # elements to make sure that all joins are appropriately
        # handled (any join of unicode and byte strings should
        # raise a UnicodeDecodeError)
        session.post(six.text_type(self.TEST_URL), data=data)

        self.assertNotIn('my data', self.logger.output)

    def test_logging_cacerts(self):
        path_to_certs = '/path/to/certs'
        session = client_session.Session(verify=path_to_certs)

        self.stub_url('GET', text='text')
        session.get(self.TEST_URL)

        self.assertIn('--cacert', self.logger.output)
        self.assertIn(path_to_certs, self.logger.output)

    def test_connect_retries(self):

        def _timeout_error(request, context):
            raise requests.exceptions.Timeout()

        self.stub_url('GET', text=_timeout_error)

        session = client_session.Session()
        retries = 3

        with mock.patch('time.sleep') as m:
            self.assertRaises(exceptions.RequestTimeout,
                              session.get,
                              self.TEST_URL, connect_retries=retries)

            self.assertEqual(retries, m.call_count)
            # 3 retries finishing with 2.0 means 0.5, 1.0 and 2.0
            m.assert_called_with(2.0)

        # we count retries so there will be one initial request + 3 retries
        self.assertThat(self.requests_mock.request_history,
                        matchers.HasLength(retries + 1))

    def test_uses_tcp_keepalive_by_default(self):
        session = client_session.Session()
        requests_session = session.session
        self.assertIsInstance(requests_session.adapters['http://'],
                              client_session.TCPKeepAliveAdapter)
        self.assertIsInstance(requests_session.adapters['https://'],
                              client_session.TCPKeepAliveAdapter)

    def test_does_not_set_tcp_keepalive_on_custom_sessions(self):
        mock_session = mock.Mock()
        client_session.Session(session=mock_session)
        self.assertFalse(mock_session.mount.called)

    def test_ssl_error_message(self):
        error = uuid.uuid4().hex

        def _ssl_error(request, context):
            raise requests.exceptions.SSLError(error)

        self.stub_url('GET', text=_ssl_error)
        session = client_session.Session()

        # The exception should contain the URL and details about the SSL error
        msg = _('SSL exception connecting to %(url)s: %(error)s') % {
            'url': self.TEST_URL, 'error': error}
        six.assertRaisesRegex(self,
                              exceptions.SSLError,
                              msg,
                              session.get,
                              self.TEST_URL)

    def test_mask_password_in_http_log_response(self):
        session = client_session.Session()

        def fake_debug(msg):
            self.assertNotIn('verybadpass', msg)

        logger = mock.Mock(isEnabledFor=mock.Mock(return_value=True))
        logger.debug = mock.Mock(side_effect=fake_debug)
        body = {
            "connection_info": {
                "driver_volume_type": "iscsi",
                "data": {
                    "auth_password": "verybadpass",
                    "target_discovered": False,
                    "encrypted": False,
                    "qos_specs": None,
                    "target_iqn": ("iqn.2010-10.org.openstack:volume-"
                                   "744d2085-8e78-40a5-8659-ef3cffb2480e"),
                    "target_portal": "172.99.69.228:3260",
                    "volume_id": "744d2085-8e78-40a5-8659-ef3cffb2480e",
                    "target_lun": 1,
                    "access_mode": "rw",
                    "auth_username": "verybadusername",
                    "auth_method": "CHAP"}}}
        body_json = jsonutils.dumps(body)
        response = mock.Mock(text=body_json, status_code=200,
                             headers={'content-type': 'application/json'})
        session._http_log_response(response, logger)
        self.assertEqual(1, logger.debug.call_count)


class TCPKeepAliveAdapter(utils.TestCase):

    @mock.patch.object(client_session, 'socket')
    @mock.patch('requests.adapters.HTTPAdapter.init_poolmanager')
    def test_init_poolmanager_all_options(self, mock_parent_init_poolmanager,
                                          mock_socket):
        # properties expected to be in socket.
        mock_socket.TCP_KEEPIDLE = mock.sentinel.TCP_KEEPIDLE
        mock_socket.TCP_KEEPCNT = mock.sentinel.TCP_KEEPCNT
        mock_socket.TCP_KEEPINTVL = mock.sentinel.TCP_KEEPINTVL
        desired_opts = [mock_socket.TCP_KEEPIDLE, mock_socket.TCP_KEEPCNT,
                        mock_socket.TCP_KEEPINTVL]

        adapter = client_session.TCPKeepAliveAdapter()
        adapter.init_poolmanager()

        call_args, call_kwargs = mock_parent_init_poolmanager.call_args
        called_socket_opts = call_kwargs['socket_options']
        call_options = [opt for (protocol, opt, value) in called_socket_opts]
        for opt in desired_opts:
            self.assertIn(opt, call_options)

    @mock.patch.object(client_session, 'socket')
    @mock.patch('requests.adapters.HTTPAdapter.init_poolmanager')
    def test_init_poolmanager(self, mock_parent_init_poolmanager, mock_socket):
        spec = ['IPPROTO_TCP', 'TCP_NODELAY', 'SOL_SOCKET', 'SO_KEEPALIVE']
        mock_socket.mock_add_spec(spec)
        adapter = client_session.TCPKeepAliveAdapter()
        adapter.init_poolmanager()

        call_args, call_kwargs = mock_parent_init_poolmanager.call_args
        called_socket_opts = call_kwargs['socket_options']
        call_options = [opt for (protocol, opt, value) in called_socket_opts]
        self.assertEqual([mock_socket.TCP_NODELAY, mock_socket.SO_KEEPALIVE],
                         call_options)


class RedirectTests(utils.TestCase):

    REDIRECT_CHAIN = ['http://myhost:3445/',
                      'http://anotherhost:6555/',
                      'http://thirdhost/',
                      'http://finaldestination:55/']

    DEFAULT_REDIRECT_BODY = 'Redirect'
    DEFAULT_RESP_BODY = 'Found'

    def setUp(self):
        super(RedirectTests, self).setUp()
        self.deprecations.expect_deprecations()

    def setup_redirects(self, method='GET', status_code=305,
                        redirect_kwargs=None, final_kwargs=None):
        redirect_kwargs = redirect_kwargs or {}
        final_kwargs = final_kwargs or {}
        redirect_kwargs.setdefault('text', self.DEFAULT_REDIRECT_BODY)

        for s, d in zip(self.REDIRECT_CHAIN, self.REDIRECT_CHAIN[1:]):
            self.requests_mock.register_uri(method, s, status_code=status_code,
                                            headers={'Location': d},
                                            **redirect_kwargs)

        final_kwargs.setdefault('status_code', 200)
        final_kwargs.setdefault('text', self.DEFAULT_RESP_BODY)
        self.requests_mock.register_uri(method, self.REDIRECT_CHAIN[-1],
                                        **final_kwargs)

    def assertResponse(self, resp):
        self.assertEqual(resp.status_code, 200)
        self.assertEqual(resp.text, self.DEFAULT_RESP_BODY)

    def test_basic_get(self):
        session = client_session.Session()
        self.setup_redirects()
        resp = session.get(self.REDIRECT_CHAIN[-2])
        self.assertResponse(resp)

    def test_basic_post_keeps_correct_method(self):
        session = client_session.Session()
        self.setup_redirects(method='POST', status_code=301)
        resp = session.post(self.REDIRECT_CHAIN[-2])
        self.assertResponse(resp)

    def test_redirect_forever(self):
        session = client_session.Session(redirect=True)
        self.setup_redirects()
        resp = session.get(self.REDIRECT_CHAIN[0])
        self.assertResponse(resp)
        self.assertTrue(len(resp.history), len(self.REDIRECT_CHAIN))

    def test_no_redirect(self):
        session = client_session.Session(redirect=False)
        self.setup_redirects()
        resp = session.get(self.REDIRECT_CHAIN[0])
        self.assertEqual(resp.status_code, 305)
        self.assertEqual(resp.url, self.REDIRECT_CHAIN[0])

    def test_redirect_limit(self):
        self.setup_redirects()
        for i in (1, 2):
            session = client_session.Session(redirect=i)
            resp = session.get(self.REDIRECT_CHAIN[0])
            self.assertEqual(resp.status_code, 305)
            self.assertEqual(resp.url, self.REDIRECT_CHAIN[i])
            self.assertEqual(resp.text, self.DEFAULT_REDIRECT_BODY)

    def test_history_matches_requests(self):
        self.setup_redirects(status_code=301)
        session = client_session.Session(redirect=True)
        req_resp = requests.get(self.REDIRECT_CHAIN[0],
                                allow_redirects=True)

        ses_resp = session.get(self.REDIRECT_CHAIN[0])

        self.assertEqual(len(req_resp.history), len(ses_resp.history))

        for r, s in zip(req_resp.history, ses_resp.history):
            self.assertEqual(r.url, s.url)
            self.assertEqual(r.status_code, s.status_code)


class ConstructSessionFromArgsTests(utils.TestCase):

    KEY = 'keyfile'
    CERT = 'certfile'
    CACERT = 'cacert-path'

    def _s(self, k=None, **kwargs):
        k = k or kwargs
        with self.deprecations.expect_deprecations_here():
            return client_session.Session.construct(k)

    def test_verify(self):
        self.assertFalse(self._s(insecure=True).verify)
        self.assertTrue(self._s(verify=True, insecure=True).verify)
        self.assertFalse(self._s(verify=False, insecure=True).verify)
        self.assertEqual(self._s(cacert=self.CACERT).verify, self.CACERT)

    def test_cert(self):
        tup = (self.CERT, self.KEY)
        self.assertEqual(self._s(cert=tup).cert, tup)
        self.assertEqual(self._s(cert=self.CERT, key=self.KEY).cert, tup)
        self.assertIsNone(self._s(key=self.KEY).cert)

    def test_pass_through(self):
        value = 42  # only a number because timeout needs to be
        for key in ['timeout', 'session', 'original_ip', 'user_agent']:
            args = {key: value}
            self.assertEqual(getattr(self._s(args), key), value)
            self.assertNotIn(key, args)


class AuthPlugin(base.BaseAuthPlugin):
    """Very simple debug authentication plugin.

    Takes Parameters such that it can throw exceptions at the right times.
    """

    TEST_TOKEN = utils.TestCase.TEST_TOKEN
    TEST_USER_ID = 'aUser'
    TEST_PROJECT_ID = 'aProject'

    SERVICE_URLS = {
        'identity': {'public': 'http://identity-public:1111/v2.0',
                     'admin': 'http://identity-admin:1111/v2.0'},
        'compute': {'public': 'http://compute-public:2222/v1.0',
                    'admin': 'http://compute-admin:2222/v1.0'},
        'image': {'public': 'http://image-public:3333/v2.0',
                  'admin': 'http://image-admin:3333/v2.0'}
    }

    def __init__(self, token=TEST_TOKEN, invalidate=True):
        self.token = token
        self._invalidate = invalidate

    def get_token(self, session):
        return self.token

    def get_endpoint(self, session, service_type=None, interface=None,
                     **kwargs):
        try:
            return self.SERVICE_URLS[service_type][interface]
        except (KeyError, AttributeError):
            return None

    def invalidate(self):
        return self._invalidate

    def get_user_id(self, session):
        return self.TEST_USER_ID

    def get_project_id(self, session):
        return self.TEST_PROJECT_ID


class CalledAuthPlugin(base.BaseAuthPlugin):

    ENDPOINT = 'http://fakeendpoint/'

    def __init__(self, invalidate=True):
        self.get_token_called = False
        self.get_endpoint_called = False
        self.endpoint_arguments = {}
        self.invalidate_called = False
        self._invalidate = invalidate

    def get_token(self, session):
        self.get_token_called = True
        return utils.TestCase.TEST_TOKEN

    def get_endpoint(self, session, **kwargs):
        self.get_endpoint_called = True
        self.endpoint_arguments = kwargs
        return self.ENDPOINT

    def invalidate(self):
        self.invalidate_called = True
        return self._invalidate


class SessionAuthTests(utils.TestCase):

    TEST_URL = 'http://127.0.0.1:5000/'
    TEST_JSON = {'hello': 'world'}

    def setUp(self):
        super(SessionAuthTests, self).setUp()
        self.deprecations.expect_deprecations()

    def stub_service_url(self, service_type, interface, path,
                         method='GET', **kwargs):
        base_url = AuthPlugin.SERVICE_URLS[service_type][interface]
        uri = "%s/%s" % (base_url.rstrip('/'), path.lstrip('/'))

        self.requests_mock.register_uri(method, uri, **kwargs)

    def test_auth_plugin_default_with_plugin(self):
        self.stub_url('GET', base_url=self.TEST_URL, json=self.TEST_JSON)

        # if there is an auth_plugin then it should default to authenticated
        auth = AuthPlugin()
        sess = client_session.Session(auth=auth)
        resp = sess.get(self.TEST_URL)
        self.assertEqual(resp.json(), self.TEST_JSON)

        self.assertRequestHeaderEqual('X-Auth-Token', AuthPlugin.TEST_TOKEN)

    def test_auth_plugin_disable(self):
        self.stub_url('GET', base_url=self.TEST_URL, json=self.TEST_JSON)

        auth = AuthPlugin()
        sess = client_session.Session(auth=auth)
        resp = sess.get(self.TEST_URL, authenticated=False)
        self.assertEqual(resp.json(), self.TEST_JSON)

        self.assertRequestHeaderEqual('X-Auth-Token', None)

    def test_service_type_urls(self):
        service_type = 'compute'
        interface = 'public'
        path = '/instances'
        status = 200
        body = 'SUCCESS'

        self.stub_service_url(service_type=service_type,
                              interface=interface,
                              path=path,
                              status_code=status,
                              text=body)

        sess = client_session.Session(auth=AuthPlugin())
        resp = sess.get(path,
                        endpoint_filter={'service_type': service_type,
                                         'interface': interface})

        self.assertEqual(self.requests_mock.last_request.url,
                         AuthPlugin.SERVICE_URLS['compute']['public'] + path)
        self.assertEqual(resp.text, body)
        self.assertEqual(resp.status_code, status)

    def test_service_url_raises_if_no_auth_plugin(self):
        sess = client_session.Session()
        self.assertRaises(exceptions.MissingAuthPlugin,
                          sess.get, '/path',
                          endpoint_filter={'service_type': 'compute',
                                           'interface': 'public'})

    def test_service_url_raises_if_no_url_returned(self):
        sess = client_session.Session(auth=AuthPlugin())
        self.assertRaises(exceptions.EndpointNotFound,
                          sess.get, '/path',
                          endpoint_filter={'service_type': 'unknown',
                                           'interface': 'public'})

    def test_raises_exc_only_when_asked(self):
        # A request that returns a HTTP error should by default raise an
        # exception by default, if you specify raise_exc=False then it will not
        self.requests_mock.get(self.TEST_URL, status_code=401)

        sess = client_session.Session()
        self.assertRaises(exceptions.Unauthorized, sess.get, self.TEST_URL)

        resp = sess.get(self.TEST_URL, raise_exc=False)
        self.assertEqual(401, resp.status_code)

    def test_passed_auth_plugin(self):
        passed = CalledAuthPlugin()
        sess = client_session.Session()

        self.requests_mock.get(CalledAuthPlugin.ENDPOINT + 'path',
                               status_code=200)
        endpoint_filter = {'service_type': 'identity'}

        # no plugin with authenticated won't work
        self.assertRaises(exceptions.MissingAuthPlugin, sess.get, 'path',
                          authenticated=True)

        # no plugin with an endpoint filter won't work
        self.assertRaises(exceptions.MissingAuthPlugin, sess.get, 'path',
                          authenticated=False, endpoint_filter=endpoint_filter)

        resp = sess.get('path', auth=passed, endpoint_filter=endpoint_filter)

        self.assertEqual(200, resp.status_code)
        self.assertTrue(passed.get_endpoint_called)
        self.assertTrue(passed.get_token_called)

    def test_passed_auth_plugin_overrides(self):
        fixed = CalledAuthPlugin()
        passed = CalledAuthPlugin()

        sess = client_session.Session(fixed)

        self.requests_mock.get(CalledAuthPlugin.ENDPOINT + 'path',
                               status_code=200)

        resp = sess.get('path', auth=passed,
                        endpoint_filter={'service_type': 'identity'})

        self.assertEqual(200, resp.status_code)
        self.assertTrue(passed.get_endpoint_called)
        self.assertTrue(passed.get_token_called)
        self.assertFalse(fixed.get_endpoint_called)
        self.assertFalse(fixed.get_token_called)

    def test_requests_auth_plugin(self):
        sess = client_session.Session()

        requests_auth = object()

        FAKE_RESP = utils.test_response(text='resp')
        RESP = mock.Mock(return_value=FAKE_RESP)

        with mock.patch.object(sess.session, 'request', RESP) as mocked:
            sess.get(self.TEST_URL, requests_auth=requests_auth)

            mocked.assert_called_once_with('GET', self.TEST_URL,
                                           headers=mock.ANY,
                                           allow_redirects=mock.ANY,
                                           auth=requests_auth,
                                           verify=mock.ANY)

    def test_reauth_called(self):
        auth = CalledAuthPlugin(invalidate=True)
        sess = client_session.Session(auth=auth)

        self.requests_mock.get(self.TEST_URL,
                               [{'text': 'Failed', 'status_code': 401},
                                {'text': 'Hello', 'status_code': 200}])

        # allow_reauth=True is the default
        resp = sess.get(self.TEST_URL, authenticated=True)

        self.assertEqual(200, resp.status_code)
        self.assertEqual('Hello', resp.text)
        self.assertTrue(auth.invalidate_called)

    def test_reauth_not_called(self):
        auth = CalledAuthPlugin(invalidate=True)
        sess = client_session.Session(auth=auth)

        self.requests_mock.get(self.TEST_URL,
                               [{'text': 'Failed', 'status_code': 401},
                                {'text': 'Hello', 'status_code': 200}])

        self.assertRaises(exceptions.Unauthorized, sess.get, self.TEST_URL,
                          authenticated=True, allow_reauth=False)
        self.assertFalse(auth.invalidate_called)

    def test_endpoint_override_overrides_filter(self):
        auth = CalledAuthPlugin()
        sess = client_session.Session(auth=auth)

        override_base = 'http://mytest/'
        path = 'path'
        override_url = override_base + path
        resp_text = uuid.uuid4().hex

        self.requests_mock.get(override_url, text=resp_text)

        resp = sess.get(path,
                        endpoint_override=override_base,
                        endpoint_filter={'service_type': 'identity'})

        self.assertEqual(resp_text, resp.text)
        self.assertEqual(override_url, self.requests_mock.last_request.url)

        self.assertTrue(auth.get_token_called)
        self.assertFalse(auth.get_endpoint_called)

    def test_endpoint_override_ignore_full_url(self):
        auth = CalledAuthPlugin()
        sess = client_session.Session(auth=auth)

        path = 'path'
        url = self.TEST_URL + path

        resp_text = uuid.uuid4().hex
        self.requests_mock.get(url, text=resp_text)

        resp = sess.get(url,
                        endpoint_override='http://someother.url',
                        endpoint_filter={'service_type': 'identity'})

        self.assertEqual(resp_text, resp.text)
        self.assertEqual(url, self.requests_mock.last_request.url)

        self.assertTrue(auth.get_token_called)
        self.assertFalse(auth.get_endpoint_called)

    def test_user_and_project_id(self):
        auth = AuthPlugin()
        sess = client_session.Session(auth=auth)

        self.assertEqual(auth.TEST_USER_ID, sess.get_user_id())
        self.assertEqual(auth.TEST_PROJECT_ID, sess.get_project_id())

    def test_logger_object_passed(self):
        logger = logging.getLogger(uuid.uuid4().hex)
        logger.setLevel(logging.DEBUG)
        logger.propagate = False

        io = six.StringIO()
        handler = logging.StreamHandler(io)
        logger.addHandler(handler)

        auth = AuthPlugin()
        sess = client_session.Session(auth=auth)
        response = {uuid.uuid4().hex: uuid.uuid4().hex}

        self.stub_url('GET',
                      json=response,
                      headers={'Content-Type': 'application/json'})

        resp = sess.get(self.TEST_URL, logger=logger)

        self.assertEqual(response, resp.json())
        output = io.getvalue()

        self.assertIn(self.TEST_URL, output)
        self.assertIn(list(response.keys())[0], output)
        self.assertIn(list(response.values())[0], output)

        self.assertNotIn(self.TEST_URL, self.logger.output)
        self.assertNotIn(list(response.keys())[0], self.logger.output)
        self.assertNotIn(list(response.values())[0], self.logger.output)


class AdapterTest(utils.TestCase):

    SERVICE_TYPE = uuid.uuid4().hex
    SERVICE_NAME = uuid.uuid4().hex
    INTERFACE = uuid.uuid4().hex
    REGION_NAME = uuid.uuid4().hex
    USER_AGENT = uuid.uuid4().hex
    VERSION = uuid.uuid4().hex

    TEST_URL = CalledAuthPlugin.ENDPOINT

    def setUp(self):
        super(AdapterTest, self).setUp()
        self.deprecations.expect_deprecations()

    def _create_loaded_adapter(self):
        auth = CalledAuthPlugin()
        sess = client_session.Session()
        return adapter.Adapter(sess,
                               auth=auth,
                               service_type=self.SERVICE_TYPE,
                               service_name=self.SERVICE_NAME,
                               interface=self.INTERFACE,
                               region_name=self.REGION_NAME,
                               user_agent=self.USER_AGENT,
                               version=self.VERSION)

    def _verify_endpoint_called(self, adpt):
        self.assertEqual(self.SERVICE_TYPE,
                         adpt.auth.endpoint_arguments['service_type'])
        self.assertEqual(self.SERVICE_NAME,
                         adpt.auth.endpoint_arguments['service_name'])
        self.assertEqual(self.INTERFACE,
                         adpt.auth.endpoint_arguments['interface'])
        self.assertEqual(self.REGION_NAME,
                         adpt.auth.endpoint_arguments['region_name'])
        self.assertEqual(self.VERSION,
                         adpt.auth.endpoint_arguments['version'])

    def test_setting_variables_on_request(self):
        response = uuid.uuid4().hex
        self.stub_url('GET', text=response)
        adpt = self._create_loaded_adapter()
        resp = adpt.get('/')
        self.assertEqual(resp.text, response)

        self._verify_endpoint_called(adpt)
        self.assertTrue(adpt.auth.get_token_called)
        self.assertRequestHeaderEqual('User-Agent', self.USER_AGENT)

    def test_setting_variables_on_get_endpoint(self):
        adpt = self._create_loaded_adapter()
        url = adpt.get_endpoint()

        self.assertEqual(self.TEST_URL, url)
        self._verify_endpoint_called(adpt)

    def test_legacy_binding(self):
        key = uuid.uuid4().hex
        val = uuid.uuid4().hex
        response = jsonutils.dumps({key: val})

        self.stub_url('GET', text=response)

        auth = CalledAuthPlugin()
        sess = client_session.Session(auth=auth)
        adpt = adapter.LegacyJsonAdapter(sess,
                                         service_type=self.SERVICE_TYPE,
                                         user_agent=self.USER_AGENT)

        resp, body = adpt.get('/')
        self.assertEqual(self.SERVICE_TYPE,
                         auth.endpoint_arguments['service_type'])
        self.assertEqual(resp.text, response)
        self.assertEqual(val, body[key])

    def test_legacy_binding_non_json_resp(self):
        response = uuid.uuid4().hex
        self.stub_url('GET', text=response,
                      headers={'Content-Type': 'text/html'})

        auth = CalledAuthPlugin()
        sess = client_session.Session(auth=auth)
        adpt = adapter.LegacyJsonAdapter(sess,
                                         service_type=self.SERVICE_TYPE,
                                         user_agent=self.USER_AGENT)

        resp, body = adpt.get('/')
        self.assertEqual(self.SERVICE_TYPE,
                         auth.endpoint_arguments['service_type'])
        self.assertEqual(resp.text, response)
        self.assertIsNone(body)

    def test_methods(self):
        sess = client_session.Session()
        adpt = adapter.Adapter(sess)
        url = 'http://url'

        for method in ['get', 'head', 'post', 'put', 'patch', 'delete']:
            with mock.patch.object(adpt, 'request') as m:
                getattr(adpt, method)(url)
                m.assert_called_once_with(url, method.upper())

    def test_setting_endpoint_override(self):
        endpoint_override = 'http://overrideurl'
        path = '/path'
        endpoint_url = endpoint_override + path

        auth = CalledAuthPlugin()
        sess = client_session.Session(auth=auth)
        adpt = adapter.Adapter(sess, endpoint_override=endpoint_override)

        response = uuid.uuid4().hex
        self.requests_mock.get(endpoint_url, text=response)

        resp = adpt.get(path)

        self.assertEqual(response, resp.text)
        self.assertEqual(endpoint_url, self.requests_mock.last_request.url)

        self.assertEqual(endpoint_override, adpt.get_endpoint())

    def test_adapter_invalidate(self):
        auth = CalledAuthPlugin()
        sess = client_session.Session()
        adpt = adapter.Adapter(sess, auth=auth)

        adpt.invalidate()

        self.assertTrue(auth.invalidate_called)

    def test_adapter_get_token(self):
        auth = CalledAuthPlugin()
        sess = client_session.Session()
        adpt = adapter.Adapter(sess, auth=auth)

        self.assertEqual(self.TEST_TOKEN, adpt.get_token())
        self.assertTrue(auth.get_token_called)

    def test_adapter_connect_retries(self):
        retries = 2
        sess = client_session.Session()
        adpt = adapter.Adapter(sess, connect_retries=retries)

        def _refused_error(request, context):
            raise requests.exceptions.ConnectionError()

        self.stub_url('GET', text=_refused_error)

        with mock.patch('time.sleep') as m:
            self.assertRaises(exceptions.ConnectionRefused,
                              adpt.get, self.TEST_URL)
            self.assertEqual(retries, m.call_count)

        # we count retries so there will be one initial request + 2 retries
        self.assertThat(self.requests_mock.request_history,
                        matchers.HasLength(retries + 1))

    def test_user_and_project_id(self):
        auth = AuthPlugin()
        sess = client_session.Session()
        adpt = adapter.Adapter(sess, auth=auth)

        self.assertEqual(auth.TEST_USER_ID, adpt.get_user_id())
        self.assertEqual(auth.TEST_PROJECT_ID, adpt.get_project_id())

    def test_logger_object_passed(self):
        logger = logging.getLogger(uuid.uuid4().hex)
        logger.setLevel(logging.DEBUG)
        logger.propagate = False

        io = six.StringIO()
        handler = logging.StreamHandler(io)
        logger.addHandler(handler)

        auth = AuthPlugin()
        sess = client_session.Session(auth=auth)
        adpt = adapter.Adapter(sess, auth=auth, logger=logger)

        response = {uuid.uuid4().hex: uuid.uuid4().hex}

        self.stub_url('GET', json=response,
                      headers={'Content-Type': 'application/json'})

        resp = adpt.get(self.TEST_URL, logger=logger)

        self.assertEqual(response, resp.json())
        output = io.getvalue()

        self.assertIn(self.TEST_URL, output)
        self.assertIn(list(response.keys())[0], output)
        self.assertIn(list(response.values())[0], output)

        self.assertNotIn(self.TEST_URL, self.logger.output)
        self.assertNotIn(list(response.keys())[0], self.logger.output)
        self.assertNotIn(list(response.values())[0], self.logger.output)


class ConfLoadingTests(utils.TestCase):

    GROUP = 'sessiongroup'

    def setUp(self):
        super(ConfLoadingTests, self).setUp()

        self.conf_fixture = self.useFixture(config.Config())
        client_session.Session.register_conf_options(self.conf_fixture.conf,
                                                     self.GROUP)

    def config(self, **kwargs):
        kwargs['group'] = self.GROUP
        self.conf_fixture.config(**kwargs)

    def get_session(self, **kwargs):
        with self.deprecations.expect_deprecations_here():
            return client_session.Session.load_from_conf_options(
                self.conf_fixture.conf,
                self.GROUP,
                **kwargs)

    def test_insecure_timeout(self):
        self.config(insecure=True, timeout=5)
        s = self.get_session()

        self.assertFalse(s.verify)
        self.assertEqual(5, s.timeout)

    def test_client_certs(self):
        cert = '/path/to/certfile'
        key = '/path/to/keyfile'

        self.config(certfile=cert, keyfile=key)
        s = self.get_session()

        self.assertTrue(s.verify)
        self.assertEqual((cert, key), s.cert)

    def test_cacert(self):
        cafile = '/path/to/cacert'

        self.config(cafile=cafile)
        s = self.get_session()

        self.assertEqual(cafile, s.verify)

    def test_deprecated(self):
        def new_deprecated():
            return cfg.DeprecatedOpt(uuid.uuid4().hex, group=uuid.uuid4().hex)

        opt_names = ['cafile', 'certfile', 'keyfile', 'insecure', 'timeout']
        depr = dict([(n, [new_deprecated()]) for n in opt_names])
        opts = client_session.Session.get_conf_options(deprecated_opts=depr)

        self.assertThat(opt_names, matchers.HasLength(len(opts)))
        for opt in opts:
            self.assertIn(depr[opt.name][0], opt.deprecated_opts)


class CliLoadingTests(utils.TestCase):

    def setUp(self):
        super(CliLoadingTests, self).setUp()

        self.parser = argparse.ArgumentParser()
        client_session.Session.register_cli_options(self.parser)

    def get_session(self, val, **kwargs):
        args = self.parser.parse_args(val.split())
        with self.deprecations.expect_deprecations_here():
            return client_session.Session.load_from_cli_options(args, **kwargs)

    def test_insecure_timeout(self):
        s = self.get_session('--insecure --timeout 5.5')

        self.assertFalse(s.verify)
        self.assertEqual(5.5, s.timeout)

    def test_client_certs(self):
        cert = '/path/to/certfile'
        key = '/path/to/keyfile'

        s = self.get_session('--os-cert %s --os-key %s' % (cert, key))

        self.assertTrue(s.verify)
        self.assertEqual((cert, key), s.cert)

    def test_cacert(self):
        cacert = '/path/to/cacert'

        s = self.get_session('--os-cacert %s' % cacert)

        self.assertEqual(cacert, s.verify)
