from datetime import timedelta
from importlib import reload
from unittest.mock import MagicMock, patch

from django.conf import settings
from django.contrib.auth import get_user_model
from django.core import exceptions as django_exceptions
from django.test import TestCase, override_settings
from rest_framework import exceptions as drf_exceptions

from rest_framework_simplejwt.exceptions import TokenError
from rest_framework_simplejwt.serializers import (
    TokenBlacklistSerializer,
    TokenObtainPairSerializer,
    TokenObtainSerializer,
    TokenObtainSlidingSerializer,
    TokenRefreshSerializer,
    TokenRefreshSlidingSerializer,
    TokenVerifySerializer,
)
from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.token_blacklist.models import (
    BlacklistedToken,
    OutstandingToken,
)
from rest_framework_simplejwt.tokens import AccessToken, RefreshToken, SlidingToken
from rest_framework_simplejwt.utils import (
    aware_utcnow,
    datetime_from_epoch,
    datetime_to_epoch,
)

from .utils import override_api_settings

User = get_user_model()


class TestTokenObtainSerializer(TestCase):
    def setUp(self):
        self.username = "test_user"
        self.password = "test_password"

        self.user = User.objects.create_user(
            username=self.username,
            password=self.password,
        )

    def test_it_should_not_validate_if_any_fields_missing(self):
        s = TokenObtainSerializer(data={})
        self.assertFalse(s.is_valid())
        self.assertIn(s.username_field, s.errors)
        self.assertIn("password", s.errors)

        s = TokenObtainSerializer(
            data={
                TokenObtainSerializer.username_field: "oieanrst",
            }
        )
        self.assertFalse(s.is_valid())
        self.assertIn("password", s.errors)

        s = TokenObtainSerializer(
            data={
                "password": "oieanrst",
            }
        )
        self.assertFalse(s.is_valid())
        self.assertIn(s.username_field, s.errors)

    def test_it_should_not_validate_if_user_not_found(self):
        s = TokenObtainSerializer(
            context=MagicMock(),
            data={
                TokenObtainSerializer.username_field: "missing",
                "password": "pass",
            },
        )

        with self.assertRaises(drf_exceptions.AuthenticationFailed):
            s.is_valid()

    def test_it_should_pass_validate_if_request_not_in_context(self):
        s = TokenObtainSerializer(
            context={},
            data={
                "username": self.username,
                "password": self.password,
            },
        )

        s.is_valid()

    def test_it_should_raise_if_user_not_active(self):
        self.user.is_active = False
        self.user.save()

        s = TokenObtainSerializer(
            context=MagicMock(),
            data={
                TokenObtainSerializer.username_field: self.username,
                "password": self.password,
            },
        )

        with self.assertRaises(drf_exceptions.AuthenticationFailed):
            s.is_valid()

    @override_settings(
        AUTHENTICATION_BACKENDS=[
            "django.contrib.auth.backends.AllowAllUsersModelBackend",
            "django.contrib.auth.backends.ModelBackend",
        ]
    )
    @override_api_settings(
        CHECK_USER_IS_ACTIVE=False,
    )
    def test_it_should_validate_if_user_inactive_but_rule_allows(self):
        self.user.is_active = False
        self.user.save()

        s = TokenObtainSerializer(
            context=MagicMock(),
            data={
                TokenObtainSerializer.username_field: self.username,
                "password": self.password,
            },
        )

        self.assertTrue(s.is_valid())


class TestTokenObtainSlidingSerializer(TestCase):
    def setUp(self):
        self.username = "test_user"
        self.password = "test_password"

        self.user = User.objects.create_user(
            username=self.username,
            password=self.password,
        )

    def test_it_should_produce_a_json_web_token_when_valid(self):
        s = TokenObtainSlidingSerializer(
            context=MagicMock(),
            data={
                TokenObtainSlidingSerializer.username_field: self.username,
                "password": self.password,
            },
        )

        self.assertTrue(s.is_valid())
        self.assertIn("token", s.validated_data)

        # Expecting token type claim to be correct for sliding token.  If this
        # is the case, instantiating a `SlidingToken` instance with encoded
        # token should not raise an exception.
        SlidingToken(s.validated_data["token"])


class TestTokenObtainPairSerializer(TestCase):
    def setUp(self):
        self.username = "test_user"
        self.password = "test_password"

        self.user = User.objects.create_user(
            username=self.username,
            password=self.password,
        )

    def test_it_should_produce_a_json_web_token_when_valid(self):
        s = TokenObtainPairSerializer(
            context=MagicMock(),
            data={
                TokenObtainPairSerializer.username_field: self.username,
                "password": self.password,
            },
        )

        self.assertTrue(s.is_valid())
        self.assertIn("access", s.validated_data)
        self.assertIn("refresh", s.validated_data)

        # Expecting token type claim to be correct for both tokens.  If this is
        # the case, instantiating appropriate token subclass instances with
        # encoded tokens should not raise an exception.
        AccessToken(s.validated_data["access"])
        RefreshToken(s.validated_data["refresh"])


class TestTokenRefreshSlidingSerializer(TestCase):
    def test_it_should_not_validate_if_token_invalid(self):
        token = SlidingToken()
        del token["exp"]

        s = TokenRefreshSlidingSerializer(data={"token": str(token)})

        with self.assertRaises(TokenError) as e:
            s.is_valid()

        self.assertIn("has no 'exp' claim", e.exception.args[0])

        token.set_exp(lifetime=-timedelta(days=1))

        s = TokenRefreshSlidingSerializer(data={"token": str(token)})

        with self.assertRaises(TokenError) as e:
            s.is_valid()

        self.assertIn("expired", e.exception.args[0])

    def test_it_should_raise_token_error_if_token_has_no_refresh_exp_claim(self):
        token = SlidingToken()
        del token[api_settings.SLIDING_TOKEN_REFRESH_EXP_CLAIM]

        s = TokenRefreshSlidingSerializer(data={"token": str(token)})

        with self.assertRaises(TokenError) as e:
            s.is_valid()

        self.assertIn(
            f"has no '{api_settings.SLIDING_TOKEN_REFRESH_EXP_CLAIM}' claim",
            e.exception.args[0],
        )

    def test_it_should_raise_token_error_if_token_has_refresh_period_expired(self):
        token = SlidingToken()
        token.set_exp(
            api_settings.SLIDING_TOKEN_REFRESH_EXP_CLAIM, lifetime=-timedelta(days=1)
        )

        s = TokenRefreshSlidingSerializer(data={"token": str(token)})

        with self.assertRaises(TokenError) as e:
            s.is_valid()

        self.assertIn(
            "'{}' claim has expired".format(
                api_settings.SLIDING_TOKEN_REFRESH_EXP_CLAIM
            ),
            e.exception.args[0],
        )

    def test_it_should_raise_token_error_if_token_has_wrong_type(self):
        token = SlidingToken()
        token[api_settings.TOKEN_TYPE_CLAIM] = "wrong_type"

        s = TokenRefreshSlidingSerializer(data={"token": str(token)})

        with self.assertRaises(TokenError) as e:
            s.is_valid()

        self.assertIn("wrong type", e.exception.args[0])

    def test_it_should_update_token_exp_claim_if_everything_ok(self):
        old_token = SlidingToken()

        lifetime = api_settings.SLIDING_TOKEN_LIFETIME - timedelta(seconds=1)
        old_exp = old_token.current_time + lifetime

        old_token.set_exp(lifetime=lifetime)

        # Serializer validates
        s = TokenRefreshSlidingSerializer(data={"token": str(old_token)})
        self.assertTrue(s.is_valid())

        # Expiration claim has moved into future
        new_token = SlidingToken(s.validated_data["token"])
        new_exp = datetime_from_epoch(new_token["exp"])

        self.assertTrue(old_exp < new_exp)


class TestTokenRefreshSerializer(TestCase):
    def setUp(self):
        self.username = "test_user"
        self.password = "test_password"

        self.user = User.objects.create_user(
            username=self.username,
            password=self.password,
        )

    def test_it_should_raise_error_for_deleted_users(self):
        refresh = RefreshToken.for_user(self.user)
        self.user.delete()

        s = TokenRefreshSerializer(data={"refresh": str(refresh)})

        with self.assertRaises(django_exceptions.ObjectDoesNotExist) as e:
            s.is_valid()

        self.assertIn("does not exist", str(e.exception))

    def test_it_should_raise_error_for_inactive_users(self):
        refresh = RefreshToken.for_user(self.user)
        self.user.is_active = False
        self.user.save()

        s = TokenRefreshSerializer(data={"refresh": str(refresh)})

        with self.assertRaises(drf_exceptions.AuthenticationFailed) as e:
            s.is_valid()

        self.assertIn("No active account", e.exception.args[0])

    def test_it_should_return_access_token_for_active_users(self):
        refresh = RefreshToken.for_user(self.user)

        s = TokenRefreshSerializer(data={"refresh": str(refresh)})

        now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2
        with patch("rest_framework_simplejwt.tokens.aware_utcnow") as fake_aware_utcnow:
            fake_aware_utcnow.return_value = now
            s.is_valid()

        access = AccessToken(s.validated_data["access"])

        self.assertEqual(
            access["exp"], datetime_to_epoch(now + api_settings.ACCESS_TOKEN_LIFETIME)
        )

    def test_it_should_raise_token_error_if_token_invalid(self):
        token = RefreshToken()
        del token["exp"]

        s = TokenRefreshSerializer(data={"refresh": str(token)})

        with self.assertRaises(TokenError) as e:
            s.is_valid()

        self.assertIn("has no 'exp' claim", e.exception.args[0])

        token.set_exp(lifetime=-timedelta(days=1))

        s = TokenRefreshSerializer(data={"refresh": str(token)})

        with self.assertRaises(TokenError) as e:
            s.is_valid()

        self.assertIn("expired", e.exception.args[0])

    def test_it_should_raise_token_error_if_token_has_wrong_type(self):
        token = RefreshToken()
        token[api_settings.TOKEN_TYPE_CLAIM] = "wrong_type"

        s = TokenRefreshSerializer(data={"refresh": str(token)})

        with self.assertRaises(TokenError) as e:
            s.is_valid()

        self.assertIn("wrong type", e.exception.args[0])

    def test_it_should_return_access_token_if_everything_ok(self):
        refresh = RefreshToken()
        refresh["test_claim"] = "arst"

        # Serializer validates
        s = TokenRefreshSerializer(data={"refresh": str(refresh)})

        now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

        with patch("rest_framework_simplejwt.tokens.aware_utcnow") as fake_aware_utcnow:
            fake_aware_utcnow.return_value = now
            self.assertTrue(s.is_valid())

        access = AccessToken(s.validated_data["access"])

        self.assertEqual(refresh["test_claim"], access["test_claim"])
        self.assertEqual(
            access["exp"], datetime_to_epoch(now + api_settings.ACCESS_TOKEN_LIFETIME)
        )

    @override_api_settings(
        ROTATE_REFRESH_TOKENS=True,
        BLACKLIST_AFTER_ROTATION=False,
    )
    def test_it_should_return_refresh_token_if_tokens_should_be_rotated(self):
        refresh = RefreshToken()

        refresh["test_claim"] = "arst"

        old_jti = refresh["jti"]
        old_exp = refresh["exp"]

        # Serializer validates
        ser = TokenRefreshSerializer(data={"refresh": str(refresh)})

        now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

        with patch("rest_framework_simplejwt.tokens.aware_utcnow") as fake_aware_utcnow:
            fake_aware_utcnow.return_value = now
            self.assertTrue(ser.is_valid())

        access = AccessToken(ser.validated_data["access"])
        new_refresh = RefreshToken(ser.validated_data["refresh"])

        self.assertEqual(refresh["test_claim"], access["test_claim"])
        self.assertEqual(refresh["test_claim"], new_refresh["test_claim"])

        self.assertNotEqual(old_jti, new_refresh["jti"])
        self.assertNotEqual(old_exp, new_refresh["exp"])

        self.assertEqual(
            access["exp"], datetime_to_epoch(now + api_settings.ACCESS_TOKEN_LIFETIME)
        )
        self.assertEqual(
            new_refresh["exp"],
            datetime_to_epoch(now + api_settings.REFRESH_TOKEN_LIFETIME),
        )

    @override_api_settings(
        ROTATE_REFRESH_TOKENS=True,
        BLACKLIST_AFTER_ROTATION=True,
    )
    def test_it_should_blacklist_refresh_token_if_tokens_should_be_rotated_and_blacklisted(
        self,
    ):
        self.assertEqual(OutstandingToken.objects.count(), 0)
        self.assertEqual(BlacklistedToken.objects.count(), 0)

        refresh = RefreshToken()

        refresh["test_claim"] = "arst"

        old_jti = refresh["jti"]
        old_exp = refresh["exp"]

        # Serializer validates
        ser = TokenRefreshSerializer(data={"refresh": str(refresh)})

        now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

        with patch("rest_framework_simplejwt.tokens.aware_utcnow") as fake_aware_utcnow:
            fake_aware_utcnow.return_value = now
            self.assertTrue(ser.is_valid())

        access = AccessToken(ser.validated_data["access"])
        new_refresh = RefreshToken(ser.validated_data["refresh"])

        self.assertEqual(refresh["test_claim"], access["test_claim"])
        self.assertEqual(refresh["test_claim"], new_refresh["test_claim"])

        self.assertNotEqual(old_jti, new_refresh["jti"])
        self.assertNotEqual(old_exp, new_refresh["exp"])

        self.assertEqual(
            access["exp"], datetime_to_epoch(now + api_settings.ACCESS_TOKEN_LIFETIME)
        )
        self.assertEqual(
            new_refresh["exp"],
            datetime_to_epoch(now + api_settings.REFRESH_TOKEN_LIFETIME),
        )

        self.assertEqual(OutstandingToken.objects.count(), 2)
        self.assertEqual(BlacklistedToken.objects.count(), 1)

        # Assert old refresh token is blacklisted
        self.assertEqual(BlacklistedToken.objects.first().token.jti, old_jti)

    @override_api_settings(
        ROTATE_REFRESH_TOKENS=True,
        BLACKLIST_AFTER_ROTATION=True,
    )
    def test_blacklist_app_not_installed_should_pass(self):
        from rest_framework_simplejwt import serializers, tokens

        # Remove blacklist app
        new_apps = list(settings.INSTALLED_APPS)
        new_apps.remove("rest_framework_simplejwt.token_blacklist")

        with self.settings(INSTALLED_APPS=tuple(new_apps)):
            # Reload module that blacklist app not installed
            reload(tokens)
            reload(serializers)

            refresh = tokens.RefreshToken()

            # Serializer validates
            ser = serializers.TokenRefreshSerializer(data={"refresh": str(refresh)})
            ser.validate({"refresh": str(refresh)})

        # Restore origin module without mock
        reload(tokens)
        reload(serializers)


class TestTokenVerifySerializer(TestCase):
    def test_it_should_raise_token_error_if_token_invalid(self):
        token = RefreshToken()
        del token["exp"]

        s = TokenVerifySerializer(data={"token": str(token)})

        with self.assertRaises(TokenError) as e:
            s.is_valid()

        self.assertIn("has no 'exp' claim", e.exception.args[0])

        token.set_exp(lifetime=-timedelta(days=1))

        s = TokenVerifySerializer(data={"token": str(token)})

        with self.assertRaises(TokenError) as e:
            s.is_valid()

        self.assertIn("expired", e.exception.args[0])

    def test_it_should_not_raise_token_error_if_token_has_wrong_type(self):
        token = RefreshToken()
        token[api_settings.TOKEN_TYPE_CLAIM] = "wrong_type"

        s = TokenVerifySerializer(data={"token": str(token)})

        self.assertTrue(s.is_valid())

    def test_it_should_return_given_token_if_everything_ok(self):
        refresh = RefreshToken()
        refresh["test_claim"] = "arst"

        # Serializer validates
        s = TokenVerifySerializer(data={"token": str(refresh)})

        now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

        with patch("rest_framework_simplejwt.tokens.aware_utcnow") as fake_aware_utcnow:
            fake_aware_utcnow.return_value = now
            self.assertTrue(s.is_valid())

        self.assertEqual(len(s.validated_data), 0)


class TestTokenBlacklistSerializer(TestCase):
    def test_it_should_raise_token_error_if_token_invalid(self):
        token = RefreshToken()
        del token["exp"]

        s = TokenBlacklistSerializer(data={"refresh": str(token)})

        with self.assertRaises(TokenError) as e:
            s.is_valid()

        self.assertIn("has no 'exp' claim", e.exception.args[0])

        token.set_exp(lifetime=-timedelta(days=1))

        s = TokenBlacklistSerializer(data={"refresh": str(token)})

        with self.assertRaises(TokenError) as e:
            s.is_valid()

        self.assertIn("expired", e.exception.args[0])

    def test_it_should_raise_token_error_if_token_has_wrong_type(self):
        token = RefreshToken()
        token[api_settings.TOKEN_TYPE_CLAIM] = "wrong_type"

        s = TokenBlacklistSerializer(data={"refresh": str(token)})

        with self.assertRaises(TokenError) as e:
            s.is_valid()

        self.assertIn("wrong type", e.exception.args[0])

    def test_it_should_return_nothing_if_everything_ok(self):
        refresh = RefreshToken()
        refresh["test_claim"] = "arst"

        # Serializer validates
        s = TokenBlacklistSerializer(data={"refresh": str(refresh)})

        now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

        with patch("rest_framework_simplejwt.tokens.aware_utcnow") as fake_aware_utcnow:
            fake_aware_utcnow.return_value = now
            self.assertTrue(s.is_valid())

        self.assertDictEqual(s.validated_data, {})

    def test_it_should_blacklist_refresh_token_if_everything_ok(self):
        self.assertEqual(OutstandingToken.objects.count(), 0)
        self.assertEqual(BlacklistedToken.objects.count(), 0)

        refresh = RefreshToken()

        refresh["test_claim"] = "arst"

        old_jti = refresh["jti"]

        # Serializer validates
        ser = TokenBlacklistSerializer(data={"refresh": str(refresh)})

        now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

        with patch("rest_framework_simplejwt.tokens.aware_utcnow") as fake_aware_utcnow:
            fake_aware_utcnow.return_value = now
            self.assertTrue(ser.is_valid())

        self.assertEqual(OutstandingToken.objects.count(), 1)
        self.assertEqual(BlacklistedToken.objects.count(), 1)

        # Assert old refresh token is blacklisted
        self.assertEqual(BlacklistedToken.objects.first().token.jti, old_jti)

    def test_blacklist_app_not_installed_should_pass(self):
        from rest_framework_simplejwt import serializers, tokens

        # Remove blacklist app
        new_apps = list(settings.INSTALLED_APPS)
        new_apps.remove("rest_framework_simplejwt.token_blacklist")

        with self.settings(INSTALLED_APPS=tuple(new_apps)):
            # Reload module that blacklist app not installed
            reload(tokens)
            reload(serializers)

            refresh = tokens.RefreshToken()

            # Serializer validates
            ser = serializers.TokenBlacklistSerializer(data={"refresh": str(refresh)})
            ser.validate({"refresh": str(refresh)})

        # Restore origin module without mock
        reload(tokens)
        reload(serializers)
