from __future__ import absolute_import, division, with_statement

import collections
from contextlib import closing
import gzip
import logging
import re
import socket

from tornado.httpclient import AsyncHTTPClient
from tornado.httputil import HTTPHeaders
from tornado.ioloop import IOLoop
from tornado.simple_httpclient import SimpleAsyncHTTPClient, _DEFAULT_CA_CERTS
from tornado.test.httpclient_test import HTTPClientCommonTestCase, ChunkHandler, CountdownHandler, HelloWorldHandler
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, LogTrapTestCase
from tornado.util import b
from tornado.web import RequestHandler, Application, asynchronous, url


class SimpleHTTPClientCommonTestCase(HTTPClientCommonTestCase):
    def get_http_client(self):
        client = SimpleAsyncHTTPClient(io_loop=self.io_loop,
                                       force_instance=True)
        self.assertTrue(isinstance(client, SimpleAsyncHTTPClient))
        return client

# Remove the base class from our namespace so the unittest module doesn't
# try to run it again.
del HTTPClientCommonTestCase


class TriggerHandler(RequestHandler):
    def initialize(self, queue, wake_callback):
        self.queue = queue
        self.wake_callback = wake_callback

    @asynchronous
    def get(self):
        logging.info("queuing trigger")
        self.queue.append(self.finish)
        if self.get_argument("wake", "true") == "true":
            self.wake_callback()


class HangHandler(RequestHandler):
    @asynchronous
    def get(self):
        pass


class ContentLengthHandler(RequestHandler):
    def get(self):
        self.set_header("Content-Length", self.get_argument("value"))
        self.write("ok")


class HeadHandler(RequestHandler):
    def head(self):
        self.set_header("Content-Length", "7")


class OptionsHandler(RequestHandler):
    def options(self):
        self.set_header("Access-Control-Allow-Origin", "*")
        self.write("ok")


class NoContentHandler(RequestHandler):
    def get(self):
        if self.get_argument("error", None):
            self.set_header("Content-Length", "7")
        self.set_status(204)


class SeeOther303PostHandler(RequestHandler):
    def post(self):
        assert self.request.body == b("blah")
        self.set_header("Location", "/303_get")
        self.set_status(303)


class SeeOther303GetHandler(RequestHandler):
    def get(self):
        assert not self.request.body
        self.write("ok")


class HostEchoHandler(RequestHandler):
    def get(self):
        self.write(self.request.headers["Host"])


class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
    def setUp(self):
        super(SimpleHTTPClientTestCase, self).setUp()
        self.http_client = SimpleAsyncHTTPClient(self.io_loop)

    def get_app(self):
        # callable objects to finish pending /trigger requests
        self.triggers = collections.deque()
        return Application([
            url("/trigger", TriggerHandler, dict(queue=self.triggers,
                                                 wake_callback=self.stop)),
            url("/chunk", ChunkHandler),
            url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
            url("/hang", HangHandler),
            url("/hello", HelloWorldHandler),
            url("/content_length", ContentLengthHandler),
            url("/head", HeadHandler),
            url("/options", OptionsHandler),
            url("/no_content", NoContentHandler),
            url("/303_post", SeeOther303PostHandler),
            url("/303_get", SeeOther303GetHandler),
            url("/host_echo", HostEchoHandler),
            ], gzip=True)

    def test_singleton(self):
        # Class "constructor" reuses objects on the same IOLoop
        self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is
                        SimpleAsyncHTTPClient(self.io_loop))
        # unless force_instance is used
        self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is not
                        SimpleAsyncHTTPClient(self.io_loop,
                                              force_instance=True))
        # different IOLoops use different objects
        io_loop2 = IOLoop()
        self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is not
                        SimpleAsyncHTTPClient(io_loop2))

    def test_connection_limit(self):
        client = SimpleAsyncHTTPClient(self.io_loop, max_clients=2,
                                       force_instance=True)
        self.assertEqual(client.max_clients, 2)
        seen = []
        # Send 4 requests.  Two can be sent immediately, while the others
        # will be queued
        for i in range(4):
            client.fetch(self.get_url("/trigger"),
                         lambda response, i=i: (seen.append(i), self.stop()))
        self.wait(condition=lambda: len(self.triggers) == 2)
        self.assertEqual(len(client.queue), 2)

        # Finish the first two requests and let the next two through
        self.triggers.popleft()()
        self.triggers.popleft()()
        self.wait(condition=lambda: (len(self.triggers) == 2 and
                                     len(seen) == 2))
        self.assertEqual(set(seen), set([0, 1]))
        self.assertEqual(len(client.queue), 0)

        # Finish all the pending requests
        self.triggers.popleft()()
        self.triggers.popleft()()
        self.wait(condition=lambda: len(seen) == 4)
        self.assertEqual(set(seen), set([0, 1, 2, 3]))
        self.assertEqual(len(self.triggers), 0)

    def test_redirect_connection_limit(self):
        # following redirects should not consume additional connections
        client = SimpleAsyncHTTPClient(self.io_loop, max_clients=1,
                                       force_instance=True)
        client.fetch(self.get_url('/countdown/3'), self.stop,
                     max_redirects=3)
        response = self.wait()
        response.rethrow()

    def test_default_certificates_exist(self):
        open(_DEFAULT_CA_CERTS).close()

    def test_gzip(self):
        # All the tests in this file should be using gzip, but this test
        # ensures that it is in fact getting compressed.
        # Setting Accept-Encoding manually bypasses the client's
        # decompression so we can see the raw data.
        response = self.fetch("/chunk", use_gzip=False,
                              headers={"Accept-Encoding": "gzip"})
        self.assertEqual(response.headers["Content-Encoding"], "gzip")
        self.assertNotEqual(response.body, b("asdfqwer"))
        # Our test data gets bigger when gzipped.  Oops.  :)
        self.assertEqual(len(response.body), 34)
        f = gzip.GzipFile(mode="r", fileobj=response.buffer)
        self.assertEqual(f.read(), b("asdfqwer"))

    def test_max_redirects(self):
        response = self.fetch("/countdown/5", max_redirects=3)
        self.assertEqual(302, response.code)
        # We requested 5, followed three redirects for 4, 3, 2, then the last
        # unfollowed redirect is to 1.
        self.assertTrue(response.request.url.endswith("/countdown/5"))
        self.assertTrue(response.effective_url.endswith("/countdown/2"))
        self.assertTrue(response.headers["Location"].endswith("/countdown/1"))

    def test_header_reuse(self):
        # Apps may reuse a headers object if they are only passing in constant
        # headers like user-agent.  The header object should not be modified.
        headers = HTTPHeaders({'User-Agent': 'Foo'})
        self.fetch("/hello", headers=headers)
        self.assertEqual(list(headers.get_all()), [('User-Agent', 'Foo')])

    def test_303_redirect(self):
        response = self.fetch("/303_post", method="POST", body="blah")
        self.assertEqual(200, response.code)
        self.assertTrue(response.request.url.endswith("/303_post"))
        self.assertTrue(response.effective_url.endswith("/303_get"))
        #request is the original request, is a POST still
        self.assertEqual("POST", response.request.method)

    def test_request_timeout(self):
        response = self.fetch('/trigger?wake=false', request_timeout=0.1)
        self.assertEqual(response.code, 599)
        self.assertTrue(0.099 < response.request_time < 0.11, response.request_time)
        self.assertEqual(str(response.error), "HTTP 599: Timeout")
        # trigger the hanging request to let it clean up after itself
        self.triggers.popleft()()

    def test_ipv6(self):
        if not socket.has_ipv6:
            # python compiled without ipv6 support, so skip this test
            return
        try:
            self.http_server.listen(self.get_http_port(), address='::1')
        except socket.gaierror, e:
            if e.args[0] == socket.EAI_ADDRFAMILY:
                # python supports ipv6, but it's not configured on the network
                # interface, so skip this test.
                return
            raise
        url = self.get_url("/hello").replace("localhost", "[::1]")

        # ipv6 is currently disabled by default and must be explicitly requested
        self.http_client.fetch(url, self.stop)
        response = self.wait()
        self.assertEqual(response.code, 599)

        self.http_client.fetch(url, self.stop, allow_ipv6=True)
        response = self.wait()
        self.assertEqual(response.body, b("Hello world!"))

    def test_multiple_content_length_accepted(self):
        response = self.fetch("/content_length?value=2,2")
        self.assertEqual(response.body, b("ok"))
        response = self.fetch("/content_length?value=2,%202,2")
        self.assertEqual(response.body, b("ok"))

        response = self.fetch("/content_length?value=2,4")
        self.assertEqual(response.code, 599)
        response = self.fetch("/content_length?value=2,%202,3")
        self.assertEqual(response.code, 599)

    def test_head_request(self):
        response = self.fetch("/head", method="HEAD")
        self.assertEqual(response.code, 200)
        self.assertEqual(response.headers["content-length"], "7")
        self.assertFalse(response.body)

    def test_options_request(self):
        response = self.fetch("/options", method="OPTIONS")
        self.assertEqual(response.code, 200)
        self.assertEqual(response.headers["content-length"], "2")
        self.assertEqual(response.headers["access-control-allow-origin"], "*")
        self.assertEqual(response.body, b("ok"))

    def test_no_content(self):
        response = self.fetch("/no_content")
        self.assertEqual(response.code, 204)
        # 204 status doesn't need a content-length, but tornado will
        # add a zero content-length anyway.
        self.assertEqual(response.headers["Content-length"], "0")

        # 204 status with non-zero content length is malformed
        response = self.fetch("/no_content?error=1")
        self.assertEqual(response.code, 599)

    def test_host_header(self):
        host_re = re.compile(b("^localhost:[0-9]+$"))
        response = self.fetch("/host_echo")
        self.assertTrue(host_re.match(response.body))

        url = self.get_url("/host_echo").replace("http://", "http://me:secret@")
        self.http_client.fetch(url, self.stop)
        response = self.wait()
        self.assertTrue(host_re.match(response.body), response.body)


class CreateAsyncHTTPClientTestCase(AsyncTestCase, LogTrapTestCase):
    def setUp(self):
        super(CreateAsyncHTTPClientTestCase, self).setUp()
        self.saved = AsyncHTTPClient._save_configuration()

    def tearDown(self):
        AsyncHTTPClient._restore_configuration(self.saved)
        super(CreateAsyncHTTPClientTestCase, self).tearDown()

    def test_max_clients(self):
        # The max_clients argument is tricky because it was originally
        # allowed to be passed positionally; newer arguments are keyword-only.
        AsyncHTTPClient.configure(SimpleAsyncHTTPClient)
        with closing(AsyncHTTPClient(
                self.io_loop, force_instance=True)) as client:
            self.assertEqual(client.max_clients, 10)
        with closing(AsyncHTTPClient(
                self.io_loop, 11, force_instance=True)) as client:
            self.assertEqual(client.max_clients, 11)
        with closing(AsyncHTTPClient(
                self.io_loop, max_clients=11, force_instance=True)) as client:
            self.assertEqual(client.max_clients, 11)

        # Now configure max_clients statically and try overriding it
        # with each way max_clients can be passed
        AsyncHTTPClient.configure(SimpleAsyncHTTPClient, max_clients=12)
        with closing(AsyncHTTPClient(
                self.io_loop, force_instance=True)) as client:
            self.assertEqual(client.max_clients, 12)
        with closing(AsyncHTTPClient(
                self.io_loop, max_clients=13, force_instance=True)) as client:
            self.assertEqual(client.max_clients, 13)
        with closing(AsyncHTTPClient(
                self.io_loop, max_clients=14, force_instance=True)) as client:
            self.assertEqual(client.max_clients, 14)
