import os
import re
import shutil
import tempfile
from binascii import hexlify, unhexlify

import msgpack

import pytest

from ..crypto import bytes_to_long, num_aes_blocks
from ..key import PlaintextKey, PassphraseKey, KeyfileKey
from ..key import UnsupportedManifestError, TAMRequiredError, TAMUnsupportedSuiteError, TAMInvalid
from ..helpers import Location, StableDict
from . import BaseTestCase


class KeyTestCase(BaseTestCase):

    class MockArgs:
        location = Location(tempfile.mkstemp()[1])

    keyfile2_key_file = """
        BORG_KEY 0000000000000000000000000000000000000000000000000000000000000000
        hqppdGVyYXRpb25zzgABhqCkaGFzaNoAIMyonNI+7Cjv0qHi0AOBM6bLGxACJhfgzVD2oq
        bIS9SFqWFsZ29yaXRobaZzaGEyNTakc2FsdNoAINNK5qqJc1JWSUjACwFEWGTdM7Nd0a5l
        1uBGPEb+9XM9p3ZlcnNpb24BpGRhdGHaANAYDT5yfPpU099oBJwMomsxouKyx/OG4QIXK2
        hQCG2L2L/9PUu4WIuKvGrsXoP7syemujNfcZws5jLp2UPva4PkQhQsrF1RYDEMLh2eF9Ol
        rwtkThq1tnh7KjWMG9Ijt7/aoQtq0zDYP/xaFF8XXSJxiyP5zjH5+spB6RL0oQHvbsliSh
        /cXJq7jrqmrJ1phd6dg4SHAM/i+hubadZoS6m25OQzYAW09wZD/phG8OVa698Z5ed3HTaT
        SmrtgJL3EoOKgUI9d6BLE4dJdBqntifo""".strip()

    keyfile2_cdata = unhexlify(re.sub('\W', '', """
        0055f161493fcfc16276e8c31493c4641e1eb19a79d0326fad0291e5a9c98e5933
        00000000000003e8d21eaf9b86c297a8cd56432e1915bb
        """))
    keyfile2_id = unhexlify('c3fbf14bc001ebcc3cd86e696c13482ed071740927cd7cbe1b01b4bfcee49314')

    def setUp(self):
        self.tmppath = tempfile.mkdtemp()
        os.environ['BORG_KEYS_DIR'] = self.tmppath

    def tearDown(self):
        shutil.rmtree(self.tmppath)

    class MockRepository:
        class _Location:
            orig = '/some/place'

            def canonical_path(self):
                return self.orig

        _location = _Location()
        id = bytes(32)

    def test_plaintext(self):
        key = PlaintextKey.create(None, None)
        data = b'foo'
        self.assert_equal(hexlify(key.id_hash(data)), b'2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae')
        self.assert_equal(data, key.decrypt(key.id_hash(data), key.encrypt(data)))

    def test_keyfile(self):
        os.environ['BORG_PASSPHRASE'] = 'test'
        key = KeyfileKey.create(self.MockRepository(), self.MockArgs())
        self.assert_equal(bytes_to_long(key.enc_cipher.iv, 8), 0)
        manifest = key.encrypt(b'XXX')
        self.assert_equal(key.extract_nonce(manifest), 0)
        manifest2 = key.encrypt(b'XXX')
        self.assert_not_equal(manifest, manifest2)
        self.assert_equal(key.decrypt(None, manifest), key.decrypt(None, manifest2))
        self.assert_equal(key.extract_nonce(manifest2), 1)
        iv = key.extract_nonce(manifest)
        key2 = KeyfileKey.detect(self.MockRepository(), manifest)
        self.assert_equal(bytes_to_long(key2.enc_cipher.iv, 8), iv + num_aes_blocks(len(manifest) - KeyfileKey.PAYLOAD_OVERHEAD))
        # Key data sanity check
        self.assert_equal(len(set([key2.id_key, key2.enc_key, key2.enc_hmac_key])), 3)
        self.assert_equal(key2.chunk_seed == 0, False)
        data = b'foo'
        self.assert_equal(data, key2.decrypt(key.id_hash(data), key.encrypt(data)))

    def test_keyfile2(self):
        with open(os.path.join(os.environ['BORG_KEYS_DIR'], 'keyfile'), 'w') as fd:
            fd.write(self.keyfile2_key_file)
        os.environ['BORG_PASSPHRASE'] = 'passphrase'
        key = KeyfileKey.detect(self.MockRepository(), self.keyfile2_cdata)
        self.assert_equal(key.decrypt(self.keyfile2_id, self.keyfile2_cdata), b'payload')

    def test_passphrase(self):
        os.environ['BORG_PASSPHRASE'] = 'test'
        key = PassphraseKey.create(self.MockRepository(), None)
        self.assert_equal(bytes_to_long(key.enc_cipher.iv, 8), 0)
        self.assert_equal(hexlify(key.id_key), b'793b0717f9d8fb01c751a487e9b827897ceea62409870600013fbc6b4d8d7ca6')
        self.assert_equal(hexlify(key.enc_hmac_key), b'b885a05d329a086627412a6142aaeb9f6c54ab7950f996dd65587251f6bc0901')
        self.assert_equal(hexlify(key.enc_key), b'2ff3654c6daf7381dbbe718d2b20b4f1ea1e34caa6cc65f6bb3ac376b93fed2a')
        self.assert_equal(key.chunk_seed, -775740477)
        manifest = key.encrypt(b'XXX')
        self.assert_equal(key.extract_nonce(manifest), 0)
        manifest2 = key.encrypt(b'XXX')
        self.assert_not_equal(manifest, manifest2)
        self.assert_equal(key.decrypt(None, manifest), key.decrypt(None, manifest2))
        self.assert_equal(key.extract_nonce(manifest2), 1)
        iv = key.extract_nonce(manifest)
        key2 = PassphraseKey.detect(self.MockRepository(), manifest)
        self.assert_equal(bytes_to_long(key2.enc_cipher.iv, 8), iv + num_aes_blocks(len(manifest) - PassphraseKey.PAYLOAD_OVERHEAD))
        self.assert_equal(key.id_key, key2.id_key)
        self.assert_equal(key.enc_hmac_key, key2.enc_hmac_key)
        self.assert_equal(key.enc_key, key2.enc_key)
        self.assert_equal(key.chunk_seed, key2.chunk_seed)
        data = b'foo'
        self.assert_equal(hexlify(key.id_hash(data)), b'818217cf07d37efad3860766dcdf1d21e401650fed2d76ed1d797d3aae925990')
        self.assert_equal(data, key2.decrypt(key2.id_hash(data), key.encrypt(data)))


class TestTAM:
    @pytest.fixture
    def key(self, monkeypatch):
        monkeypatch.setenv('BORG_PASSPHRASE', 'test')
        return KeyfileKey.create(KeyTestCase.MockRepository(), KeyTestCase.MockArgs())

    def test_unpack_future(self, key):
        blob = b'\xc1\xc1\xc1\xc1foobar'
        with pytest.raises(UnsupportedManifestError):
            key.unpack_and_verify_manifest(blob)

        blob = b'\xc1\xc1\xc1'
        with pytest.raises(msgpack.UnpackException):
            key.unpack_and_verify_manifest(blob)

    def test_missing_when_required(self, key):
        blob = msgpack.packb({})
        with pytest.raises(TAMRequiredError):
            key.unpack_and_verify_manifest(blob)

    def test_missing(self, key):
        blob = msgpack.packb({})
        key.tam_required = False
        unpacked, verified = key.unpack_and_verify_manifest(blob)
        assert unpacked == {}
        assert not verified

    def test_unknown_type_when_required(self, key):
        blob = msgpack.packb({
            'tam': {
                'type': 'HMAC_VOLLBIT',
            },
        })
        with pytest.raises(TAMUnsupportedSuiteError):
            key.unpack_and_verify_manifest(blob)

    def test_unknown_type(self, key):
        blob = msgpack.packb({
            'tam': {
                'type': 'HMAC_VOLLBIT',
            },
        })
        key.tam_required = False
        unpacked, verified = key.unpack_and_verify_manifest(blob)
        assert unpacked == {}
        assert not verified

    @pytest.mark.parametrize('tam, exc', (
        ({}, TAMUnsupportedSuiteError),
        ({'type': b'\xff'}, TAMUnsupportedSuiteError),
        (None, TAMInvalid),
        (1234, TAMInvalid),
    ))
    def test_invalid(self, key, tam, exc):
        blob = msgpack.packb({
            'tam': tam,
        })
        with pytest.raises(exc):
            key.unpack_and_verify_manifest(blob)

    @pytest.mark.parametrize('hmac, salt', (
        ({}, bytes(64)),
        (bytes(64), {}),
        (None, bytes(64)),
        (bytes(64), None),
    ))
    def test_wrong_types(self, key, hmac, salt):
        data = {
            'tam': {
                'type': 'HKDF_HMAC_SHA512',
                'hmac': hmac,
                'salt': salt
            },
        }
        tam = data['tam']
        if hmac is None:
            del tam['hmac']
        if salt is None:
            del tam['salt']
        blob = msgpack.packb(data)
        with pytest.raises(TAMInvalid):
            key.unpack_and_verify_manifest(blob)

    def test_round_trip(self, key):
        data = {'foo': 'bar'}
        blob = key.pack_and_authenticate_metadata(data)
        assert blob.startswith(b'\x82')

        unpacked = msgpack.unpackb(blob)
        assert unpacked[b'tam'][b'type'] == b'HKDF_HMAC_SHA512'

        unpacked, verified = key.unpack_and_verify_manifest(blob)
        assert verified
        assert unpacked[b'foo'] == b'bar'
        assert b'tam' not in unpacked

    @pytest.mark.parametrize('which', (b'hmac', b'salt'))
    def test_tampered(self, key, which):
        data = {'foo': 'bar'}
        blob = key.pack_and_authenticate_metadata(data)
        assert blob.startswith(b'\x82')

        unpacked = msgpack.unpackb(blob, object_hook=StableDict)
        assert len(unpacked[b'tam'][which]) == 64
        unpacked[b'tam'][which] = unpacked[b'tam'][which][0:32] + bytes(32)
        assert len(unpacked[b'tam'][which]) == 64
        blob = msgpack.packb(unpacked)

        with pytest.raises(TAMInvalid):
            key.unpack_and_verify_manifest(blob)
