########################################################################
# File name: test_channel_binding_methods.py
# This file is part of: aiosasl
#
# LICENSE
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this program.  If not, see
# <http://www.gnu.org/licenses/>.
#
########################################################################

import unittest
import unittest.mock

import OpenSSL

from aiosasl.channel_binding import (
    ChannelBindingProvider,
    StdlibTLS,
    TLSUnique,
    TLSServerEndPoint,
    parse_openssl_digest,
)


RAW_EXAMPLE_CERTS = [(b'0\x82\x01\xf90\x82\x01b\x02\x01\x010\r\x06\t*\x86H\x86\xf7\r\x01\x01\x04\x05\x000C1\x0b0\t\x06\x03U\x04\x06\x13\x02FN1\x1c0\x1a\x06\x03U\x04\n\x0c\x13Example Association1\x160\x14\x06\x03U\x04\x03\x0c\raiosasl tests0"\x18\x0f20000101000000Z\x18\x0f21000101000000Z0C1\x0b0\t\x06\x03U\x04\x06\x13\x02FN1\x1c0\x1a\x06\x03U\x04\n\x0c\x13Example Association1\x160\x14\x06\x03U\x04\x03\x0c\raiosasl tests0\x81\x9f0\r\x06\t*\x86H\x86\xf7\r\x01\x01\x01\x05\x00\x03\x81\x8d\x000\x81\x89\x02\x81\x81\x00\x9d\xce/J\x9b\x89Z{~g\xda\xd8\xbfCi%g:\x9f\xbb\xcb\xe6\x08\xd0\x9c\x0b\xec1\xda\xb8\xd1w\x14v\xc83\x8dI\xc7\xf1#qz\x97\xf7\xf27\xd2\x97\x91s\x1f/i\x06f\x1a\x1ejN\xfb\x13\xf1\xe1\x8a\'\xa6\xec\x03h\xdd+\xbe\xda\xde\xba\xfd\xe7\xc4h\xc6]5d\xe5k\x97\x15\xf1\xc4\xc0\xf1\xc8\xb7\xb4\xea\x0f\xbb\x15O\xc10 3\x8d\x9a \xa8Dg\xb1\x1c\xac\xa3\xe7pY\xdb\xb7\xb8Ze\x95^\x94\x1b\xfa\xed\x02\x03\x01\x00\x010\r\x06\t*\x86H\x86\xf7\r\x01\x01\x04\x05\x00\x03\x81\x81\x00l77\x1c\x95/\xb41G\xe8l\xe6\x9e\x8aL\xbcs#?\x8diZwB\x942\xa0\x82\x1e\xa5\x08\xe5%\xd8\x93<\xae\xe7\x1fR\x03\xf7\xcfw\xf4\xf7Q\x99B\x9c\xbf\x80\xac\xd5{\r\xa7\xf2\xb1\xfa\x88\xd2\x14\xf4\xf9q(\xfa4\x17\xa9\x07V\xb4\xe8G\xb4\x93\x8b\x8b\x0b\xc7\x00\xc5/\x80\x8c\x1d\x85Uv\xd5\xa9\xe6\x17\xb1m@\xb7\x01\xd6.l\xe6\xa9/\x90\xea$NB\xfa\xa5\xe3\xe4\xf4I\xfcQ\x85\x06\xc7\xd3g\xb8\xe9\xf5i', b'0B:AA:BA:DD:CA:F4:7C:2F:5A:2A:12:67:16:B3:1C:AF:65:1F:D6:B0:BF:A2:B7:B9:9A:CD:83:F4:AB:B5:CE:CD'), (b'0\x82\x01\xf90\x82\x01b\x02\x01\x010\r\x06\t*\x86H\x86\xf7\r\x01\x01\x05\x05\x000C1\x0b0\t\x06\x03U\x04\x06\x13\x02FN1\x1c0\x1a\x06\x03U\x04\n\x0c\x13Example Association1\x160\x14\x06\x03U\x04\x03\x0c\raiosasl tests0"\x18\x0f20000101000000Z\x18\x0f21000101000000Z0C1\x0b0\t\x06\x03U\x04\x06\x13\x02FN1\x1c0\x1a\x06\x03U\x04\n\x0c\x13Example Association1\x160\x14\x06\x03U\x04\x03\x0c\raiosasl tests0\x81\x9f0\r\x06\t*\x86H\x86\xf7\r\x01\x01\x01\x05\x00\x03\x81\x8d\x000\x81\x89\x02\x81\x81\x00\x9d\xce/J\x9b\x89Z{~g\xda\xd8\xbfCi%g:\x9f\xbb\xcb\xe6\x08\xd0\x9c\x0b\xec1\xda\xb8\xd1w\x14v\xc83\x8dI\xc7\xf1#qz\x97\xf7\xf27\xd2\x97\x91s\x1f/i\x06f\x1a\x1ejN\xfb\x13\xf1\xe1\x8a\'\xa6\xec\x03h\xdd+\xbe\xda\xde\xba\xfd\xe7\xc4h\xc6]5d\xe5k\x97\x15\xf1\xc4\xc0\xf1\xc8\xb7\xb4\xea\x0f\xbb\x15O\xc10 3\x8d\x9a \xa8Dg\xb1\x1c\xac\xa3\xe7pY\xdb\xb7\xb8Ze\x95^\x94\x1b\xfa\xed\x02\x03\x01\x00\x010\r\x06\t*\x86H\x86\xf7\r\x01\x01\x05\x05\x00\x03\x81\x81\x00\x92$\xfe\xc8\xeb~#\xd5\x0f#\x1c\xd9\xfeV\xa4z\x84\xcc\x83\xdd\xfd\xa47\xb5\xf7\xe4*\x9fJ\xd5My\x13UFT+\x95\'v\xa8\xb6\x95{2\xe4t\xbd#F\x08\xbf\xf3\xbei\xaf\x83M\xbb\x06FO\x1b\x9ew\xc81\r\xa0\xc3Z\xb8\xb7\x16\tS\xde\xe8\xe4\xdd\xed\x04\xfb\xe1\x1a(\xf0i}\x18\x10Q\x82\x10\x8eH\xe9x:\xcd\xaec]\x9c\xd9\xa7&\xf1\xa3:`\xa8(\x86+\xf05\n!\x82\xac\xfd\xd7\xe72T\x15.', b'79:45:47:63:08:2F:E0:21:E4:31:A5:99:EE:FE:34:D5:9E:1A:A9:A0:54:00:D6:FE:E8:89:62:27:A1:8E:5F:A7'), (b'0\x82\x01\xf90\x82\x01b\x02\x01\x010\r\x06\t*\x86H\x86\xf7\r\x01\x01\x0b\x05\x000C1\x0b0\t\x06\x03U\x04\x06\x13\x02FN1\x1c0\x1a\x06\x03U\x04\n\x0c\x13Example Association1\x160\x14\x06\x03U\x04\x03\x0c\raiosasl tests0"\x18\x0f20000101000000Z\x18\x0f21000101000000Z0C1\x0b0\t\x06\x03U\x04\x06\x13\x02FN1\x1c0\x1a\x06\x03U\x04\n\x0c\x13Example Association1\x160\x14\x06\x03U\x04\x03\x0c\raiosasl tests0\x81\x9f0\r\x06\t*\x86H\x86\xf7\r\x01\x01\x01\x05\x00\x03\x81\x8d\x000\x81\x89\x02\x81\x81\x00\x9d\xce/J\x9b\x89Z{~g\xda\xd8\xbfCi%g:\x9f\xbb\xcb\xe6\x08\xd0\x9c\x0b\xec1\xda\xb8\xd1w\x14v\xc83\x8dI\xc7\xf1#qz\x97\xf7\xf27\xd2\x97\x91s\x1f/i\x06f\x1a\x1ejN\xfb\x13\xf1\xe1\x8a\'\xa6\xec\x03h\xdd+\xbe\xda\xde\xba\xfd\xe7\xc4h\xc6]5d\xe5k\x97\x15\xf1\xc4\xc0\xf1\xc8\xb7\xb4\xea\x0f\xbb\x15O\xc10 3\x8d\x9a \xa8Dg\xb1\x1c\xac\xa3\xe7pY\xdb\xb7\xb8Ze\x95^\x94\x1b\xfa\xed\x02\x03\x01\x00\x010\r\x06\t*\x86H\x86\xf7\r\x01\x01\x0b\x05\x00\x03\x81\x81\x000\xbfiz\x97\x90\xab\x8f\xca1\x1eH\xeb\xd6\xbd?\x87o\x07\xd3\xddi\xe9C\xa9\x97\x84\x87Xp&&\'\x0c\xbci\xc3\xb4?o\xfdOy\x1d\xf2\x9e\xe36\x8d(\xc2A\x0c\x106$}\xb3X;\xa7\xc3h\xd38D\xeb\x95\x8f>6\x17\xa1e\x1a\xc9\xfa\xbd\r\x00\x8c0\xa5T{n5\xe5\x9dp\x80%\x9eb\xb4\xd6\xd3;\xc7\x8c\n\x19\x9b\xaf\xab =IhK\xfc%*\xdbMkM\x8f\x1d\x05\xd8\xa9\xbd\x17RW\xfc{', b'0A:19:C9:73:FE:9C:F2:B9:DF:5D:27:CC:5A:FD:04:5E:19:97:05:99:4A:EB:91:16:FC:3F:CC:87:1B:D5:6E:2E'), (b'0\x82\x01\xf90\x82\x01b\x02\x01\x010\r\x06\t*\x86H\x86\xf7\r\x01\x01\r\x05\x000C1\x0b0\t\x06\x03U\x04\x06\x13\x02FN1\x1c0\x1a\x06\x03U\x04\n\x0c\x13Example Association1\x160\x14\x06\x03U\x04\x03\x0c\raiosasl tests0"\x18\x0f20000101000000Z\x18\x0f21000101000000Z0C1\x0b0\t\x06\x03U\x04\x06\x13\x02FN1\x1c0\x1a\x06\x03U\x04\n\x0c\x13Example Association1\x160\x14\x06\x03U\x04\x03\x0c\raiosasl tests0\x81\x9f0\r\x06\t*\x86H\x86\xf7\r\x01\x01\x01\x05\x00\x03\x81\x8d\x000\x81\x89\x02\x81\x81\x00\x9d\xce/J\x9b\x89Z{~g\xda\xd8\xbfCi%g:\x9f\xbb\xcb\xe6\x08\xd0\x9c\x0b\xec1\xda\xb8\xd1w\x14v\xc83\x8dI\xc7\xf1#qz\x97\xf7\xf27\xd2\x97\x91s\x1f/i\x06f\x1a\x1ejN\xfb\x13\xf1\xe1\x8a\'\xa6\xec\x03h\xdd+\xbe\xda\xde\xba\xfd\xe7\xc4h\xc6]5d\xe5k\x97\x15\xf1\xc4\xc0\xf1\xc8\xb7\xb4\xea\x0f\xbb\x15O\xc10 3\x8d\x9a \xa8Dg\xb1\x1c\xac\xa3\xe7pY\xdb\xb7\xb8Ze\x95^\x94\x1b\xfa\xed\x02\x03\x01\x00\x010\r\x06\t*\x86H\x86\xf7\r\x01\x01\r\x05\x00\x03\x81\x81\x00\x8cf\x94\xeet\xf4\x03\xe6Tj|\xee\x1dh?\xed\x9b4\xb4\xb1\xc0J\xa6\xe1\xb8$\xf9c)\xd0[5\xed\x8d\xa3 \x9f\xfb\xedm\x904\x9a(u\xbe\x0b\xa8\\)d\xb2\x8b\xd6\xf2^\x80\xa4Z\x10\x0b\x8aN\'f\xbb\x81\x95\xc2\x99v\x96\xb2\xb5_\xed\xcc\xc1\x9a\xc1\xa7\x85\x7f\xa3s\x17\xab\x98\x91o\xdd3\xd64{\x9c\xaft3t`\xa8\'\x9b\x7f\x13\x02.\xf2Vl=3\xd5#\x8e>L\x8c\xfe\xce\x1c\xb3\x7f\x04\xe8m\xad', b'B1:90:A4:38:8E:0E:63:F4:0E:3B:D3:21:8F:C5:33:2E:3B:96:F7:86:14:C1:E1:37:ED:AE:BE:B1:4C:96:37:7A:E7:13:83:4D:94:B6:E8:2C:D8:A6:E0:3C:DA:AA:79:07:2A:73:85:58:C9:D1:ED:CE:53:E0:A8:F7:B1:A0:DB:62')]  # NOQA


EXAMPLE_CERTS = []
for cert, digest in RAW_EXAMPLE_CERTS:
    EXAMPLE_CERTS.append((
        OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_ASN1, cert),
        parse_openssl_digest(digest),
    ))


class TestStdlibTLS(unittest.TestCase):

    def test_is_channel_binding_type(self):
        self.assertTrue(issubclass(StdlibTLS, ChannelBindingProvider))

    def test_cb_name(self):
        for method in ("tls-unique", "tls-server-end-point"):
            mock_socket = unittest.mock.Mock()
            provider = StdlibTLS(mock_socket, method)
            self.assertEqual(provider.cb_name, method.encode("us-ascii"))

        mock_socket = unittest.mock.Mock()
        provider = StdlibTLS(mock_socket, "tls-server-end-point")
        self.assertEqual(provider.cb_name, b"tls-server-end-point")

    def test_extract_cb_data(self):
        for method in ("tls-unique", "tls-server-end-point"):
            mock_socket = unittest.mock.Mock()
            provider = StdlibTLS(mock_socket, method)
            with unittest.mock.patch.object(
                    mock_socket,
                    "get_channel_binding") as get_channel_binding:
                get_channel_binding.return_value = b"foobar"
                cb_data = provider.extract_cb_data()

            self.assertSequenceEqual(
                get_channel_binding.mock_calls,
                [
                    unittest.mock.call(method)
                ]
            )

            self.assertEqual(cb_data, b"foobar")


class TestTLSUnique(unittest.TestCase):

    def test_is_channel_binding_type(self):
        self.assertTrue(issubclass(TLSUnique, ChannelBindingProvider))

    def test_cb_name(self):
        mock_conn = unittest.mock.Mock()
        provider = TLSUnique(mock_conn)
        self.assertEqual(provider.cb_name, b"tls-unique")

    def test_extract_cb_data(self):
        mock_conn = unittest.mock.Mock()
        provider = TLSUnique(mock_conn)
        with unittest.mock.patch.object(
                mock_conn,
                "get_finished") as get_finished:
            get_finished.return_value = b"foobar"
            cb_data = provider.extract_cb_data()

        self.assertSequenceEqual(
            get_finished.mock_calls,
            [
                unittest.mock.call()
            ]
        )

        self.assertEqual(cb_data, b"foobar")


class TestTLSServerEndPoint(unittest.TestCase):

    def test_is_channel_binding_type(self):
        self.assertTrue(issubclass(TLSServerEndPoint, ChannelBindingProvider))

    def test_cb_name(self):
        mock_conn = unittest.mock.Mock()
        provider = TLSServerEndPoint(mock_conn)
        self.assertEqual(provider.cb_name, b"tls-server-end-point")

    def test_extract_cb_data(self):
        mock_conn = unittest.mock.Mock()
        provider = TLSServerEndPoint(mock_conn)

        for cert, hash_ in EXAMPLE_CERTS:
            with unittest.mock.patch.object(
                    mock_conn,
                    "get_peer_certificate") as get_peer_cert:
                get_peer_cert.return_value = cert
                cb_data = provider.extract_cb_data()

            self.assertSequenceEqual(
                get_peer_cert.mock_calls,
                [
                    unittest.mock.call()
                ]
            )

            self.assertEqual(cb_data, hash_)
