import asyncio
import logging
import os
import ssl
import unittest
import unittest.mock

from .client import *
from .exceptions import ConnectionClosed, InvalidHandshake
from .http import USER_AGENT, read_response
from .server import *


# Avoid displaying stack traces at the ERROR logging level.
logging.basicConfig(level=logging.CRITICAL)

testcert = os.path.join(os.path.dirname(__file__), 'testcert.pem')


@asyncio.coroutine
def handler(ws, path):
    if path == '/attributes':
        yield from ws.send(repr((ws.host, ws.port, ws.secure)))
    elif path == '/headers':
        yield from ws.send(str(ws.request_headers))
        yield from ws.send(str(ws.response_headers))
    elif path == '/raw_headers':
        yield from ws.send(repr(ws.raw_request_headers))
        yield from ws.send(repr(ws.raw_response_headers))
    elif path == '/subprotocol':
        yield from ws.send(repr(ws.subprotocol))
    else:
        yield from ws.send((yield from ws.recv()))


class ClientServerTests(unittest.TestCase):

    secure = False

    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self.loop)

    def tearDown(self):
        self.loop.close()

    def run_loop_once(self):
        # Process callbacks scheduled with call_soon by appending a callback
        # to stop the event loop then running it until it hits that callback.
        self.loop.call_soon(self.loop.stop)
        self.loop.run_forever()

    def start_server(self, **kwds):
        server = serve(handler, 'localhost', 8642, **kwds)
        self.server = self.loop.run_until_complete(server)

    def start_client(self, path='', **kwds):
        client = connect('ws://localhost:8642/' + path, **kwds)
        self.client = self.loop.run_until_complete(client)

    def stop_client(self):
        try:
            self.loop.run_until_complete(
                asyncio.wait_for(self.client.worker_task, timeout=1))
        except asyncio.TimeoutError:                # pragma: no cover
            self.fail("Client failed to stop")

    def stop_server(self):
        self.server.close()
        try:
            self.loop.run_until_complete(
                asyncio.wait_for(self.server.wait_closed(), timeout=1))
        except asyncio.TimeoutError:                # pragma: no cover
            self.fail("Server failed to stop")

    def test_basic(self):
        self.start_server()
        self.start_client()
        self.loop.run_until_complete(self.client.send("Hello!"))
        reply = self.loop.run_until_complete(self.client.recv())
        self.assertEqual(reply, "Hello!")
        self.stop_client()
        self.stop_server()

    def test_server_close_while_client_connected(self):
        self.start_server()
        self.start_client()
        self.stop_server()

    def test_explicit_event_loop(self):
        self.start_server(loop=self.loop)
        self.start_client(loop=self.loop)
        self.loop.run_until_complete(self.client.send("Hello!"))
        reply = self.loop.run_until_complete(self.client.recv())
        self.assertEqual(reply, "Hello!")
        self.stop_client()
        self.stop_server()

    def test_protocol_attributes(self):
        self.start_server()
        self.start_client('attributes')
        expected_attrs = ('localhost', 8642, self.secure)
        client_attrs = (self.client.host, self.client.port, self.client.secure)
        self.assertEqual(client_attrs, expected_attrs)
        server_attrs = self.loop.run_until_complete(self.client.recv())
        self.assertEqual(server_attrs, repr(expected_attrs))
        self.stop_client()
        self.stop_server()

    def test_protocol_headers(self):
        self.start_server()
        self.start_client('headers')
        client_req = self.client.request_headers
        client_resp = self.client.response_headers
        self.assertEqual(client_req['User-Agent'], USER_AGENT)
        self.assertEqual(client_resp['Server'], USER_AGENT)
        server_req = self.loop.run_until_complete(self.client.recv())
        server_resp = self.loop.run_until_complete(self.client.recv())
        self.assertEqual(server_req, str(client_req))
        self.assertEqual(server_resp, str(client_resp))
        self.stop_client()
        self.stop_server()

    def test_protocol_raw_headers(self):
        self.start_server()
        self.start_client('raw_headers')
        client_req = self.client.raw_request_headers
        client_resp = self.client.raw_response_headers
        self.assertEqual(dict(client_req)['User-Agent'], USER_AGENT)
        self.assertEqual(dict(client_resp)['Server'], USER_AGENT)
        server_req = self.loop.run_until_complete(self.client.recv())
        server_resp = self.loop.run_until_complete(self.client.recv())
        self.assertEqual(server_req, repr(client_req))
        self.assertEqual(server_resp, repr(client_resp))
        self.stop_client()
        self.stop_server()

    def test_protocol_custom_request_headers_dict(self):
        self.start_server()
        self.start_client('raw_headers', extra_headers={'X-Spam': 'Eggs'})
        req_headers = self.loop.run_until_complete(self.client.recv())
        self.loop.run_until_complete(self.client.recv())
        self.assertIn("('X-Spam', 'Eggs')", req_headers)
        self.stop_client()
        self.stop_server()

    def test_protocol_custom_request_headers_list(self):
        self.start_server()
        self.start_client('raw_headers', extra_headers=[('X-Spam', 'Eggs')])
        req_headers = self.loop.run_until_complete(self.client.recv())
        self.loop.run_until_complete(self.client.recv())
        self.assertIn("('X-Spam', 'Eggs')", req_headers)
        self.stop_client()
        self.stop_server()

    def test_protocol_custom_response_headers_callable_dict(self):
        self.start_server(extra_headers=lambda p, r: {'X-Spam': 'Eggs'})
        self.start_client('raw_headers')
        self.loop.run_until_complete(self.client.recv())
        resp_headers = self.loop.run_until_complete(self.client.recv())
        self.assertIn("('X-Spam', 'Eggs')", resp_headers)
        self.stop_client()
        self.stop_server()

    def test_protocol_custom_response_headers_callable_list(self):
        self.start_server(extra_headers=lambda p, r: [('X-Spam', 'Eggs')])
        self.start_client('raw_headers')
        self.loop.run_until_complete(self.client.recv())
        resp_headers = self.loop.run_until_complete(self.client.recv())
        self.assertIn("('X-Spam', 'Eggs')", resp_headers)
        self.stop_client()
        self.stop_server()

    def test_protocol_custom_response_headers_dict(self):
        self.start_server(extra_headers={'X-Spam': 'Eggs'})
        self.start_client('raw_headers')
        self.loop.run_until_complete(self.client.recv())
        resp_headers = self.loop.run_until_complete(self.client.recv())
        self.assertIn("('X-Spam', 'Eggs')", resp_headers)
        self.stop_client()
        self.stop_server()

    def test_protocol_custom_response_headers_list(self):
        self.start_server(extra_headers=[('X-Spam', 'Eggs')])
        self.start_client('raw_headers')
        self.loop.run_until_complete(self.client.recv())
        resp_headers = self.loop.run_until_complete(self.client.recv())
        self.assertIn("('X-Spam', 'Eggs')", resp_headers)
        self.stop_client()
        self.stop_server()

    def test_no_subprotocol(self):
        self.start_server()
        self.start_client('subprotocol')
        server_subprotocol = self.loop.run_until_complete(self.client.recv())
        self.assertEqual(server_subprotocol, repr(None))
        self.assertEqual(self.client.subprotocol, None)
        self.stop_client()
        self.stop_server()

    def test_subprotocol_found(self):
        self.start_server(subprotocols=['superchat', 'chat'])
        self.start_client('subprotocol', subprotocols=['otherchat', 'chat'])
        server_subprotocol = self.loop.run_until_complete(self.client.recv())
        self.assertEqual(server_subprotocol, repr('chat'))
        self.assertEqual(self.client.subprotocol, 'chat')
        self.stop_client()
        self.stop_server()

    def test_subprotocol_not_found(self):
        self.start_server(subprotocols=['superchat'])
        self.start_client('subprotocol', subprotocols=['otherchat'])
        server_subprotocol = self.loop.run_until_complete(self.client.recv())
        self.assertEqual(server_subprotocol, repr(None))
        self.assertEqual(self.client.subprotocol, None)
        self.stop_client()
        self.stop_server()

    def test_subprotocol_not_offered(self):
        self.start_server()
        self.start_client('subprotocol', subprotocols=['otherchat', 'chat'])
        server_subprotocol = self.loop.run_until_complete(self.client.recv())
        self.assertEqual(server_subprotocol, repr(None))
        self.assertEqual(self.client.subprotocol, None)
        self.stop_client()
        self.stop_server()

    def test_subprotocol_not_requested(self):
        self.start_server(subprotocols=['superchat', 'chat'])
        self.start_client('subprotocol')
        server_subprotocol = self.loop.run_until_complete(self.client.recv())
        self.assertEqual(server_subprotocol, repr(None))
        self.assertEqual(self.client.subprotocol, None)
        self.stop_client()
        self.stop_server()

    @unittest.mock.patch.object(
        WebSocketServerProtocol, 'select_subprotocol', autospec=True)
    def test_subprotocol_error(self, _select_subprotocol):
        _select_subprotocol.return_value = 'superchat'

        self.start_server(subprotocols=['superchat'])
        with self.assertRaises(InvalidHandshake):
            self.start_client('subprotocol', subprotocols=['otherchat'])
        self.run_loop_once()
        self.stop_server()

    @unittest.mock.patch('websockets.server.read_request')
    def test_server_receives_malformed_request(self, _read_request):
        _read_request.side_effect = ValueError("read_request failed")

        self.start_server()
        with self.assertRaises(InvalidHandshake):
            self.start_client()
        self.stop_server()

    @unittest.mock.patch('websockets.client.read_response')
    def test_client_receives_malformed_response(self, _read_response):
        _read_response.side_effect = ValueError("read_response failed")

        self.start_server()
        with self.assertRaises(InvalidHandshake):
            self.start_client()
        self.run_loop_once()
        self.stop_server()

    @unittest.mock.patch('websockets.client.build_request')
    def test_client_sends_invalid_handshake_request(self, _build_request):
        def wrong_build_request(set_header):
            return '42'
        _build_request.side_effect = wrong_build_request

        self.start_server()
        with self.assertRaises(InvalidHandshake):
            self.start_client()
        self.stop_server()

    @unittest.mock.patch('websockets.server.build_response')
    def test_server_sends_invalid_handshake_response(self, _build_response):
        def wrong_build_response(set_header, key):
            return build_response(set_header, '42')
        _build_response.side_effect = wrong_build_response

        self.start_server()
        with self.assertRaises(InvalidHandshake):
            self.start_client()
        self.stop_server()

    @unittest.mock.patch('websockets.client.read_response')
    def test_server_does_not_switch_protocols(self, _read_response):
        @asyncio.coroutine
        def wrong_read_response(stream):
            code, headers = yield from read_response(stream)
            return 400, headers
        _read_response.side_effect = wrong_read_response

        self.start_server()
        with self.assertRaises(InvalidHandshake):
            self.start_client()
        self.run_loop_once()
        self.stop_server()

    @unittest.mock.patch('websockets.server.WebSocketServerProtocol.send')
    def test_server_handler_crashes(self, send):
        send.side_effect = ValueError("send failed")

        self.start_server()
        self.start_client()
        self.loop.run_until_complete(self.client.send("Hello!"))
        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.client.recv())
        self.stop_client()
        self.stop_server()

        # Connection ends with an unexpected error.
        self.assertEqual(self.client.close_code, 1011)

    @unittest.mock.patch('websockets.server.WebSocketServerProtocol.close')
    def test_server_close_crashes(self, close):
        close.side_effect = ValueError("close failed")

        self.start_server()
        self.start_client()
        self.loop.run_until_complete(self.client.send("Hello!"))
        reply = self.loop.run_until_complete(self.client.recv())
        self.assertEqual(reply, "Hello!")
        self.stop_client()
        self.stop_server()

        # Connection ends with an abnormal closure.
        self.assertEqual(self.client.close_code, 1006)

    @unittest.mock.patch.object(WebSocketClientProtocol, 'handshake')
    def test_client_closes_connection_before_handshake(self, handshake):
        self.start_server()
        self.start_client()
        # We have mocked the handshake() method to prevent the client from
        # performing the opening handshake. Force it to close the connection.
        self.loop.run_until_complete(self.client.close_connection(force=True))
        self.stop_client()
        # The server should stop properly anyway. It used to hang because the
        # worker handling the connection was waiting for the opening handshake.
        self.stop_server()

    @unittest.mock.patch('websockets.server.read_request')
    def test_server_shuts_down_during_opening_handshake(self, _read_request):
        _read_request.side_effect = asyncio.CancelledError

        self.start_server()
        self.server.closing = True
        with self.assertRaises(InvalidHandshake) as raised:
            self.start_client()
        self.stop_server()

        # Opening handshake fails with 503 Service Unavailable
        self.assertEqual(str(raised.exception), "Bad status code: 503")

    def test_server_shuts_down_during_connection_handling(self):
        self.start_server()
        self.start_client()

        self.server.close()
        with self.assertRaises(ConnectionClosed):
            self.loop.run_until_complete(self.client.recv())
        self.stop_client()
        self.stop_server()

        # Websocket connection terminates with 1001 Going Away.
        self.assertEqual(self.client.close_code, 1001)


@unittest.skipUnless(os.path.exists(testcert), "test certificate is missing")
class SSLClientServerTests(ClientServerTests):

    secure = True

    @property
    def server_context(self):
        ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
        ssl_context.load_cert_chain(testcert)
        return ssl_context

    @property
    def client_context(self):
        ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
        ssl_context.load_verify_locations(testcert)
        ssl_context.verify_mode = ssl.CERT_REQUIRED
        return ssl_context

    def start_server(self, *args, **kwds):
        kwds['ssl'] = self.server_context
        server = serve(handler, 'localhost', 8642, **kwds)
        self.server = self.loop.run_until_complete(server)

    def start_client(self, path='', **kwds):
        kwds['ssl'] = self.client_context
        client = connect('wss://localhost:8642/' + path, **kwds)
        self.client = self.loop.run_until_complete(client)

    def test_ws_uri_is_rejected(self):
        self.start_server()
        client = connect('ws://localhost:8642/', ssl=self.client_context)
        with self.assertRaises(ValueError):
            self.loop.run_until_complete(client)
        self.stop_server()


class ClientServerOriginTests(unittest.TestCase):

    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self.loop)

    def tearDown(self):
        self.loop.close()

    def test_checking_origin_succeeds(self):
        server = self.loop.run_until_complete(
            serve(handler, 'localhost', 8642, origins=['http://localhost']))
        client = self.loop.run_until_complete(
            connect('ws://localhost:8642/', origin='http://localhost'))

        self.loop.run_until_complete(client.send("Hello!"))
        self.assertEqual(self.loop.run_until_complete(client.recv()), "Hello!")

        self.loop.run_until_complete(client.close())
        server.close()
        self.loop.run_until_complete(server.wait_closed())

    def test_checking_origin_fails(self):
        server = self.loop.run_until_complete(
            serve(handler, 'localhost', 8642, origins=['http://localhost']))
        with self.assertRaisesRegex(InvalidHandshake, "Bad status code: 403"):
            self.loop.run_until_complete(
                connect('ws://localhost:8642/', origin='http://otherhost'))

        server.close()
        self.loop.run_until_complete(server.wait_closed())

    def test_checking_lack_of_origin_succeeds(self):
        server = self.loop.run_until_complete(
            serve(handler, 'localhost', 8642, origins=['']))
        client = self.loop.run_until_complete(connect('ws://localhost:8642/'))

        self.loop.run_until_complete(client.send("Hello!"))
        self.assertEqual(self.loop.run_until_complete(client.recv()), "Hello!")

        self.loop.run_until_complete(client.close())
        server.close()
        self.loop.run_until_complete(server.wait_closed())


try:
    from .py35.client_server import ClientServerContextManager
except (SyntaxError, ImportError):                          # pragma: no cover
    pass
else:
    class ClientServerContextManagerTests(ClientServerContextManager,
                                          unittest.TestCase):
        pass
