import select
import socket
import contextlib
import threading
from tests import unittest
from contextlib import contextmanager

import botocore.session
from botocore.config import Config
from six.moves import BaseHTTPServer, socketserver
from botocore.exceptions import (
    ConnectTimeoutError, ReadTimeoutError, EndpointConnectionError,
    ConnectionClosedError,
)
from requests import exceptions as requests_exceptions


class TestClientHTTPBehavior(unittest.TestCase):
    def setUp(self):
        self.port = unused_port()
        self.localhost = 'http://localhost:%s/' % self.port
        self.session = botocore.session.get_session()

    @unittest.skip('Test has suddenly become extremely flakey.')
    def test_can_proxy_https_request_with_auth(self):
        proxy_url = 'http://user:pass@localhost:%s/' % self.port
        config = Config(proxies={'https': proxy_url}, region_name='us-west-1')
        client = self.session.create_client('ec2', config=config)

        class AuthProxyHandler(ProxyHandler):
            event = threading.Event()

            def validate_auth(self):
                proxy_auth = self.headers.get('Proxy-Authorization')
                return proxy_auth == 'Basic dXNlcjpwYXNz'

        try:
            with background(run_server, args=(AuthProxyHandler, self.port)):
                AuthProxyHandler.event.wait(timeout=60)
                client.describe_regions()
        except BackgroundTaskFailed:
            self.fail('Background task did not exit, proxy was not used.')

    def _read_timeout_server(self):
        config = Config(
            read_timeout=0.1,
            retries={'max_attempts': 0},
            region_name='us-weast-2',
        )
        client = self.session.create_client('ec2', endpoint_url=self.localhost,
                                            config=config)
        client_call_ended_event = threading.Event()

        class FakeEC2(SimpleHandler):
            event = threading.Event()
            msg = b'<response/>'

            def get_length(self):
                return len(self.msg)

            def get_body(self):
                client_call_ended_event.wait(timeout=60)
                return self.msg

        try:
            with background(run_server, args=(FakeEC2, self.port)):
                try:
                    FakeEC2.event.wait(timeout=60)
                    client.describe_regions()
                finally:
                    client_call_ended_event.set()
        except BackgroundTaskFailed:
            self.fail('Fake EC2 service was not called.')

    def test_read_timeout_exception(self):
        with self.assertRaises(ReadTimeoutError):
            self._read_timeout_server()

    def test_old_read_timeout_exception(self):
        with self.assertRaises(requests_exceptions.ReadTimeout):
            self._read_timeout_server()

    @unittest.skip('The current implementation will fail to timeout on linux')
    def test_connect_timeout_exception(self):
        config = Config(
            connect_timeout=0.2,
            retries={'max_attempts': 0},
            region_name='us-weast-2',
        )
        client = self.session.create_client('ec2', endpoint_url=self.localhost,
                                            config=config)
        server_bound_event = threading.Event()
        client_call_ended_event = threading.Event()

        def no_accept_server():
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            sock.bind(('', self.port))
            server_bound_event.set()
            client_call_ended_event.wait(timeout=60)
            sock.close()

        with background(no_accept_server):
            server_bound_event.wait(timeout=60)
            with self.assertRaises(ConnectTimeoutError):
                client.describe_regions()
            client_call_ended_event.set()

    def test_invalid_host_gaierror(self):
        config = Config(retries={'max_attempts': 0}, region_name='us-weast-1')
        endpoint = 'https://ec2.us-weast-1.amazonaws.com/'
        client = self.session.create_client('ec2', endpoint_url=endpoint,
                                            config=config)
        with self.assertRaises(EndpointConnectionError):
            client.describe_regions()

    def test_bad_status_line(self):
        config = Config(retries={'max_attempts': 0}, region_name='us-weast-2')
        client = self.session.create_client('ec2', endpoint_url=self.localhost,
                                            config=config)

        class BadStatusHandler(BaseHTTPServer.BaseHTTPRequestHandler):
            event = threading.Event()

            def do_POST(self):
                self.wfile.write(b'garbage')

        with background(run_server, args=(BadStatusHandler, self.port)):
            with self.assertRaises(ConnectionClosedError):
                BadStatusHandler.event.wait(timeout=60)
                client.describe_regions()


def unused_port():
    with contextlib.closing(socket.socket()) as sock:
        sock.bind(('127.0.0.1', 0))
        return sock.getsockname()[1]


class SimpleHandler(BaseHTTPServer.BaseHTTPRequestHandler):
    status = 200

    def get_length(self):
        return 0

    def get_body(self):
        return b''

    def do_GET(self):
        length = str(self.get_length())
        self.send_response(self.status)
        self.send_header('Content-Length', length)
        self.end_headers()
        self.wfile.write(self.get_body())

    do_POST = do_PUT = do_GET


class ProxyHandler(BaseHTTPServer.BaseHTTPRequestHandler):
    tunnel_chunk_size = 1024

    def _tunnel(self, client, remote):
        client.setblocking(0)
        remote.setblocking(0)
        sockets = [client, remote]
        while True:
            readable, writeable, _ = select.select(sockets, sockets, [], 1)
            if client in readable and remote in writeable:
                client_bytes = client.recv(self.tunnel_chunk_size)
                if not client_bytes:
                    break
                remote.sendall(client_bytes)
            if remote in readable and client in writeable:
                remote_bytes = remote.recv(self.tunnel_chunk_size)
                if not remote_bytes:
                    break
                client.sendall(remote_bytes)

    def do_CONNECT(self):
        if not self.validate_auth():
            self.send_response(401)
            self.end_headers()
            return

        self.send_response(200)
        self.end_headers()

        remote_host, remote_port = self.path.split(':')
        remote_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        remote_socket.connect((remote_host, int(remote_port)))

        self._tunnel(self.request, remote_socket)
        remote_socket.close()

    def validate_auth(self):
        return True


class BackgroundTaskFailed(Exception):
    pass


@contextmanager
def background(target, args=(), timeout=60):
    thread = threading.Thread(target=target, args=args)
    thread.daemon = True
    thread.start()
    try:
        yield target
    finally:
        thread.join(timeout=timeout)
        if thread.is_alive():
            msg = 'Background task did not exit in a timely manner.'
            raise BackgroundTaskFailed(msg)


def run_server(handler, port):
    address = ('', port)
    httpd = socketserver.TCPServer(address, handler, bind_and_activate=False)
    httpd.allow_reuse_address = True
    httpd.server_bind()
    httpd.server_activate()
    handler.event.set()
    httpd.handle_request()
    httpd.server_close()
