# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import http.client
import io
import json
import re
import urllib.parse

import purl
import requests
from urllib3 import HTTPResponse

import requests_mock
from . import base


class MyExc(Exception):
    pass


class SessionAdapterTests(base.TestCase):

    PREFIX = "mock"

    def setUp(self):
        super(SessionAdapterTests, self).setUp()

        self.adapter = requests_mock.Adapter()
        self.session = requests.Session()
        self.session.mount(self.PREFIX, self.adapter)

        self.url = '%s://example.com/test' % self.PREFIX
        self.headers = {'header_a': 'A', 'header_b': 'B'}

    def assertHeaders(self, resp):
        for k, v in self.headers.items():
            self.assertEqual(v, resp.headers[k])

    def assertLastRequest(self, method='GET', body=None):
        self.assertEqual(self.url, self.adapter.last_request.url)
        self.assertEqual(method, self.adapter.last_request.method)
        self.assertEqual(body, self.adapter.last_request.body)

        url_parts = urllib.parse.urlparse(self.url)
        qs = urllib.parse.parse_qs(url_parts.query)
        self.assertEqual(url_parts.scheme, self.adapter.last_request.scheme)
        self.assertEqual(url_parts.netloc, self.adapter.last_request.netloc)
        self.assertEqual(url_parts.path, self.adapter.last_request.path)
        self.assertEqual(url_parts.query, self.adapter.last_request.query)
        self.assertEqual(url_parts.query, self.adapter.last_request.query)
        self.assertEqual(qs, self.adapter.last_request.qs)

    def test_content(self):
        data = b'testdata'

        self.adapter.register_uri('GET',
                                  self.url,
                                  content=data,
                                  headers=self.headers)
        resp = self.session.get(self.url)
        self.assertEqual(data, resp.content)
        self.assertHeaders(resp)
        self.assertLastRequest()

    def test_content_callback(self):
        status_code = 401
        data = b'testdata'

        def _content_cb(request, context):
            context.status_code = status_code
            context.headers.update(self.headers)
            return data

        self.adapter.register_uri('GET',
                                  self.url,
                                  content=_content_cb)
        resp = self.session.get(self.url)
        self.assertEqual(status_code, resp.status_code)
        self.assertEqual(data, resp.content)
        self.assertHeaders(resp)
        self.assertLastRequest()

    def test_text(self):
        data = u'testdata'

        self.adapter.register_uri('GET',
                                  self.url,
                                  text=data,
                                  headers=self.headers)
        resp = self.session.get(self.url)
        self.assertEqual(data.encode('utf-8'), resp.content)
        self.assertEqual(data, resp.text)
        self.assertEqual('utf-8', resp.encoding)
        self.assertHeaders(resp)
        self.assertLastRequest()

    def test_text_callback(self):
        status_code = 401
        data = u'testdata'

        def _text_cb(request, context):
            context.status_code = status_code
            context.headers.update(self.headers)
            return data

        self.adapter.register_uri('GET', self.url, text=_text_cb)
        resp = self.session.get(self.url)
        self.assertEqual(status_code, resp.status_code)
        self.assertEqual(data, resp.text)
        self.assertEqual(data.encode('utf-8'), resp.content)
        self.assertEqual('utf-8', resp.encoding)
        self.assertHeaders(resp)
        self.assertLastRequest()

    def test_raw_callback(self):
        status_code = 401
        data = 'testdata'

        def _raw_cb(request, context):
            return HTTPResponse(
                status=status_code,
                headers=self.headers,
                body=io.BytesIO(data.encode('utf-8')),
                preload_content=False,
                reason=http.client.responses.get(status_code),
            )

        self.adapter.register_uri('GET', self.url, raw=_raw_cb)
        resp = self.session.get(self.url)
        self.assertEqual(status_code, resp.status_code)
        self.assertEqual(data, resp.text)
        self.assertEqual(data.encode('utf-8'), resp.content)
        self.assertHeaders(resp)
        self.assertLastRequest()

    def test_json(self):
        json_data = {'hello': 'world'}
        self.adapter.register_uri('GET',
                                  self.url,
                                  json=json_data,
                                  headers=self.headers)
        resp = self.session.get(self.url)
        self.assertEqual(b'{"hello": "world"}', resp.content)
        self.assertEqual(u'{"hello": "world"}', resp.text)
        self.assertEqual(json_data, resp.json())
        self.assertEqual('utf-8', resp.encoding)
        self.assertHeaders(resp)
        self.assertLastRequest()

    def test_json_callback(self):
        status_code = 401
        json_data = {'hello': 'world'}
        data = u'{"hello": "world"}'

        def _json_cb(request, context):
            context.status_code = status_code
            context.headers.update(self.headers)
            return json_data

        self.adapter.register_uri('GET', self.url, json=_json_cb)
        resp = self.session.get(self.url)
        self.assertEqual(status_code, resp.status_code)
        self.assertEqual(json_data, resp.json())
        self.assertEqual(data, resp.text)
        self.assertEqual(data.encode('utf-8'), resp.content)
        self.assertEqual('utf-8', resp.encoding)
        self.assertHeaders(resp)
        self.assertLastRequest()

    def test_no_body(self):
        self.adapter.register_uri('GET', self.url)
        resp = self.session.get(self.url)
        self.assertEqual(b'', resp.content)
        self.assertEqual(200, resp.status_code)

    def test_multiple_body_elements(self):
        self.assertRaises(RuntimeError,
                          self.adapter.register_uri,
                          self.url,
                          'GET',
                          content=b'b',
                          text=u'u')

    def test_multiple_responses(self):
        inp = [{'status_code': 400, 'text': 'abcd'},
               {'status_code': 300, 'text': 'defg'},
               {'status_code': 200, 'text': 'hijk'}]

        self.adapter.register_uri('GET', self.url, inp)
        out = [self.session.get(self.url) for i in range(0, len(inp))]

        for i, o in zip(inp, out):
            for k, v in i.items():
                self.assertEqual(v, getattr(o, k))

        last = self.session.get(self.url)
        for k, v in inp[-1].items():
            self.assertEqual(v, getattr(last, k))

    def test_callback_optional_status(self):
        headers = {'a': 'b'}

        def _test_cb(request, context):
            context.headers.update(headers)
            return ''

        self.adapter.register_uri('GET',
                                  self.url,
                                  text=_test_cb,
                                  status_code=300)
        resp = self.session.get(self.url)
        self.assertEqual(300, resp.status_code)

        for k, v in headers.items():
            self.assertEqual(v, resp.headers[k])

    def test_callback_optional_headers(self):
        headers = {'a': 'b'}

        def _test_cb(request, context):
            context.status_code = 300
            return ''

        self.adapter.register_uri('GET',
                                  self.url,
                                  text=_test_cb,
                                  headers=headers)

        resp = self.session.get(self.url)
        self.assertEqual(300, resp.status_code)

        for k, v in headers.items():
            self.assertEqual(v, resp.headers[k])

    def test_latest_register_overrides(self):
        self.adapter.register_uri('GET', self.url, text='abc')
        self.adapter.register_uri('GET', self.url, text='def')

        resp = self.session.get(self.url)
        self.assertEqual('def', resp.text)

    def test_no_last_request(self):
        self.assertIsNone(self.adapter.last_request)
        self.assertEqual(0, len(self.adapter.request_history))

    def test_dont_pass_list_and_kwargs(self):
        self.assertRaises(RuntimeError,
                          self.adapter.register_uri,
                          'GET',
                          self.url,
                          [{'text': 'a'}],
                          headers={'a': 'b'})

    def test_empty_string_return(self):
        # '' evaluates as False, so make sure an empty string is not ignored.
        self.adapter.register_uri('GET', self.url, text='')
        resp = self.session.get(self.url)
        self.assertEqual('', resp.text)

    def test_dont_pass_multiple_bodies(self):
        self.assertRaises(RuntimeError,
                          self.adapter.register_uri,
                          'GET',
                          self.url,
                          json={'abc': 'def'},
                          text='ghi')

    def test_dont_pass_unexpected_kwargs(self):
        self.assertRaises(TypeError,
                          self.adapter.register_uri,
                          'GET',
                          self.url,
                          unknown='argument')

    def test_dont_pass_unicode_as_content(self):
        self.assertRaises(TypeError,
                          self.adapter.register_uri,
                          'GET',
                          self.url,
                          content=u'unicode')

    def test_dont_pass_empty_string_as_content(self):
        self.assertRaises(TypeError,
                          self.adapter.register_uri,
                          'GET',
                          self.url,
                          content=u'')

    def test_dont_pass_bytes_as_text(self):
        self.assertRaises(TypeError,
                          self.adapter.register_uri,
                          'GET',
                          self.url,
                          text=b'bytes')

    def test_dont_pass_empty_string_as_text(self):
        self.assertRaises(TypeError,
                          self.adapter.register_uri,
                          'GET',
                          self.url,
                          text=b'')

    def test_dont_pass_non_str_as_content(self):
        self.assertRaises(TypeError,
                          self.adapter.register_uri,
                          'GET',
                          self.url,
                          content=5)

    def test_dont_pass_non_str_as_text(self):
        self.assertRaises(TypeError,
                          self.adapter.register_uri,
                          'GET',
                          self.url,
                          text=5)

    def test_with_any_method(self):
        self.adapter.register_uri(requests_mock.ANY, self.url, text='resp')

        for m in ('GET', 'HEAD', 'POST', 'UNKNOWN'):
            resp = self.session.request(m, self.url)
            self.assertEqual('resp', resp.text)

    def test_with_any_url(self):
        self.adapter.register_uri('GET', requests_mock.ANY, text='resp')

        for u in ('mock://a', 'mock://b', 'mock://c'):
            resp = self.session.get(u)
            self.assertEqual('resp', resp.text)

    def test_with_regexp(self):
        self.adapter.register_uri('GET', re.compile('tester.com'), text='resp')

        for u in ('mock://www.tester.com/a', 'mock://abc.tester.com'):
            resp = self.session.get(u)
            self.assertEqual('resp', resp.text)

    def test_with_purl(self):
        self.adapter.register_uri('GET',
                                  purl.URL('mock://www.tester.com/a'),
                                  text='resp')

        resp = self.session.get('mock://www.tester.com/a')
        self.assertEqual('resp', resp.text)

    def test_requests_in_history_on_no_match(self):
        self.assertRaises(requests_mock.NoMockAddress,
                          self.session.get,
                          self.url)

        self.assertEqual(self.url, self.adapter.last_request.url)

    def test_requests_in_history_on_exception(self):

        def _test_cb(request, ctx):
            raise MyExc()

        self.adapter.register_uri('GET', self.url, text=_test_cb)

        self.assertRaises(MyExc,
                          self.session.get,
                          self.url)

        self.assertEqual(self.url, self.adapter.last_request.url)

    def test_not_called_and_called_count(self):
        m = self.adapter.register_uri('GET', self.url, text='resp')
        self.assertEqual(0, m.call_count)
        self.assertFalse(m.called)
        self.assertFalse(m.called_once)

        self.assertEqual(0, self.adapter.call_count)
        self.assertFalse(self.adapter.called)
        self.assertFalse(m.called_once)

    def test_called_and_called_count(self):
        m = self.adapter.register_uri('GET', self.url, text='resp')

        resps = [self.session.get(self.url) for i in range(0, 3)]

        for r in resps:
            self.assertEqual('resp', r.text)
            self.assertEqual(200, r.status_code)

        self.assertEqual(len(resps), m.call_count)
        self.assertTrue(m.called)
        self.assertFalse(m.called_once)

        self.assertEqual(len(resps), self.adapter.call_count)
        self.assertTrue(self.adapter.called)
        self.assertFalse(m.called_once)

    def test_reset_reverts_call_count(self):
        # Create matchers and add calls to history
        call_count = 3
        matcher_count = 3
        for i in range(matcher_count):
            url = self.url + str(i)
            self.adapter.register_uri('GET', url, text='resp')
            for _ in range(call_count):
                self.session.get(url)

        # Verify call counts on adapter and matchers
        self.assertEqual(self.adapter.call_count, matcher_count * call_count)
        for matcher in self.adapter._matchers:
            self.assertEqual(matcher.call_count, call_count)

        self.adapter.reset()

        # Verify call counts are 0 after reset
        self.assertEqual(self.adapter.call_count, 0)
        for matcher in self.adapter._matchers:
            self.assertEqual(matcher.call_count, 0)

    def test_adapter_picks_correct_adapter(self):
        good = '%s://test3.url/' % self.PREFIX
        self.adapter.register_uri('GET',
                                  '%s://test1.url' % self.PREFIX,
                                  text='bad')
        self.adapter.register_uri('GET',
                                  '%s://test2.url' % self.PREFIX,
                                  text='bad')
        self.adapter.register_uri('GET', good, text='good')
        self.adapter.register_uri('GET',
                                  '%s://test4.url' % self.PREFIX,
                                  text='bad')

        resp = self.session.get(good)

        self.assertEqual('good', resp.text)

    def test_adapter_is_connection(self):
        url = '%s://test.url' % self.PREFIX
        text = 'text'
        self.adapter.register_uri('GET', url, text=text)
        resp = self.session.get(url)

        self.assertEqual(text, resp.text)
        self.assertIs(self.adapter, resp.connection)

    def test_send_to_connection(self):
        url1 = '%s://test1.url/' % self.PREFIX
        url2 = '%s://test2.url/' % self.PREFIX

        text1 = 'text1'
        text2 = 'text2'

        self.adapter.register_uri('GET', url1, text=text1)
        self.adapter.register_uri('GET', url2, text=text2)

        req = requests.Request(method='GET', url=url2).prepare()

        resp1 = self.session.get(url1)
        self.assertEqual(text1, resp1.text)

        resp2 = resp1.connection.send(req)
        self.assertEqual(text2, resp2.text)

    def test_request_json_with_str_data(self):
        dict_req = {'hello': 'world'}
        dict_resp = {'goodbye': 'world'}

        m = self.adapter.register_uri('POST', self.url, json=dict_resp)

        data = json.dumps(dict_req)
        resp = self.session.post(self.url, data=data)

        self.assertIs(data, m.last_request.body)
        self.assertEqual(dict_resp, resp.json())
        self.assertEqual(dict_req, m.last_request.json())

    def test_request_json_with_bytes_data(self):
        dict_req = {'hello': 'world'}
        dict_resp = {'goodbye': 'world'}

        m = self.adapter.register_uri('POST', self.url, json=dict_resp)

        data = json.dumps(dict_req).encode('utf-8')
        resp = self.session.post(self.url, data=data)

        self.assertIs(data, m.last_request.body)
        self.assertEqual(dict_resp, resp.json())
        self.assertEqual(dict_req, m.last_request.json())

    def test_request_json_with_cb(self):
        dict_req = {'hello': 'world'}
        dict_resp = {'goodbye': 'world'}
        data = json.dumps(dict_req)

        def _cb(req, context):
            self.assertEqual(dict_req, req.json())
            return dict_resp

        m = self.adapter.register_uri('POST', self.url, json=_cb)
        resp = self.session.post(self.url, data=data)

        self.assertEqual(1, m.call_count)
        self.assertTrue(m.called_once)
        self.assertEqual(dict_resp, resp.json())

    def test_raises_exception(self):
        self.adapter.register_uri('GET', self.url, exc=MyExc)

        self.assertRaises(MyExc,
                          self.session.get,
                          self.url)

        self.assertTrue(self.adapter.called_once)
        self.assertEqual(self.url, self.adapter.last_request.url)

    def test_raises_exception_with_body_args_fails(self):
        self.assertRaises(TypeError,
                          self.adapter.register_uri,
                          'GET',
                          self.url,
                          exc=MyExc,
                          text='fail')

    def test_sets_request_matcher_in_history(self):
        url1 = '%s://test1.url/' % self.PREFIX
        url2 = '%s://test2.url/' % self.PREFIX

        text1 = 'text1'
        text2 = 'text2'

        m1 = self.adapter.register_uri('GET', url1, text=text1)
        m2 = self.adapter.register_uri('GET', url2, text=text2)

        resp1 = self.session.get(url1)
        resp2 = self.session.get(url2)

        self.assertEqual(text1, resp1.text)
        self.assertEqual(text2, resp2.text)

        self.assertEqual(2, self.adapter.call_count)
        self.assertFalse(self.adapter.called_once)

        self.assertEqual(url1, self.adapter.request_history[0].url)
        self.assertEqual(url2, self.adapter.request_history[1].url)

        self.assertIs(m1, self.adapter.request_history[0].matcher)
        self.assertIs(m2, self.adapter.request_history[1].matcher)

    def test_sets_request_matcher_on_exception(self):
        m = self.adapter.register_uri('GET', self.url, exc=MyExc)

        self.assertRaises(MyExc,
                          self.session.get,
                          self.url)

        self.assertEqual(self.url, self.adapter.last_request.url)
        self.assertIs(m, self.adapter.last_request.matcher)

    def test_cookies_from_header(self):
        headers = {'Set-Cookie': 'fig=newton; Path=/test; domain=.example.com'}
        self.adapter.register_uri('GET',
                                  self.url,
                                  text='text',
                                  headers=headers)

        resp = self.session.get(self.url)

        self.assertEqual('newton', resp.cookies['fig'])
        self.assertEqual(['/test'], resp.cookies.list_paths())
        self.assertEqual(['.example.com'], resp.cookies.list_domains())

    def test_cookies_from_dict(self):
        # This is a syntax we get from requests. I'm not sure i like it.
        self.adapter.register_uri('GET',
                                  self.url,
                                  text='text',
                                  cookies={'fig': 'newton', 'sugar': 'apple'})

        resp = self.session.get(self.url)

        self.assertEqual('newton', resp.cookies['fig'])
        self.assertEqual('apple', resp.cookies['sugar'])

    def test_cookies_with_jar(self):
        jar = requests_mock.CookieJar()
        jar.set('fig', 'newton', path='/foo', domain='.example.com')
        jar.set('sugar', 'apple', path='/bar', domain='.example.com')

        self.adapter.register_uri('GET', self.url, text='text', cookies=jar)
        resp = self.session.get(self.url)

        self.assertEqual('newton', resp.cookies['fig'])
        self.assertEqual('apple', resp.cookies['sugar'])
        self.assertEqual({'/foo', '/bar'}, set(resp.cookies.list_paths()))
        self.assertEqual(['.example.com'], resp.cookies.list_domains())

    def test_cookies_header_with_cb(self):

        def _cb(request, context):
            val = 'fig=newton; Path=/test; domain=.example.com'
            context.headers['Set-Cookie'] = val
            return 'text'

        self.adapter.register_uri('GET', self.url, text=_cb)
        resp = self.session.get(self.url)

        self.assertEqual('newton', resp.cookies['fig'])
        self.assertEqual(['/test'], resp.cookies.list_paths())
        self.assertEqual(['.example.com'], resp.cookies.list_domains())

    def test_cookies_from_dict_with_cb(self):
        def _cb(request, context):
            # converted into a jar by now
            context.cookies.set('sugar', 'apple', path='/test')
            return 'text'

        self.adapter.register_uri('GET',
                                  self.url,
                                  text=_cb,
                                  cookies={'fig': 'newton'})

        resp = self.session.get(self.url)

        self.assertEqual('newton', resp.cookies['fig'])
        self.assertEqual('apple', resp.cookies['sugar'])
        self.assertEqual(['/', '/test'], resp.cookies.list_paths())

    def test_cookies_with_jar_cb(self):
        def _cb(request, context):
            context.cookies.set('sugar',
                                'apple',
                                path='/bar',
                                domain='.example.com')
            return 'text'

        jar = requests_mock.CookieJar()
        jar.set('fig', 'newton', path='/foo', domain='.example.com')

        self.adapter.register_uri('GET', self.url, text=_cb, cookies=jar)
        resp = self.session.get(self.url)

        self.assertEqual('newton', resp.cookies['fig'])
        self.assertEqual('apple', resp.cookies['sugar'])
        self.assertEqual({'/foo', '/bar'}, set(resp.cookies.list_paths()))
        self.assertEqual(['.example.com'], resp.cookies.list_domains())

    def test_reading_closed_fp(self):
        self.adapter.register_uri('GET', self.url, text='abc')
        resp = self.session.get(self.url)

        # raw will have been closed during the request reading
        self.assertTrue(resp.raw.closed)

        data = resp.raw.read()

        self.assertIsInstance(data, bytes)
        self.assertEqual(0, len(data))

    def test_case_sensitive_headers(self):
        data = 'testdata'
        headers = {'aBcDe': 'FgHiJ'}

        self.adapter.register_uri('GET', self.url, text=data)
        resp = self.session.get(self.url, headers=headers)

        self.assertEqual('GET', self.adapter.last_request.method)
        self.assertEqual(200, resp.status_code)
        self.assertEqual(data, resp.text)

        for k, v in headers.items():
            self.assertEqual(v, self.adapter.last_request.headers[k])

    def test_case_sensitive_history(self):
        self.adapter._case_sensitive = True

        data = 'testdata'
        netloc = 'examPlE.CoM'
        path = '/TesTER'
        query = 'aBC=deF'

        mock_url = '%s://%s%s' % (self.PREFIX, netloc.lower(), path)
        request_url = '%s://%s%s?%s' % (self.PREFIX, netloc, path, query)

        # test that the netloc is ignored when actually making the request
        self.adapter.register_uri('GET', mock_url, text=data)
        resp = self.session.get(request_url)

        self.assertEqual('GET', self.adapter.last_request.method)
        self.assertEqual(200, resp.status_code)
        self.assertEqual(data, resp.text)

        # but even still the mixed case parameters come out in history
        self.assertEqual(netloc, self.adapter.last_request.netloc)
        self.assertEqual(path, self.adapter.last_request.path)
        self.assertEqual(query, self.adapter.last_request.query)

    def test_stream_none(self):
        text = 'hello world'

        self.adapter.register_uri('GET',
                                  self.url,
                                  text=text,
                                  headers=self.headers)

        resp = self.session.get(self.url, stream=True)
        resps = [c for c in resp.iter_content(None, decode_unicode=True)]
        self.assertEqual([text], resps)

    def test_stream_size(self):
        text = 'hello world'

        self.adapter.register_uri('GET',
                                  self.url,
                                  text=text,
                                  headers=self.headers)

        resp = self.session.get(self.url, stream=True)
        resps = [c for c in resp.iter_content(3, decode_unicode=True)]
        self.assertEqual(['hel', 'lo ', 'wor', 'ld'], resps)
