import json
import os
import unittest
from cryptography.hazmat.primitives.keywrap import InvalidUnwrap
from authlib.common.encoding import urlsafe_b64encode, json_b64encode, to_bytes, to_unicode
from authlib.jose import JsonWebEncryption
from authlib.jose import OctKey, OKPKey
from authlib.jose import errors
from authlib.jose.drafts import register_jwe_draft
from authlib.jose.errors import InvalidAlgorithmForMultipleRecipientsMode, DecodeError, InvalidHeaderParameterNameError
from authlib.jose.util import extract_header
from tests.util import read_file_path

register_jwe_draft(JsonWebEncryption)


class JWETest(unittest.TestCase):
    def test_not_enough_segments(self):
        s = 'a.b.c'
        jwe = JsonWebEncryption()
        self.assertRaises(
            errors.DecodeError,
            jwe.deserialize_compact,
            s, None
        )

    def test_invalid_header(self):
        jwe = JsonWebEncryption()
        public_key = read_file_path('rsa_public.pem')
        self.assertRaises(
            errors.MissingAlgorithmError,
            jwe.serialize_compact, {}, 'a', public_key
        )
        self.assertRaises(
            errors.UnsupportedAlgorithmError,
            jwe.serialize_compact, {'alg': 'invalid'}, 'a', public_key
        )
        self.assertRaises(
            errors.MissingEncryptionAlgorithmError,
            jwe.serialize_compact, {'alg': 'RSA-OAEP'}, 'a', public_key
        )
        self.assertRaises(
            errors.UnsupportedEncryptionAlgorithmError,
            jwe.serialize_compact, {'alg': 'RSA-OAEP', 'enc': 'invalid'},
            'a', public_key
        )
        self.assertRaises(
            errors.UnsupportedCompressionAlgorithmError,
            jwe.serialize_compact,
            {'alg': 'RSA-OAEP', 'enc': 'A256GCM', 'zip': 'invalid'},
            'a', public_key
        )

    def test_not_supported_alg(self):
        public_key = read_file_path('rsa_public.pem')
        private_key = read_file_path('rsa_private.pem')

        jwe = JsonWebEncryption()
        s = jwe.serialize_compact(
            {'alg': 'RSA-OAEP', 'enc': 'A256GCM'},
            'hello', public_key
        )

        jwe = JsonWebEncryption(algorithms=['RSA1_5', 'A256GCM'])
        self.assertRaises(
            errors.UnsupportedAlgorithmError,
            jwe.serialize_compact,
            {'alg': 'RSA-OAEP', 'enc': 'A256GCM'},
            'hello', public_key
        )
        self.assertRaises(
            errors.UnsupportedCompressionAlgorithmError,
            jwe.serialize_compact,
            {'alg': 'RSA1_5', 'enc': 'A256GCM', 'zip': 'DEF'},
            'hello', public_key
        )
        self.assertRaises(
            errors.UnsupportedAlgorithmError,
            jwe.deserialize_compact,
            s, private_key,
        )

        jwe = JsonWebEncryption(algorithms=['RSA-OAEP', 'A192GCM'])
        self.assertRaises(
            errors.UnsupportedEncryptionAlgorithmError,
            jwe.serialize_compact,
            {'alg': 'RSA-OAEP', 'enc': 'A256GCM'},
            'hello', public_key
        )
        self.assertRaises(
            errors.UnsupportedCompressionAlgorithmError,
            jwe.serialize_compact,
            {'alg': 'RSA-OAEP', 'enc': 'A192GCM', 'zip': 'DEF'},
            'hello', public_key
        )
        self.assertRaises(
            errors.UnsupportedEncryptionAlgorithmError,
            jwe.deserialize_compact,
            s, private_key,
        )

    def test_inappropriate_sender_key_for_serialize_compact(self):
        jwe = JsonWebEncryption()
        alice_key = {
            "kty": "EC",
            "crv": "P-256",
            "x": "WKn-ZIGevcwGIyyrzFoZNBdaq9_TsqzGl96oc0CWuis",
            "y": "y77t-RvAHRKTsSGdIYUfweuOvwrvDD-Q3Hv5J0fSKbE",
            "d": "Hndv7ZZjs_ke8o9zXYo3iq-Yr8SewI5vrqd0pAvEPqg"
        }
        bob_key = {
            "kty": "EC",
            "crv": "P-256",
            "x": "weNJy2HscCSM6AEDTDg04biOvhFhyyWvOHQfeF_PxMQ",
            "y": "e8lnCO-AlStT-NJVX-crhB7QRYhiix03illJOVAOyck",
            "d": "VEmDZpDXXK8p8N0Cndsxs924q6nS1RXFASRl6BfUqdw"
        }

        protected = {'alg': 'ECDH-1PU', 'enc': 'A256GCM'}
        self.assertRaises(
            ValueError,
            jwe.serialize_compact,
            protected, b'hello', bob_key
        )

        protected = {'alg': 'ECDH-ES', 'enc': 'A256GCM'}
        self.assertRaises(
            ValueError,
            jwe.serialize_compact,
            protected, b'hello', bob_key, sender_key=alice_key
        )

    def test_inappropriate_sender_key_for_deserialize_compact(self):
        jwe = JsonWebEncryption()
        alice_key = {
            "kty": "EC",
            "crv": "P-256",
            "x": "WKn-ZIGevcwGIyyrzFoZNBdaq9_TsqzGl96oc0CWuis",
            "y": "y77t-RvAHRKTsSGdIYUfweuOvwrvDD-Q3Hv5J0fSKbE",
            "d": "Hndv7ZZjs_ke8o9zXYo3iq-Yr8SewI5vrqd0pAvEPqg"
        }
        bob_key = {
            "kty": "EC",
            "crv": "P-256",
            "x": "weNJy2HscCSM6AEDTDg04biOvhFhyyWvOHQfeF_PxMQ",
            "y": "e8lnCO-AlStT-NJVX-crhB7QRYhiix03illJOVAOyck",
            "d": "VEmDZpDXXK8p8N0Cndsxs924q6nS1RXFASRl6BfUqdw"
        }

        protected = {'alg': 'ECDH-1PU', 'enc': 'A256GCM'}
        data = jwe.serialize_compact(protected, b'hello', bob_key, sender_key=alice_key)
        self.assertRaises(
            ValueError,
            jwe.deserialize_compact,
            data, bob_key
        )

        protected = {'alg': 'ECDH-ES', 'enc': 'A256GCM'}
        data = jwe.serialize_compact(protected, b'hello', bob_key)
        self.assertRaises(
            ValueError,
            jwe.deserialize_compact,
            data, bob_key, sender_key=alice_key
        )

    def test_compact_rsa(self):
        jwe = JsonWebEncryption()
        s = jwe.serialize_compact(
            {'alg': 'RSA-OAEP', 'enc': 'A256GCM'},
            'hello',
            read_file_path('rsa_public.pem')
        )
        data = jwe.deserialize_compact(s, read_file_path('rsa_private.pem'))
        header, payload = data['header'], data['payload']
        self.assertEqual(payload, b'hello')
        self.assertEqual(header['alg'], 'RSA-OAEP')

    def test_with_zip_header(self):
        jwe = JsonWebEncryption()
        s = jwe.serialize_compact(
            {'alg': 'RSA-OAEP', 'enc': 'A128CBC-HS256', 'zip': 'DEF'},
            'hello',
            read_file_path('rsa_public.pem')
        )
        data = jwe.deserialize_compact(s, read_file_path('rsa_private.pem'))
        header, payload = data['header'], data['payload']
        self.assertEqual(payload, b'hello')
        self.assertEqual(header['alg'], 'RSA-OAEP')

    def test_aes_jwe(self):
        jwe = JsonWebEncryption()
        sizes = [128, 192, 256]
        _enc_choices = [
            'A128CBC-HS256', 'A192CBC-HS384', 'A256CBC-HS512',
            'A128GCM', 'A192GCM', 'A256GCM'
        ]
        for s in sizes:
            alg = 'A{}KW'.format(s)
            key = os.urandom(s // 8)
            for enc in _enc_choices:
                protected = {'alg': alg, 'enc': enc}
                data = jwe.serialize_compact(protected, b'hello', key)
                rv = jwe.deserialize_compact(data, key)
                self.assertEqual(rv['payload'], b'hello')

    def test_aes_jwe_invalid_key(self):
        jwe = JsonWebEncryption()
        protected = {'alg': 'A128KW', 'enc': 'A128GCM'}
        self.assertRaises(
            ValueError,
            jwe.serialize_compact,
            protected, b'hello', b'invalid-key'
        )

    def test_aes_gcm_jwe(self):
        jwe = JsonWebEncryption()
        sizes = [128, 192, 256]
        _enc_choices = [
            'A128CBC-HS256', 'A192CBC-HS384', 'A256CBC-HS512',
            'A128GCM', 'A192GCM', 'A256GCM'
        ]
        for s in sizes:
            alg = 'A{}GCMKW'.format(s)
            key = os.urandom(s // 8)
            for enc in _enc_choices:
                protected = {'alg': alg, 'enc': enc}
                data = jwe.serialize_compact(protected, b'hello', key)
                rv = jwe.deserialize_compact(data, key)
                self.assertEqual(rv['payload'], b'hello')

    def test_aes_gcm_jwe_invalid_key(self):
        jwe = JsonWebEncryption()
        protected = {'alg': 'A128GCMKW', 'enc': 'A128GCM'}
        self.assertRaises(
            ValueError,
            jwe.serialize_compact,
            protected, b'hello', b'invalid-key'
        )

    def test_serialize_compact_fails_if_header_contains_unknown_field_while_private_fields_restricted(self):
        jwe = JsonWebEncryption(private_headers=set())
        key = OKPKey.generate_key('X25519', is_private=True)

        protected = {
            "alg": "ECDH-ES+A128KW",
            "enc": "A128GCM",
            "foo": "bar"
        }

        self.assertRaises(
            InvalidHeaderParameterNameError,
            jwe.serialize_compact,
            protected, b'hello', key
        )

    def test_serialize_compact_allows_unknown_fields_in_header_while_private_fields_not_restricted(self):
        jwe = JsonWebEncryption()
        key = OKPKey.generate_key('X25519', is_private=True)

        protected = {
            "alg": "ECDH-ES+A128KW",
            "enc": "A128GCM",
            "foo": "bar"
        }

        data = jwe.serialize_compact(protected, b'hello', key)
        rv = jwe.deserialize_compact(data, key)
        self.assertEqual(rv['payload'], b'hello')

    def test_serialize_json_fails_if_protected_header_contains_unknown_field_while_private_fields_restricted(self):
        jwe = JsonWebEncryption(private_headers=set())
        key = OKPKey.generate_key('X25519', is_private=True)

        protected = {
            "alg": "ECDH-ES+A128KW",
            "enc": "A128GCM",
            "foo": "bar"
        }
        header_obj = {
            "protected": protected
        }

        self.assertRaises(
            InvalidHeaderParameterNameError,
            jwe.serialize_json,
            header_obj, b'hello', key
        )

    def test_serialize_json_fails_if_unprotected_header_contains_unknown_field_while_private_fields_restricted(self):
        jwe = JsonWebEncryption(private_headers=set())
        key = OKPKey.generate_key('X25519', is_private=True)

        protected = {
            "alg": "ECDH-ES+A128KW",
            "enc": "A128GCM"
        }
        unprotected = {
            "foo": "bar"
        }
        header_obj = {
            "protected": protected,
            "unprotected": unprotected
        }

        self.assertRaises(
            InvalidHeaderParameterNameError,
            jwe.serialize_json,
            header_obj, b'hello', key
        )

    def test_serialize_json_fails_if_recipient_header_contains_unknown_field_while_private_fields_restricted(self):
        jwe = JsonWebEncryption(private_headers=set())
        key = OKPKey.generate_key('X25519', is_private=True)

        protected = {
            "alg": "ECDH-ES+A128KW",
            "enc": "A128GCM"
        }
        recipients = [
            {
                "header": {
                    "foo": "bar"
                }
            }
        ]
        header_obj = {
            "protected": protected,
            "recipients": recipients
        }

        self.assertRaises(
            InvalidHeaderParameterNameError,
            jwe.serialize_json,
            header_obj, b'hello', key
        )

    def test_serialize_json_allows_unknown_fields_in_headers_while_private_fields_not_restricted(self):
        jwe = JsonWebEncryption()
        key = OKPKey.generate_key('X25519', is_private=True)

        protected = {
            "alg": "ECDH-ES+A128KW",
            "enc": "A128GCM",
            "foo1": "bar1"
        }
        unprotected = {
            "foo2": "bar2"
        }
        recipients = [
            {
                "header": {
                    "foo3": "bar3"
                }
            }
        ]
        header_obj = {
            "protected": protected,
            "unprotected": unprotected,
            "recipients": recipients
        }

        data = jwe.serialize_json(header_obj, b'hello', key)
        rv = jwe.deserialize_json(data, key)
        self.assertEqual(rv['payload'], b'hello')

    def test_serialize_json_ignores_additional_members_in_recipients_elements(self):
        jwe = JsonWebEncryption()
        key = OKPKey.generate_key('X25519', is_private=True)

        protected = {
            "alg": "ECDH-ES+A128KW",
            "enc": "A128GCM"
        }
        recipients = [
            {
                "foo": "bar"
            }
        ]
        header_obj = {
            "protected": protected,
            "recipients": recipients
        }

        data = jwe.serialize_compact(protected, b'hello', key)
        rv = jwe.deserialize_compact(data, key)
        self.assertEqual(rv['payload'], b'hello')

    def test_deserialize_json_fails_if_protected_header_contains_unknown_field_while_private_fields_restricted(self):
        jwe = JsonWebEncryption(private_headers=set())
        key = OKPKey.generate_key('X25519', is_private=True)

        protected = {
            "alg": "ECDH-ES+A128KW",
            "enc": "A128GCM"
        }
        header_obj = {
            "protected": protected
        }

        data = jwe.serialize_json(header_obj, b'hello', key)

        decoded_protected = extract_header(to_bytes(data["protected"]), DecodeError)
        decoded_protected["foo"] = "bar"
        data["protected"] = to_unicode(json_b64encode(decoded_protected))

        self.assertRaises(
            InvalidHeaderParameterNameError,
            jwe.deserialize_json,
            data, key
        )

    def test_deserialize_json_fails_if_unprotected_header_contains_unknown_field_while_private_fields_restricted(self):
        jwe = JsonWebEncryption(private_headers=set())
        key = OKPKey.generate_key('X25519', is_private=True)

        protected = {
            "alg": "ECDH-ES+A128KW",
            "enc": "A128GCM"
        }
        header_obj = {
            "protected": protected
        }

        data = jwe.serialize_json(header_obj, b'hello', key)

        data["unprotected"] = {
            "foo": "bar"
        }

        self.assertRaises(
            InvalidHeaderParameterNameError,
            jwe.deserialize_json,
            data, key
        )

    def test_deserialize_json_fails_if_recipient_header_contains_unknown_field_while_private_fields_restricted(self):
        jwe = JsonWebEncryption(private_headers=set())
        key = OKPKey.generate_key('X25519', is_private=True)

        protected = {
            "alg": "ECDH-ES+A128KW",
            "enc": "A128GCM"
        }
        header_obj = {
            "protected": protected
        }

        data = jwe.serialize_json(header_obj, b'hello', key)

        data["recipients"][0]["header"] = {
            "foo": "bar"
        }

        self.assertRaises(
            InvalidHeaderParameterNameError,
            jwe.deserialize_json,
            data, key
        )

    def test_deserialize_json_allows_unknown_fields_in_headers_while_private_fields_not_restricted(self):
        jwe = JsonWebEncryption()
        key = OKPKey.generate_key('X25519', is_private=True)

        protected = {
            "alg": "ECDH-ES+A128KW",
            "enc": "A128GCM"
        }
        header_obj = {
            "protected": protected
        }

        data = jwe.serialize_json(header_obj, b'hello', key)

        data["unprotected"] = {
            "foo1": "bar1"
        }
        data["recipients"][0]["header"] = {
            "foo2": "bar2"
        }

        rv = jwe.deserialize_json(data, key)
        self.assertEqual(rv['payload'], b'hello')

    def test_deserialize_json_ignores_additional_members_in_recipients_elements(self):
        jwe = JsonWebEncryption()
        key = OKPKey.generate_key('X25519', is_private=True)

        protected = {
            "alg": "ECDH-ES+A128KW",
            "enc": "A128GCM"
        }
        header_obj = {
            "protected": protected
        }

        data = jwe.serialize_json(header_obj, b'hello', key)

        data["recipients"][0]["foo"] = "bar"

        data = jwe.serialize_compact(protected, b'hello', key)
        rv = jwe.deserialize_compact(data, key)
        self.assertEqual(rv['payload'], b'hello')

    def test_deserialize_json_ignores_additional_members_in_jwe_message(self):
        jwe = JsonWebEncryption()
        key = OKPKey.generate_key('X25519', is_private=True)

        protected = {
            "alg": "ECDH-ES+A128KW",
            "enc": "A128GCM"
        }
        header_obj = {
            "protected": protected
        }

        data = jwe.serialize_json(header_obj, b'hello', key)

        data["foo"] = "bar"

        data = jwe.serialize_compact(protected, b'hello', key)
        rv = jwe.deserialize_compact(data, key)
        self.assertEqual(rv['payload'], b'hello')

    def test_ecdh_es_key_agreement_computation(self):
        # https://tools.ietf.org/html/rfc7518#appendix-C
        alice_ephemeral_key = {
            "kty": "EC",
            "crv": "P-256",
            "x": "gI0GAILBdu7T53akrFmMyGcsF3n5dO7MmwNBHKW5SV0",
            "y": "SLW_xSffzlPWrHEVI30DHM_4egVwt3NQqeUD7nMFpps",
            "d": "0_NxaRPUMQoAJt50Gz8YiTr8gRTwyEaCumd-MToTmIo"
        }
        bob_static_key = {
            "kty": "EC",
            "crv": "P-256",
            "x": "weNJy2HscCSM6AEDTDg04biOvhFhyyWvOHQfeF_PxMQ",
            "y": "e8lnCO-AlStT-NJVX-crhB7QRYhiix03illJOVAOyck",
            "d": "VEmDZpDXXK8p8N0Cndsxs924q6nS1RXFASRl6BfUqdw"
        }

        headers = {
            "alg": "ECDH-ES",
            "enc": "A128GCM",
            "apu": "QWxpY2U",
            "apv": "Qm9i",
            "epk": {
                "kty": "EC",
                "crv": "P-256",
                "x": "gI0GAILBdu7T53akrFmMyGcsF3n5dO7MmwNBHKW5SV0",
                "y": "SLW_xSffzlPWrHEVI30DHM_4egVwt3NQqeUD7nMFpps"
            }
        }

        alg = JsonWebEncryption.ALG_REGISTRY['ECDH-ES']
        enc = JsonWebEncryption.ENC_REGISTRY['A128GCM']

        alice_ephemeral_key = alg.prepare_key(alice_ephemeral_key)
        bob_static_key = alg.prepare_key(bob_static_key)

        alice_ephemeral_pubkey = alice_ephemeral_key.get_op_key('wrapKey')
        bob_static_pubkey = bob_static_key.get_op_key('wrapKey')

        # Derived key computation at Alice

        # Step-by-step methods verification
        _shared_key_at_alice = alice_ephemeral_key.exchange_shared_key(bob_static_pubkey)
        self.assertEqual(
            _shared_key_at_alice,
            bytes([158, 86, 217, 29, 129, 113, 53, 211, 114, 131, 66, 131, 191, 132, 38, 156,
                   251, 49, 110, 163, 218, 128, 106, 72, 246, 218, 167, 121, 140, 254, 144, 196])
        )

        _fixed_info_at_alice = alg.compute_fixed_info(headers, enc.key_size)
        self.assertEqual(
            _fixed_info_at_alice,
            bytes([0, 0, 0, 7, 65, 49, 50, 56, 71, 67, 77, 0, 0, 0, 5, 65,
                   108, 105, 99, 101, 0, 0, 0, 3, 66, 111, 98, 0, 0, 0, 128])
        )

        _dk_at_alice = alg.compute_derived_key(_shared_key_at_alice, _fixed_info_at_alice, enc.key_size)
        self.assertEqual(_dk_at_alice, bytes([86, 170, 141, 234, 248, 35, 109, 32, 92, 34, 40, 205, 113, 167, 16, 26]))
        self.assertEqual(urlsafe_b64encode(_dk_at_alice), b'VqqN6vgjbSBcIijNcacQGg')

        # All-in-one method verification
        dk_at_alice = alg.deliver(alice_ephemeral_key, bob_static_pubkey, headers, enc.key_size)
        self.assertEqual(dk_at_alice, bytes([86, 170, 141, 234, 248, 35, 109, 32, 92, 34, 40, 205, 113, 167, 16, 26]))
        self.assertEqual(urlsafe_b64encode(dk_at_alice), b'VqqN6vgjbSBcIijNcacQGg')

        # Derived key computation at Bob

        # Step-by-step methods verification
        _shared_key_at_bob = bob_static_key.exchange_shared_key(alice_ephemeral_pubkey)
        self.assertEqual(_shared_key_at_bob, _shared_key_at_alice)

        _fixed_info_at_bob = alg.compute_fixed_info(headers, enc.key_size)
        self.assertEqual(_fixed_info_at_bob, _fixed_info_at_alice)

        _dk_at_bob = alg.compute_derived_key(_shared_key_at_bob, _fixed_info_at_bob, enc.key_size)
        self.assertEqual(_dk_at_bob, _dk_at_alice)

        # All-in-one method verification
        dk_at_bob = alg.deliver(bob_static_key, alice_ephemeral_pubkey, headers, enc.key_size)
        self.assertEqual(dk_at_bob, dk_at_alice)

    def test_ecdh_es_jwe_in_direct_key_agreement_mode(self):
        jwe = JsonWebEncryption()
        key = {
            "kty": "EC",
            "crv": "P-256",
            "x": "weNJy2HscCSM6AEDTDg04biOvhFhyyWvOHQfeF_PxMQ",
            "y": "e8lnCO-AlStT-NJVX-crhB7QRYhiix03illJOVAOyck",
            "d": "VEmDZpDXXK8p8N0Cndsxs924q6nS1RXFASRl6BfUqdw"
        }

        for enc in [
            'A128CBC-HS256',
            'A192CBC-HS384',
            'A256CBC-HS512',
            'A128GCM',
            'A192GCM',
            'A256GCM',
        ]:
            protected = {'alg': 'ECDH-ES', 'enc': enc}
            data = jwe.serialize_compact(protected, b'hello', key)
            rv = jwe.deserialize_compact(data, key)
            self.assertEqual(rv['payload'], b'hello')

    def test_ecdh_es_jwe_json_serialization_single_recipient_in_direct_key_agreement_mode(self):
        jwe = JsonWebEncryption()
        key = OKPKey.generate_key('X25519', is_private=True)

        protected = {'alg': 'ECDH-ES', 'enc': 'A128GCM'}
        header_obj = {'protected': protected}
        data = jwe.serialize_json(header_obj, b'hello', key)
        rv = jwe.deserialize_json(data, key)
        self.assertEqual(rv['payload'], b'hello')

    def test_ecdh_es_jwe_in_key_agreement_with_key_wrapping_mode(self):
        jwe = JsonWebEncryption()
        key = {
            "kty": "EC",
            "crv": "P-256",
            "x": "weNJy2HscCSM6AEDTDg04biOvhFhyyWvOHQfeF_PxMQ",
            "y": "e8lnCO-AlStT-NJVX-crhB7QRYhiix03illJOVAOyck",
            "d": "VEmDZpDXXK8p8N0Cndsxs924q6nS1RXFASRl6BfUqdw"
        }

        for alg in [
            'ECDH-ES+A128KW',
            'ECDH-ES+A192KW',
            'ECDH-ES+A256KW',
        ]:
            for enc in [
                'A128CBC-HS256',
                'A192CBC-HS384',
                'A256CBC-HS512',
                'A128GCM',
                'A192GCM',
                'A256GCM',
            ]:
                protected = {'alg': alg, 'enc': enc}
                data = jwe.serialize_compact(protected, b'hello', key)
                rv = jwe.deserialize_compact(data, key)
                self.assertEqual(rv['payload'], b'hello')

    def test_ecdh_es_jwe_with_okp_key_in_direct_key_agreement_mode(self):
        jwe = JsonWebEncryption()
        key = OKPKey.generate_key('X25519', is_private=True)

        for enc in [
            'A128CBC-HS256',
            'A192CBC-HS384',
            'A256CBC-HS512',
            'A128GCM',
            'A192GCM',
            'A256GCM',
        ]:
            protected = {'alg': 'ECDH-ES', 'enc': enc}
            data = jwe.serialize_compact(protected, b'hello', key)
            rv = jwe.deserialize_compact(data, key)
            self.assertEqual(rv['payload'], b'hello')

    def test_ecdh_es_jwe_with_okp_key_in_key_agreement_with_key_wrapping_mode(self):
        jwe = JsonWebEncryption()
        key = OKPKey.generate_key('X25519', is_private=True)

        for alg in [
            'ECDH-ES+A128KW',
            'ECDH-ES+A192KW',
            'ECDH-ES+A256KW',
        ]:
            for enc in [
                'A128CBC-HS256',
                'A192CBC-HS384',
                'A256CBC-HS512',
                'A128GCM',
                'A192GCM',
                'A256GCM',
            ]:
                protected = {'alg': alg, 'enc': enc}
                data = jwe.serialize_compact(protected, b'hello', key)
                rv = jwe.deserialize_compact(data, key)
                self.assertEqual(rv['payload'], b'hello')

    def test_ecdh_es_jwe_with_json_serialization_when_kid_is_not_specified(self):
        jwe = JsonWebEncryption()

        bob_key = OKPKey.import_key({
            "kty": "OKP",
            "crv": "X25519",
            "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw",
            "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg"
        })
        charlie_key = OKPKey.import_key({
            "kty": "OKP",
            "crv": "X25519",
            "x": "q-LsvU772uV_2sPJhfAIq-3vnKNVefNoIlvyvg1hrnE",
            "d": "Jcv8gklhMjC0b-lsk5onBbppWAx5ncNtbM63Jr9xBQE"
        })

        protected = {
            "alg": "ECDH-ES+A256KW",
            "enc": "A256GCM",
            "apu": "QWxpY2U",
            "apv": "Qm9iIGFuZCBDaGFybGll"
        }

        unprotected = {
            "jku": "https://alice.example.com/keys.jwks"
        }

        recipients = [
            {
                "header": {
                    "kid": "bob-key-2"
                }
            },
            {
                "header": {
                    "kid": "2021-05-06"
                }
            }
        ]

        jwe_aad = b'Authenticate me too.'

        header_obj = {
            "protected": protected,
            "unprotected": unprotected,
            "recipients": recipients,
            "aad": jwe_aad
        }

        payload = b'Three is a magic number.'

        data = jwe.serialize_json(header_obj, payload, [bob_key, charlie_key])

        rv_at_bob = jwe.deserialize_json(data, bob_key)

        self.assertEqual(rv_at_bob['header']['protected'].keys(), protected.keys() | {'epk'})
        self.assertEqual(
            {k: rv_at_bob['header']['protected'][k] for k in rv_at_bob['header']['protected'].keys() - {'epk'}},
            protected
        )
        self.assertEqual(rv_at_bob['header']['unprotected'], unprotected)
        self.assertEqual(rv_at_bob['header']['recipients'], recipients)
        self.assertEqual(rv_at_bob['header']['aad'], jwe_aad)
        self.assertEqual(rv_at_bob['payload'], payload)

        rv_at_charlie = jwe.deserialize_json(data, charlie_key)

        self.assertEqual(rv_at_charlie['header']['protected'].keys(), protected.keys() | {'epk'})
        self.assertEqual(
            {k: rv_at_charlie['header']['protected'][k] for k in rv_at_charlie['header']['protected'].keys() - {'epk'}},
            protected
        )
        self.assertEqual(rv_at_charlie['header']['unprotected'], unprotected)
        self.assertEqual(rv_at_charlie['header']['recipients'], recipients)
        self.assertEqual(rv_at_charlie['header']['aad'], jwe_aad)
        self.assertEqual(rv_at_charlie['payload'], payload)

    def test_ecdh_es_jwe_with_json_serialization_when_kid_is_specified(self):
        jwe = JsonWebEncryption()

        bob_key = OKPKey.import_key({
            "kty": "OKP",
            "crv": "X25519",
            "kid": "bob-key-2",
            "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw",
            "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg"
        })
        charlie_key = OKPKey.import_key({
            "kty": "OKP",
            "crv": "X25519",
            "kid": "2021-05-06",
            "x": "q-LsvU772uV_2sPJhfAIq-3vnKNVefNoIlvyvg1hrnE",
            "d": "Jcv8gklhMjC0b-lsk5onBbppWAx5ncNtbM63Jr9xBQE"
        })

        protected = {
            "alg": "ECDH-ES+A256KW",
            "enc": "A256GCM",
            "apu": "QWxpY2U",
            "apv": "Qm9iIGFuZCBDaGFybGll"
        }

        unprotected = {
            "jku": "https://alice.example.com/keys.jwks"
        }

        recipients = [
            {
                "header": {
                    "kid": "bob-key-2"
                }
            },
            {
                "header": {
                    "kid": "2021-05-06"
                }
            }
        ]

        jwe_aad = b'Authenticate me too.'

        header_obj = {
            "protected": protected,
            "unprotected": unprotected,
            "recipients": recipients,
            "aad": jwe_aad
        }

        payload = b'Three is a magic number.'

        data = jwe.serialize_json(header_obj, payload, [bob_key, charlie_key])

        rv_at_bob = jwe.deserialize_json(data, bob_key)

        self.assertEqual(rv_at_bob['header']['protected'].keys(), protected.keys() | {'epk'})
        self.assertEqual(
            {k: rv_at_bob['header']['protected'][k] for k in rv_at_bob['header']['protected'].keys() - {'epk'}},
            protected
        )
        self.assertEqual(rv_at_bob['header']['unprotected'], unprotected)
        self.assertEqual(rv_at_bob['header']['recipients'], recipients)
        self.assertEqual(rv_at_bob['header']['aad'], jwe_aad)
        self.assertEqual(rv_at_bob['payload'], payload)

        rv_at_charlie = jwe.deserialize_json(data, charlie_key)

        self.assertEqual(rv_at_charlie['header']['protected'].keys(), protected.keys() | {'epk'})
        self.assertEqual(
            {k: rv_at_charlie['header']['protected'][k] for k in rv_at_charlie['header']['protected'].keys() - {'epk'}},
            protected
        )
        self.assertEqual(rv_at_charlie['header']['unprotected'], unprotected)
        self.assertEqual(rv_at_charlie['header']['recipients'], recipients)
        self.assertEqual(rv_at_charlie['header']['aad'], jwe_aad)
        self.assertEqual(rv_at_charlie['payload'], payload)

    def test_ecdh_es_jwe_with_json_serialization_for_single_recipient(self):
        jwe = JsonWebEncryption()

        key = OKPKey.import_key({
            "kty": "OKP",
            "crv": "X25519",
            "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw",
            "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg"
        })

        protected = {
            "alg": "ECDH-ES+A256KW",
            "enc": "A256GCM",
            "apu": "QWxpY2U",
            "apv": "Qm9i"
        }

        unprotected = {
            "jku": "https://alice.example.com/keys.jwks"
        }

        recipients = [
            {
                "header": {
                    "kid": "bob-key-2"
                }
            }
        ]

        jwe_aad = b'Authenticate me too.'

        header_obj = {
            "protected": protected,
            "unprotected": unprotected,
            "recipients": recipients,
            "aad": jwe_aad
        }

        payload = b'Three is a magic number.'

        data = jwe.serialize_json(header_obj, payload, key)

        rv = jwe.deserialize_json(data, key)

        self.assertEqual(rv['header']['protected'].keys(), protected.keys() | {'epk'})
        self.assertEqual(
            {k: rv['header']['protected'][k] for k in rv['header']['protected'].keys() - {'epk'}},
            protected
        )
        self.assertEqual(rv['header']['unprotected'], unprotected)
        self.assertEqual(rv['header']['recipients'], recipients)
        self.assertEqual(rv['header']['aad'], jwe_aad)
        self.assertEqual(rv['payload'], payload)

    def test_ecdh_es_encryption_fails_json_serialization_multiple_recipients_in_direct_key_agreement_mode(self):
        jwe = JsonWebEncryption()
        bob_key = OKPKey.generate_key('X25519', is_private=True)
        charlie_key = OKPKey.generate_key('X25519', is_private=True)

        protected = {'alg': 'ECDH-ES', 'enc': 'A128GCM'}
        header_obj = {'protected': protected}
        self.assertRaises(
            InvalidAlgorithmForMultipleRecipientsMode,
            jwe.serialize_json,
            header_obj, b'hello', [bob_key, charlie_key]
        )

    def test_ecdh_es_decryption_with_public_key_fails(self):
        jwe = JsonWebEncryption()
        protected = {'alg': 'ECDH-ES', 'enc': 'A128GCM'}

        key = {
            "kty": "EC",
            "crv": "P-256",
            "x": "weNJy2HscCSM6AEDTDg04biOvhFhyyWvOHQfeF_PxMQ",
            "y": "e8lnCO-AlStT-NJVX-crhB7QRYhiix03illJOVAOyck"
        }
        data = jwe.serialize_compact(protected, b'hello', key)
        self.assertRaises(
            ValueError,
            jwe.deserialize_compact,
            data, key
        )

    def test_ecdh_es_encryption_fails_if_key_curve_is_inappropriate(self):
        jwe = JsonWebEncryption()
        protected = {'alg': 'ECDH-ES', 'enc': 'A128GCM'}

        key = OKPKey.generate_key('Ed25519', is_private=False)
        self.assertRaises(
            ValueError,
            jwe.serialize_compact,
            protected, b'hello', key
        )

    def test_ecdh_es_decryption_fails_if_key_matches_to_no_recipient(self):
        jwe = JsonWebEncryption()

        bob_key = OKPKey.import_key({
            "kty": "OKP",
            "crv": "X25519",
            "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw",
            "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg"
        })
        charlie_key = OKPKey.import_key({
            "kty": "OKP",
            "crv": "X25519",
            "x": "q-LsvU772uV_2sPJhfAIq-3vnKNVefNoIlvyvg1hrnE",
            "d": "Jcv8gklhMjC0b-lsk5onBbppWAx5ncNtbM63Jr9xBQE"
        })

        protected = {
            "alg": "ECDH-ES+A256KW",
            "enc": "A256GCM",
            "apu": "QWxpY2U",
            "apv": "Qm9i"
        }

        unprotected = {
            "jku": "https://alice.example.com/keys.jwks"
        }

        recipients = [
            {
                "header": {
                    "kid": "bob-key-2"
                }
            }
        ]

        jwe_aad = b'Authenticate me too.'

        header_obj = {
            "protected": protected,
            "unprotected": unprotected,
            "recipients": recipients,
            "aad": jwe_aad
        }

        payload = b'Three is a magic number.'

        data = jwe.serialize_json(header_obj, payload, bob_key)

        self.assertRaises(
            InvalidUnwrap,
            jwe.deserialize_json,
            data, charlie_key
        )

    def test_decryption_with_json_serialization_succeeds_while_encrypted_key_for_another_recipient_is_invalid(self):
        jwe = JsonWebEncryption()

        alice_key = OKPKey.import_key({
            "kid": "Alice's key",
            "kty": "OKP",
            "crv": "X25519",
            "x": "Knbm_BcdQr7WIoz-uqit9M0wbcfEr6y-9UfIZ8QnBD4",
            "d": "i9KuFhSzEBsiv3PKVL5115OCdsqQai5nj_Flzfkw5jU"
        })
        bob_key = OKPKey.import_key({
            "kid": "Bob's key",
            "kty": "OKP",
            "crv": "X25519",
            "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw",
            "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg"
        })
        charlie_key = OKPKey.import_key({
            "kid": "Charlie's key",
            "kty": "OKP",
            "crv": "X25519",
            "x": "q-LsvU772uV_2sPJhfAIq-3vnKNVefNoIlvyvg1hrnE",
            "d": "Jcv8gklhMjC0b-lsk5onBbppWAx5ncNtbM63Jr9xBQE"
        })

        data = {
            "protected": "eyJhbGciOiJFQ0RILTFQVStBMTI4S1ciLCJlbmMiOiJBMjU2Q0JDLUhTNTEyIiwiYXB1Ijoi" +
                         "UVd4cFkyVSIsImFwdiI6IlFtOWlJR0Z1WkNCRGFHRnliR2xsIiwiZXBrIjp7Imt0eSI6Ik9L" +
                         "UCIsImNydiI6IlgyNTUxOSIsIngiOiJrOW9mX2NwQWFqeTBwb1c1Z2FpeFhHczluSGt3ZzFB" +
                         "RnFVQUZhMzlkeUJjIn19",
            "unprotected": {
                "jku": "https://alice.example.com/keys.jwks"
            },
            "recipients": [
                {
                    "header": {
                        "kid": "Bob's key"
                    },
                    "encrypted_key": "pOMVA9_PtoRe7xXW1139NzzN1UhiFoio8lGto9cf0t8PyU-sjNXH8-LIRLycq8CHJQbDwvQ" +
                                     "eU1cSl55cQ0hGezJu2N9IY0QM"  # Invalid encrypted key
                },
                {
                    "header": {
                        "kid": "Charlie's key"
                    },
                    "encrypted_key": "56GVudgRLIMEElQ7DpXsijJVRSWUSDNdbWkdV3g0GUNq6hcT_GkxwnxlPIWrTXCqRpVKQC8" +
                                     "fe4z3PQ2YH2afvjQ28aiCTWFE"  # Valid encrypted key
                }
            ],
            "iv": "AAECAwQFBgcICQoLDA0ODw",
            "ciphertext": "Az2IWsISEMDJvyc5XRL-3-d-RgNBOGolCsxFFoUXFYw",
            "tag": "HLb4fTlm8spGmij3RyOs2gJ4DpHM4hhVRwdF_hGb3WQ"
        }

        rv_at_charlie = jwe.deserialize_json(data, charlie_key, sender_key=alice_key)

        self.assertEqual(rv_at_charlie.keys(), {'header', 'payload'})

        self.assertEqual(rv_at_charlie['header'].keys(), {'protected', 'unprotected', 'recipients'})

        self.assertEqual(
            rv_at_charlie['header']['protected'],
            {
                "alg": "ECDH-1PU+A128KW",
                "enc": "A256CBC-HS512",
                "apu": "QWxpY2U",
                "apv": "Qm9iIGFuZCBDaGFybGll",
                "epk": {
                    "kty": "OKP",
                    "crv": "X25519",
                    "x": "k9of_cpAajy0poW5gaixXGs9nHkwg1AFqUAFa39dyBc"
                }
            }
        )

        self.assertEqual(
            rv_at_charlie['header']['unprotected'],
            {
                "jku": "https://alice.example.com/keys.jwks"
            }
        )

        self.assertEqual(
            rv_at_charlie['header']['recipients'],
            [
                {
                    "header": {
                        "kid": "Bob's key"
                    }
                },
                {
                    "header": {
                        "kid": "Charlie's key"
                    }
                }
            ]
        )

        self.assertEqual(rv_at_charlie['payload'], b'Three is a magic number.')

    def test_decryption_with_json_serialization_fails_if_encrypted_key_for_this_recipient_is_invalid(self):
        jwe = JsonWebEncryption()

        alice_key = OKPKey.import_key({
            "kid": "Alice's key",
            "kty": "OKP",
            "crv": "X25519",
            "x": "Knbm_BcdQr7WIoz-uqit9M0wbcfEr6y-9UfIZ8QnBD4",
            "d": "i9KuFhSzEBsiv3PKVL5115OCdsqQai5nj_Flzfkw5jU"
        })
        bob_key = OKPKey.import_key({
            "kid": "Bob's key",
            "kty": "OKP",
            "crv": "X25519",
            "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw",
            "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg"
        })
        charlie_key = OKPKey.import_key({
            "kid": "Charlie's key",
            "kty": "OKP",
            "crv": "X25519",
            "x": "q-LsvU772uV_2sPJhfAIq-3vnKNVefNoIlvyvg1hrnE",
            "d": "Jcv8gklhMjC0b-lsk5onBbppWAx5ncNtbM63Jr9xBQE"
        })

        data = {
            "protected": "eyJhbGciOiJFQ0RILTFQVStBMTI4S1ciLCJlbmMiOiJBMjU2Q0JDLUhTNTEyIiwiYXB1Ijoi" +
                         "UVd4cFkyVSIsImFwdiI6IlFtOWlJR0Z1WkNCRGFHRnliR2xsIiwiZXBrIjp7Imt0eSI6Ik9L" +
                         "UCIsImNydiI6IlgyNTUxOSIsIngiOiJrOW9mX2NwQWFqeTBwb1c1Z2FpeFhHczluSGt3ZzFB" +
                         "RnFVQUZhMzlkeUJjIn19",
            "unprotected": {
                "jku": "https://alice.example.com/keys.jwks"
            },
            "recipients": [
                {
                    "header": {
                        "kid": "Bob's key"
                    },
                    "encrypted_key": "pOMVA9_PtoRe7xXW1139NzzN1UhiFoio8lGto9cf0t8PyU-sjNXH8-LIRLycq8CHJQbDwvQ" +
                                     "eU1cSl55cQ0hGezJu2N9IY0QM"  # Invalid encrypted key
                },
                {
                    "header": {
                        "kid": "Charlie's key"
                    },
                    "encrypted_key": "56GVudgRLIMEElQ7DpXsijJVRSWUSDNdbWkdV3g0GUNq6hcT_GkxwnxlPIWrTXCqRpVKQC8" +
                                     "fe4z3PQ2YH2afvjQ28aiCTWFE"  # Valid encrypted key
                }
            ],
            "iv": "AAECAwQFBgcICQoLDA0ODw",
            "ciphertext": "Az2IWsISEMDJvyc5XRL-3-d-RgNBOGolCsxFFoUXFYw",
            "tag": "HLb4fTlm8spGmij3RyOs2gJ4DpHM4hhVRwdF_hGb3WQ"
        }

        self.assertRaises(
            InvalidUnwrap,
            jwe.deserialize_json,
            data, bob_key, sender_key=alice_key
        )

    def test_dir_alg(self):
        jwe = JsonWebEncryption()
        key = OctKey.generate_key(128, is_private=True)
        protected = {'alg': 'dir', 'enc': 'A128GCM'}
        data = jwe.serialize_compact(protected, b'hello', key)
        rv = jwe.deserialize_compact(data, key)
        self.assertEqual(rv['payload'], b'hello')

        key2 = OctKey.generate_key(256, is_private=True)
        self.assertRaises(ValueError, jwe.deserialize_compact, data, key2)

        self.assertRaises(
            ValueError,
            jwe.serialize_compact,
            protected, b'hello', key2
        )

    def test_decryption_of_message_to_multiple_recipients_by_matching_key(self):
        jwe = JsonWebEncryption()

        alice_public_key_id = "did:example:123#WjKgJV7VRw3hmgU6--4v15c0Aewbcvat1BsRFTIqa5Q"
        alice_public_key = OKPKey.import_key({
            "kid": "WjKgJV7VRw3hmgU6--4v15c0Aewbcvat1BsRFTIqa5Q",
            "kty": "OKP",
            "crv": "X25519",
            "x": "Knbm_BcdQr7WIoz-uqit9M0wbcfEr6y-9UfIZ8QnBD4"
        })

        key_store = {}

        charlie_X448_key_id = "did:example:123#_TKzHv2jFIyvdTGF1Dsgwngfdg3SH6TpDv0Ta1aOEkw"
        charlie_X448_key = OKPKey.import_key({
            "kid": "_TKzHv2jFIyvdTGF1Dsgwngfdg3SH6TpDv0Ta1aOEkw",
            "kty": "OKP",
            "crv": "X448",
            "x": "M-OMugy74ksznVQ-Bp6MC_-GEPSrT8yiAtminJvw0j_UxJtpNHl_hcWMSf_Pfm_ws0vVWvAfwwA",
            "d": "VGZPkclj_7WbRaRMzBqxpzXIpc2xz1d3N1ay36UxdVLfKaP33hABBMpddTRv1f-hRsQUNvmlGOg"
        })
        key_store[charlie_X448_key_id] = charlie_X448_key

        charlie_X25519_key_id = "did:example:123#ZC2jXTO6t4R501bfCXv3RxarZyUbdP2w_psLwMuY6ec"
        charlie_X25519_key = OKPKey.import_key({
            "kid": "ZC2jXTO6t4R501bfCXv3RxarZyUbdP2w_psLwMuY6ec",
            "kty": "OKP",
            "crv": "X25519",
            "x": "q-LsvU772uV_2sPJhfAIq-3vnKNVefNoIlvyvg1hrnE",
            "d": "Jcv8gklhMjC0b-lsk5onBbppWAx5ncNtbM63Jr9xBQE"
        })
        key_store[charlie_X25519_key_id] = charlie_X25519_key

        data = """
            {
                "protected": "eyJhbGciOiJFQ0RILTFQVStBMTI4S1ciLCJlbmMiOiJBMjU2Q0JDLUhTNTEyIiwiYXB1IjoiUVd4cFkyVSIsImFwdiI6IlFtOWlJR0Z1WkNCRGFHRnliR2xsIiwiZXBrIjp7Imt0eSI6Ik9LUCIsImNydiI6IlgyNTUxOSIsIngiOiJrOW9mX2NwQWFqeTBwb1c1Z2FpeFhHczluSGt3ZzFBRnFVQUZhMzlkeUJjIn19",
                "unprotected": {
                    "jku": "https://alice.example.com/keys.jwks"
                },
                "recipients": [
                    {
                        "header": {
                            "kid": "did:example:123#_Qq0UL2Fq651Q0Fjd6TvnYE-faHiOpRlPVQcY_-tA4A"
                        },
                        "encrypted_key": "pOMVA9_PtoRe7xXW1139NzzN1UhiFoio8lGto9cf0t8PyU-sjNXH8-LIRLycq8CHJQbDwvQeU1cSl55cQ0hGezJu2N9IY0QN"
                    },
                    {
                        "header": {
                            "kid": "did:example:123#ZC2jXTO6t4R501bfCXv3RxarZyUbdP2w_psLwMuY6ec"
                        },
                        "encrypted_key": "56GVudgRLIMEElQ7DpXsijJVRSWUSDNdbWkdV3g0GUNq6hcT_GkxwnxlPIWrTXCqRpVKQC8fe4z3PQ2YH2afvjQ28aiCTWFE"
                    }
                ],
                "iv": "AAECAwQFBgcICQoLDA0ODw",
                "ciphertext": "Az2IWsISEMDJvyc5XRL-3-d-RgNBOGolCsxFFoUXFYw",
                "tag": "HLb4fTlm8spGmij3RyOs2gJ4DpHM4hhVRwdF_hGb3WQ"
            }"""

        parsed_data = jwe.parse_json(data)

        available_key_id = next(recipient['header']['kid'] for recipient in parsed_data['recipients']
                                if recipient['header']['kid'] in key_store.keys())
        available_key = key_store[available_key_id]

        rv = jwe.deserialize_json(parsed_data, (available_key_id, available_key), sender_key=alice_public_key)

        self.assertEqual(rv.keys(), {'header', 'payload'})

        self.assertEqual(rv['header'].keys(), {'protected', 'unprotected', 'recipients'})

        self.assertEqual(
            rv['header']['protected'],
            {
                "alg": "ECDH-1PU+A128KW",
                "enc": "A256CBC-HS512",
                "apu": "QWxpY2U",
                "apv": "Qm9iIGFuZCBDaGFybGll",
                "epk": {
                    "kty": "OKP",
                    "crv": "X25519",
                    "x": "k9of_cpAajy0poW5gaixXGs9nHkwg1AFqUAFa39dyBc"
                }
            }
        )

        self.assertEqual(
            rv['header']['unprotected'],
            {
                "jku": "https://alice.example.com/keys.jwks"
            }
        )

        self.assertEqual(
            rv['header']['recipients'],
            [
                {
                    "header": {
                        "kid": "did:example:123#_Qq0UL2Fq651Q0Fjd6TvnYE-faHiOpRlPVQcY_-tA4A"
                    }
                },
                {
                    "header": {
                        "kid": "did:example:123#ZC2jXTO6t4R501bfCXv3RxarZyUbdP2w_psLwMuY6ec"
                    }
                }
            ]
        )

        self.assertEqual(rv['payload'], b'Three is a magic number.')

    def test_decryption_of_json_string(self):
        jwe = JsonWebEncryption()

        alice_key = OKPKey.import_key({
            "kty": "OKP",
            "crv": "X25519",
            "x": "Knbm_BcdQr7WIoz-uqit9M0wbcfEr6y-9UfIZ8QnBD4",
            "d": "i9KuFhSzEBsiv3PKVL5115OCdsqQai5nj_Flzfkw5jU"
        })
        bob_key = OKPKey.import_key({
            "kty": "OKP",
            "crv": "X25519",
            "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw",
            "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg"
        })
        charlie_key = OKPKey.import_key({
            "kty": "OKP",
            "crv": "X25519",
            "x": "q-LsvU772uV_2sPJhfAIq-3vnKNVefNoIlvyvg1hrnE",
            "d": "Jcv8gklhMjC0b-lsk5onBbppWAx5ncNtbM63Jr9xBQE"
        })

        data = """
            {
                "protected": "eyJhbGciOiJFQ0RILTFQVStBMTI4S1ciLCJlbmMiOiJBMjU2Q0JDLUhTNTEyIiwiYXB1IjoiUVd4cFkyVSIsImFwdiI6IlFtOWlJR0Z1WkNCRGFHRnliR2xsIiwiZXBrIjp7Imt0eSI6Ik9LUCIsImNydiI6IlgyNTUxOSIsIngiOiJrOW9mX2NwQWFqeTBwb1c1Z2FpeFhHczluSGt3ZzFBRnFVQUZhMzlkeUJjIn19",
                "unprotected": {
                    "jku": "https://alice.example.com/keys.jwks"
                },
                "recipients": [
                    {
                        "header": {
                            "kid": "bob-key-2"
                        },
                        "encrypted_key": "pOMVA9_PtoRe7xXW1139NzzN1UhiFoio8lGto9cf0t8PyU-sjNXH8-LIRLycq8CHJQbDwvQeU1cSl55cQ0hGezJu2N9IY0QN"
                    },
                    {
                        "header": {
                            "kid": "2021-05-06"
                        },
                        "encrypted_key": "56GVudgRLIMEElQ7DpXsijJVRSWUSDNdbWkdV3g0GUNq6hcT_GkxwnxlPIWrTXCqRpVKQC8fe4z3PQ2YH2afvjQ28aiCTWFE"
                    }
                ],
                "iv": "AAECAwQFBgcICQoLDA0ODw",
                "ciphertext": "Az2IWsISEMDJvyc5XRL-3-d-RgNBOGolCsxFFoUXFYw",
                "tag": "HLb4fTlm8spGmij3RyOs2gJ4DpHM4hhVRwdF_hGb3WQ"
            }"""

        rv_at_bob = jwe.deserialize_json(data, bob_key, sender_key=alice_key)

        self.assertEqual(rv_at_bob.keys(), {'header', 'payload'})

        self.assertEqual(rv_at_bob['header'].keys(), {'protected', 'unprotected', 'recipients'})

        self.assertEqual(
            rv_at_bob['header']['protected'],
            {
                "alg": "ECDH-1PU+A128KW",
                "enc": "A256CBC-HS512",
                "apu": "QWxpY2U",
                "apv": "Qm9iIGFuZCBDaGFybGll",
                "epk": {
                    "kty": "OKP",
                    "crv": "X25519",
                    "x": "k9of_cpAajy0poW5gaixXGs9nHkwg1AFqUAFa39dyBc"
                }
            }
        )

        self.assertEqual(
            rv_at_bob['header']['unprotected'],
            {
                "jku": "https://alice.example.com/keys.jwks"
            }
        )

        self.assertEqual(
            rv_at_bob['header']['recipients'],
            [
                {
                    "header": {
                        "kid": "bob-key-2"
                    }
                },
                {
                    "header": {
                        "kid": "2021-05-06"
                    }
                }
            ]
        )

        self.assertEqual(rv_at_bob['payload'], b'Three is a magic number.')

        rv_at_charlie = jwe.deserialize_json(data, charlie_key, sender_key=alice_key)

        self.assertEqual(rv_at_charlie.keys(), {'header', 'payload'})

        self.assertEqual(rv_at_charlie['header'].keys(), {'protected', 'unprotected', 'recipients'})

        self.assertEqual(
            rv_at_charlie['header']['protected'],
            {
                "alg": "ECDH-1PU+A128KW",
                "enc": "A256CBC-HS512",
                "apu": "QWxpY2U",
                "apv": "Qm9iIGFuZCBDaGFybGll",
                "epk": {
                    "kty": "OKP",
                    "crv": "X25519",
                    "x": "k9of_cpAajy0poW5gaixXGs9nHkwg1AFqUAFa39dyBc"
                }
            }
        )

        self.assertEqual(
            rv_at_charlie['header']['unprotected'],
            {
                "jku": "https://alice.example.com/keys.jwks"
            }
        )

        self.assertEqual(
            rv_at_charlie['header']['recipients'],
            [
                {
                    "header": {
                        "kid": "bob-key-2"
                    }
                },
                {
                    "header": {
                        "kid": "2021-05-06"
                    }
                }
            ]
        )

        self.assertEqual(rv_at_charlie['payload'], b'Three is a magic number.')

    def test_parse_json(self):

        json_msg = """
            {
                "protected": "eyJhbGciOiJFQ0RILTFQVStBMTI4S1ciLCJlbmMiOiJBMjU2Q0JDLUhTNTEyIiwiYXB1IjoiUVd4cFkyVSIsImFwdiI6IlFtOWlJR0Z1WkNCRGFHRnliR2xsIiwiZXBrIjp7Imt0eSI6Ik9LUCIsImNydiI6IlgyNTUxOSIsIngiOiJrOW9mX2NwQWFqeTBwb1c1Z2FpeFhHczluSGt3ZzFBRnFVQUZhMzlkeUJjIn19",
                "unprotected": {
                    "jku": "https://alice.example.com/keys.jwks"
                },
                "recipients": [
                    {
                        "header": {
                            "kid": "bob-key-2"
                        },
                        "encrypted_key": "pOMVA9_PtoRe7xXW1139NzzN1UhiFoio8lGto9cf0t8PyU-sjNXH8-LIRLycq8CHJQbDwvQeU1cSl55cQ0hGezJu2N9IY0QN"
                    },
                    {
                        "header": {
                            "kid": "2021-05-06"
                        },
                        "encrypted_key": "56GVudgRLIMEElQ7DpXsijJVRSWUSDNdbWkdV3g0GUNq6hcT_GkxwnxlPIWrTXCqRpVKQC8fe4z3PQ2YH2afvjQ28aiCTWFE"
                    }
                ],
                "iv": "AAECAwQFBgcICQoLDA0ODw",
                "ciphertext": "Az2IWsISEMDJvyc5XRL-3-d-RgNBOGolCsxFFoUXFYw",
                "tag": "HLb4fTlm8spGmij3RyOs2gJ4DpHM4hhVRwdF_hGb3WQ"
            }"""

        parsed_msg = JsonWebEncryption.parse_json(json_msg)

        self.assertEqual(
            parsed_msg,
            {
                "protected": "eyJhbGciOiJFQ0RILTFQVStBMTI4S1ciLCJlbmMiOiJBMjU2Q0JDLUhTNTEyIiwiYXB1IjoiUVd4cFkyVSIsImFwdiI6IlFtOWlJR0Z1WkNCRGFHRnliR2xsIiwiZXBrIjp7Imt0eSI6Ik9LUCIsImNydiI6IlgyNTUxOSIsIngiOiJrOW9mX2NwQWFqeTBwb1c1Z2FpeFhHczluSGt3ZzFBRnFVQUZhMzlkeUJjIn19",
                "unprotected": {
                    "jku": "https://alice.example.com/keys.jwks"
                },
                "recipients": [
                    {
                        "header": {
                            "kid": "bob-key-2"
                        },
                        "encrypted_key": "pOMVA9_PtoRe7xXW1139NzzN1UhiFoio8lGto9cf0t8PyU-sjNXH8-LIRLycq8CHJQbDwvQeU1cSl55cQ0hGezJu2N9IY0QN"
                    },
                    {
                        "header": {
                            "kid": "2021-05-06"
                        },
                        "encrypted_key": "56GVudgRLIMEElQ7DpXsijJVRSWUSDNdbWkdV3g0GUNq6hcT_GkxwnxlPIWrTXCqRpVKQC8fe4z3PQ2YH2afvjQ28aiCTWFE"
                    }
                ],
                "iv": "AAECAwQFBgcICQoLDA0ODw",
                "ciphertext": "Az2IWsISEMDJvyc5XRL-3-d-RgNBOGolCsxFFoUXFYw",
                "tag": "HLb4fTlm8spGmij3RyOs2gJ4DpHM4hhVRwdF_hGb3WQ"
            }
        )

    def test_parse_json_fails_if_json_msg_is_invalid(self):

        json_msg = """
            {
                "protected": "eyJhbGciOiJFQ0RILTFQVStBMTI4S1ciLCJlbmMiOiJBMjU2Q0JDLUhTNTEyIiwiYXB1IjoiUVd4cFkyVSIsImFwdiI6IlFtOWlJR0Z1WkNCRGFHRnliR2xsIiwiZXBrIjp7Imt0eSI6Ik9LUCIsImNydiI6IlgyNTUxOSIsIngiOiJrOW9mX2NwQWFqeTBwb1c1Z2FpeFhHczluSGt3ZzFBRnFVQUZhMzlkeUJjIn19",
                "unprotected": {
                    "jku": "https://alice.example.com/keys.jwks"
                },
                "recipients": [
                    {
                        "header": {
                            "kid": "bob-key-2"
                        ,
                        "encrypted_key": "pOMVA9_PtoRe7xXW1139NzzN1UhiFoio8lGto9cf0t8PyU-sjNXH8-LIRLycq8CHJQbDwvQeU1cSl55cQ0hGezJu2N9IY0QN"
                    },
                    {
                        "header": {
                            "kid": "2021-05-06"
                        },
                        "encrypted_key": "56GVudgRLIMEElQ7DpXsijJVRSWUSDNdbWkdV3g0GUNq6hcT_GkxwnxlPIWrTXCqRpVKQC8fe4z3PQ2YH2afvjQ28aiCTWFE"
                    }
                ],
                "iv": "AAECAwQFBgcICQoLDA0ODw",
                "ciphertext": "Az2IWsISEMDJvyc5XRL-3-d-RgNBOGolCsxFFoUXFYw",
                "tag": "HLb4fTlm8spGmij3RyOs2gJ4DpHM4hhVRwdF_hGb3WQ"
            }"""

        self.assertRaises(
            DecodeError,
            JsonWebEncryption.parse_json,
            json_msg
        )

    def test_decryption_fails_if_ciphertext_is_invalid(self):
        jwe = JsonWebEncryption()

        alice_key = OKPKey.import_key({
            "kty": "OKP",
            "crv": "X25519",
            "x": "Knbm_BcdQr7WIoz-uqit9M0wbcfEr6y-9UfIZ8QnBD4",
            "d": "i9KuFhSzEBsiv3PKVL5115OCdsqQai5nj_Flzfkw5jU"
        })
        bob_key = OKPKey.import_key({
            "kty": "OKP",
            "crv": "X25519",
            "x": "BT7aR0ItXfeDAldeeOlXL_wXqp-j5FltT0vRSG16kRw",
            "d": "1gDirl_r_Y3-qUa3WXHgEXrrEHngWThU3c9zj9A2uBg"
        })

        data = {
            "protected": "eyJhbGciOiJFQ0RILTFQVStBMTI4S1ciLCJlbmMiOiJBMjU2Q0JDLUhTNTEyIiwiYXB1Ijoi" +
                         "UVd4cFkyVSIsImFwdiI6IlFtOWlJR0Z1WkNCRGFHRnliR2xsIiwiZXBrIjp7Imt0eSI6Ik9L" +
                         "UCIsImNydiI6IlgyNTUxOSIsIngiOiJrOW9mX2NwQWFqeTBwb1c1Z2FpeFhHczluSGt3ZzFB" +
                         "RnFVQUZhMzlkeUJjIn19",
            "unprotected": {
                "jku": "https://alice.example.com/keys.jwks"
            },
            "recipients": [
                {
                    "header": {
                        "kid": "bob-key-2"
                    },
                    "encrypted_key": "pOMVA9_PtoRe7xXW1139NzzN1UhiFoio8lGto9cf0t8PyU-sjNXH8-LIRLycq8CHJQbDwvQ" +
                                     "eU1cSl55cQ0hGezJu2N9IY0QN"
                }
            ],
            "iv": "AAECAwQFBgcICQoLDA0ODw",
            "ciphertext": "Az2IWsISEMDJvyc5XRL-3-d-RgNBOGolCsxFFoUXFY",  # invalid ciphertext
            "tag": "HLb4fTlm8spGmij3RyOs2gJ4DpHM4hhVRwdF_hGb3WQ"
        }

        self.assertRaises(
            Exception,
            jwe.deserialize_json,
            data, bob_key, sender_key=alice_key
        )

    def test_generic_serialize_deserialize_for_compact_serialization(self):
        jwe = JsonWebEncryption()

        alice_key = OKPKey.generate_key('X25519', is_private=True)
        bob_key = OKPKey.generate_key('X25519', is_private=True)

        header_obj = {'alg': 'ECDH-1PU+A128KW', 'enc': 'A128CBC-HS256'}

        data = jwe.serialize(header_obj, b'hello', bob_key, sender_key=alice_key)
        self.assertIsInstance(data, bytes)

        rv = jwe.deserialize(data, bob_key, sender_key=alice_key)
        self.assertEqual(rv['payload'], b'hello')

    def test_generic_serialize_deserialize_for_json_serialization(self):
        jwe = JsonWebEncryption()

        alice_key = OKPKey.generate_key('X25519', is_private=True)
        bob_key = OKPKey.generate_key('X25519', is_private=True)

        protected = {'alg': 'ECDH-1PU+A128KW', 'enc': 'A128CBC-HS256'}
        header_obj = {'protected': protected}

        data = jwe.serialize(header_obj, b'hello', bob_key, sender_key=alice_key)
        self.assertIsInstance(data, dict)

        rv = jwe.deserialize(data, bob_key, sender_key=alice_key)
        self.assertEqual(rv['payload'], b'hello')

    def test_generic_deserialize_for_json_serialization_string(self):
        jwe = JsonWebEncryption()

        alice_key = OKPKey.generate_key('X25519', is_private=True)
        bob_key = OKPKey.generate_key('X25519', is_private=True)

        protected = {'alg': 'ECDH-1PU+A128KW', 'enc': 'A128CBC-HS256'}
        header_obj = {'protected': protected}

        data = jwe.serialize(header_obj, b'hello', bob_key, sender_key=alice_key)
        self.assertIsInstance(data, dict)

        data_as_string = json.dumps(data)

        rv = jwe.deserialize(data_as_string, bob_key, sender_key=alice_key)
        self.assertEqual(rv['payload'], b'hello')
