import time as real_time
import unittest

import jwt as jwt_lib
from mock import patch

from twilio.jwt import Jwt, JwtDecodeError


class DummyJwt(Jwt):
    """Jwt implementation that allows setting arbitrary payload and headers for testing."""

    ALGORITHM = "HS256"

    def __init__(
        self,
        secret_key,
        issuer,
        subject=None,
        algorithm=None,
        nbf=Jwt.GENERATE,
        ttl=3600,
        valid_until=None,
        headers=None,
        payload=None,
    ):
        super(DummyJwt, self).__init__(
            secret_key=secret_key,
            issuer=issuer,
            subject=subject,
            algorithm=algorithm or self.ALGORITHM,
            nbf=nbf,
            ttl=ttl,
            valid_until=valid_until,
        )
        self._payload = payload or {}
        self._headers = headers or {}

    def _generate_payload(self):
        return self._payload

    def _generate_headers(self):
        return self._headers


class JwtTest(unittest.TestCase):
    def assertIn(self, foo, bar, msg=None):
        """backport for 2.6"""
        assert foo in bar, msg or "%s not found in %s" % (foo, bar)

    def now(self):
        return int(real_time.time())

    def assertJwtsEqual(self, jwt, key, expected_payload=None, expected_headers=None):
        expected_headers = expected_headers or {}
        expected_payload = expected_payload or {}

        decoded_payload = jwt_lib.decode(
            jwt, key, algorithms=["HS256"], options={"verify_signature": False}
        )
        decoded_headers = jwt_lib.get_unverified_header(jwt)

        self.assertEqual(expected_headers, decoded_headers)
        self.assertEqual(expected_payload, decoded_payload)

    @patch("time.time")
    def test_basic_encode(self, time_mock):
        time_mock.return_value = 0.0

        jwt = DummyJwt("secret_key", "issuer", headers={}, payload={})

        self.assertJwtsEqual(
            jwt.to_jwt(),
            "secret_key",
            expected_headers={"typ": "JWT", "alg": "HS256"},
            expected_payload={"iss": "issuer", "exp": 3600, "nbf": 0},
        )

    @patch("time.time")
    def test_encode_with_subject(self, time_mock):
        time_mock.return_value = 0.0

        jwt = DummyJwt(
            "secret_key", "issuer", subject="subject", headers={}, payload={}
        )

        self.assertJwtsEqual(
            jwt.to_jwt(),
            "secret_key",
            expected_headers={"typ": "JWT", "alg": "HS256"},
            expected_payload={"iss": "issuer", "exp": 3600, "nbf": 0, "sub": "subject"},
        )

    @patch("time.time")
    def test_encode_without_nbf(self, time_mock):
        time_mock.return_value = 0.0

        jwt = DummyJwt(
            "secret_key", "issuer", subject="subject", headers={}, payload={}, nbf=None
        )

        self.assertJwtsEqual(
            jwt.to_jwt(),
            "secret_key",
            expected_headers={"typ": "JWT", "alg": "HS256"},
            expected_payload={"iss": "issuer", "exp": 3600, "sub": "subject"},
        )

    @patch("time.time")
    def test_encode_custom_ttl(self, time_mock):
        time_mock.return_value = 0.0

        jwt = DummyJwt("secret_key", "issuer", ttl=10, headers={}, payload={})

        self.assertJwtsEqual(
            jwt.to_jwt(),
            "secret_key",
            expected_headers={"typ": "JWT", "alg": "HS256"},
            expected_payload={"iss": "issuer", "exp": 10, "nbf": 0},
        )

    @patch("time.time")
    def test_encode_ttl_added_to_current_time(self, time_mock):
        time_mock.return_value = 50.0

        jwt = DummyJwt("secret_key", "issuer", ttl=10, headers={}, payload={})

        self.assertJwtsEqual(
            jwt.to_jwt(),
            "secret_key",
            expected_headers={"typ": "JWT", "alg": "HS256"},
            expected_payload={"iss": "issuer", "exp": 60, "nbf": 50},
        )

    @patch("time.time")
    def test_encode_override_ttl(self, time_mock):
        time_mock.return_value = 0.0

        jwt = DummyJwt("secret_key", "issuer", ttl=10, headers={}, payload={})

        self.assertJwtsEqual(
            jwt.to_jwt(ttl=20),
            "secret_key",
            expected_headers={"typ": "JWT", "alg": "HS256"},
            expected_payload={"iss": "issuer", "exp": 20, "nbf": 0},
        )

    @patch("time.time")
    def test_encode_valid_until_overrides_ttl(self, time_mock):
        time_mock.return_value = 0.0

        jwt = DummyJwt(
            "secret_key", "issuer", ttl=10, valid_until=70, headers={}, payload={}
        )

        self.assertJwtsEqual(
            jwt.to_jwt(),
            "secret_key",
            expected_headers={"typ": "JWT", "alg": "HS256"},
            expected_payload={"iss": "issuer", "exp": 70, "nbf": 0},
        )

    @patch("time.time")
    def test_encode_custom_nbf(self, time_mock):
        time_mock.return_value = 0.0

        jwt = DummyJwt("secret_key", "issuer", ttl=10, nbf=5, headers={}, payload={})

        self.assertJwtsEqual(
            jwt.to_jwt(),
            "secret_key",
            expected_headers={"typ": "JWT", "alg": "HS256"},
            expected_payload={"iss": "issuer", "exp": 10, "nbf": 5},
        )

    @patch("time.time")
    def test_encode_with_headers(self, time_mock):
        time_mock.return_value = 0.0

        jwt = DummyJwt("secret_key", "issuer", headers={"sooper": "secret"}, payload={})

        self.assertJwtsEqual(
            jwt.to_jwt(),
            "secret_key",
            expected_headers={"typ": "JWT", "alg": "HS256", "sooper": "secret"},
            expected_payload={"iss": "issuer", "exp": 3600, "nbf": 0},
        )

    @patch("time.time")
    def test_encode_with_payload(self, time_mock):
        time_mock.return_value = 0.0

        jwt = DummyJwt("secret_key", "issuer", payload={"root": "true"})

        self.assertJwtsEqual(
            jwt.to_jwt(),
            "secret_key",
            expected_headers={"typ": "JWT", "alg": "HS256"},
            expected_payload={"iss": "issuer", "exp": 3600, "nbf": 0, "root": "true"},
        )

    @patch("time.time")
    def test_encode_with_payload_and_headers(self, time_mock):
        time_mock.return_value = 0.0

        jwt = DummyJwt(
            "secret_key", "issuer", headers={"yes": "oui"}, payload={"pay": "me"}
        )

        self.assertJwtsEqual(
            jwt.to_jwt(),
            "secret_key",
            expected_headers={"typ": "JWT", "alg": "HS256", "yes": "oui"},
            expected_payload={"iss": "issuer", "exp": 3600, "nbf": 0, "pay": "me"},
        )

    def test_encode_no_key_fails(self):
        jwt = DummyJwt(None, "issuer")
        self.assertRaises(ValueError, jwt.to_jwt)

    def test_encode_decode(self):
        test_start = self.now()

        jwt = DummyJwt("secret_key", "issuer", subject="hey", payload={"sick": "sick"})
        decoded_jwt = Jwt.from_jwt(jwt.to_jwt(), "secret_key")

        self.assertGreaterEqual(decoded_jwt.valid_until, self.now() + 3600)
        self.assertGreaterEqual(decoded_jwt.nbf, test_start)
        self.assertEqual(decoded_jwt.issuer, "issuer")
        self.assertEqual(decoded_jwt.secret_key, "secret_key")
        self.assertEqual(decoded_jwt.algorithm, "HS256")
        self.assertEqual(decoded_jwt.subject, "hey")

        self.assertEqual(decoded_jwt.headers, {"typ": "JWT", "alg": "HS256"})
        self.assertEqual(
            decoded_jwt.payload,
            {
                **decoded_jwt.payload,
                **{
                    "iss": "issuer",
                    "sub": "hey",
                    "sick": "sick",
                },
            },
        )

    def test_encode_decode_mismatched_algorithms(self):
        jwt = DummyJwt(
            "secret_key",
            "issuer",
            algorithm="HS512",
            subject="hey",
            payload={"sick": "sick"},
        )
        self.assertRaises(JwtDecodeError, Jwt.from_jwt, jwt.to_jwt())

    def test_decode_bad_secret(self):
        jwt = DummyJwt("secret_key", "issuer")
        self.assertRaises(JwtDecodeError, Jwt.from_jwt, jwt.to_jwt(), "letmeinplz")

    def test_decode_modified_jwt_fails(self):
        jwt = DummyJwt("secret_key", "issuer")
        example_jwt = jwt.to_jwt()
        example_jwt = "ABC" + example_jwt[3:]

        self.assertRaises(JwtDecodeError, Jwt.from_jwt, example_jwt, "secret_key")

    def test_decode_validates_expiration(self):
        expired_jwt = DummyJwt("secret_key", "issuer", valid_until=self.now())
        real_time.sleep(1)
        self.assertRaises(
            JwtDecodeError, Jwt.from_jwt, expired_jwt.to_jwt(), "secret_key"
        )

    def test_decode_validates_nbf(self):
        expired_jwt = DummyJwt(
            "secret_key", "issuer", nbf=self.now() + 3600
        )  # valid 1hr from now
        self.assertRaises(
            JwtDecodeError, Jwt.from_jwt, expired_jwt.to_jwt(), "secret_key"
        )

    def test_decodes_valid_jwt(self):
        expiry_time = self.now() + 1000
        example_jwt = jwt_lib.encode(
            {"hello": "world", "iss": "me", "sub": "being awesome", "exp": expiry_time},
            "secret",
        )

        decoded_jwt = Jwt.from_jwt(example_jwt, "secret")
        self.assertEqual(decoded_jwt.issuer, "me")
        self.assertEqual(decoded_jwt.subject, "being awesome")
        self.assertEqual(decoded_jwt.valid_until, expiry_time)
        self.assertIn("hello", decoded_jwt.payload)
        self.assertEqual(decoded_jwt.payload["hello"], "world")

    def test_decode_allows_skip_verification(self):
        jwt = DummyJwt("secret", "issuer", payload={"get": "rekt"})
        decoded_jwt = Jwt.from_jwt(jwt.to_jwt(), key=None)
        self.assertEqual(decoded_jwt.issuer, "issuer")
        self.assertEqual(decoded_jwt.payload["get"], "rekt")
        self.assertIsNone(decoded_jwt.secret_key)
