from functools import partial

from django.core.cache import cache, InvalidCacheBackendError
from django.core.exceptions import ImproperlyConfigured
from django.test import RequestFactory, TestCase
from django.test.utils import override_settings
from django.utils.decorators import method_decorator
from django.views.generic import View

from django_ratelimit.decorators import ratelimit
from django_ratelimit.exceptions import Ratelimited
from django_ratelimit.core import (get_usage, is_ratelimited,
                                   _split_rate, _get_ip)

from unittest import skipIf
try:
    import django_redis
except ImportError:
    django_redis = False


rf = RequestFactory()


class MockUser:
    def __init__(self, authenticated=False):
        self.pk = 1
        self.is_authenticated = authenticated


class RateParsingTests(TestCase):
    def test_simple(self):
        tests = (
            ('100/s', (100, 1)),
            ('100/10s', (100, 10)),
            ('100/10', (100, 10)),
            ('100/m', (100, 60)),
            ('400/10m', (400, 600)),
            ('1000/h', (1000, 3600)),
            ('800/d', (800, 24 * 60 * 60)),
        )

        for i, o in tests:
            assert o == _split_rate(i)


def callable_rate(group, request):
    if request.user.is_authenticated:
        return None
    return (0, 1)


def mykey(group, request):
    return request.META['REMOTE_ADDR'][::-1]


class CustomRatelimitedException(Exception):
    pass


class RatelimitTests(TestCase):
    def setUp(self):
        cache.clear()

    def test_no_key(self):
        @ratelimit(rate='1/m')
        def view(request):
            return True

        req = rf.get('/')
        with self.assertRaises(ImproperlyConfigured):
            view(req)

    def test_ip(self):
        @ratelimit(key='ip', rate='1/m', block=False)
        def view(request):
            return request.limited

        assert not view(rf.get('/')), 'First request works.'
        assert view(rf.get('/')), 'Second request is limited'

    def test_block(self):
        @ratelimit(key='ip', rate='1/m')
        def blocked(request):
            return request.limited

        assert not blocked(rf.get('/')), 'First request works.'
        with self.assertRaises(Ratelimited):
            blocked(rf.get('/')), 'Second request is blocked.'

    def test_ratelimit_custom_string_exception_class(self):
        @ratelimit(key='ip', rate='1/m')
        def view(request):
            return request.limited

        with self.settings(
            RATELIMIT_EXCEPTION_CLASS=(
                "django_ratelimit.tests.CustomRatelimitedException"
            )
        ):
            req = rf.get("")
            assert not view(req)
            with self.assertRaises(CustomRatelimitedException):
                view(req)

    def test_ratelimit_custom_exception_class(self):
        @ratelimit(key='ip', rate='1/m')
        def view(request):
            return request.limited

        with self.settings(
            RATELIMIT_EXCEPTION_CLASS=CustomRatelimitedException
        ):
            req = rf.get("")
            assert not view(req)
            with self.assertRaises(CustomRatelimitedException):
                view(req)

    def test_method(self):
        @ratelimit(key='ip', method='POST', rate='1/m', group='a', block=False)
        def limit_post(request):
            return request.limited

        assert not limit_post(rf.post('/')), 'Do not limit first POST.'
        assert limit_post(rf.post('/')), 'Limit second POST.'
        assert not limit_post(rf.get('/')), 'Do not limit GET.'

    def test_unsafe_methods(self):
        @ratelimit(key='ip', method=ratelimit.UNSAFE, rate='0/m', block=False)
        def limit_unsafe(request):
            return request.limited

        assert not limit_unsafe(rf.get('/'))
        assert not limit_unsafe(rf.head('/'))
        assert not limit_unsafe(rf.options('/'))
        assert limit_unsafe(rf.delete('/'))
        assert limit_unsafe(rf.post('/'))
        assert limit_unsafe(rf.put('/'))
        assert limit_unsafe(rf.patch('/'))

    def test_key_get(self):
        @ratelimit(key='get:foo', rate='1/m', method='GET', block=False)
        def view(request):
            return request.limited

        assert not view(rf.get('/', {'foo': 'a'}))
        assert view(rf.get('/', {'foo': 'a'}))
        assert not view(rf.get('/', {'foo': 'b'}))
        assert view(rf.get('/', {'foo': 'b'}))

    def test_key_post(self):
        @ratelimit(key='post:foo', rate='1/m', block=False)
        def view(request):
            return request.limited

        assert not view(rf.post('/', {'foo': 'a'}))
        assert view(rf.post('/', {'foo': 'a'}))
        assert not view(rf.post('/', {'foo': 'b'}))
        assert view(rf.post('/', {'foo': 'b'}))

    def test_key_header(self):
        def _req():
            req = rf.post('/')
            req.META['HTTP_X_REAL_IP'] = '1.2.3.4'
            return req

        @ratelimit(key='header:x-real-ip', rate='1/m', block=False)
        @ratelimit(key='header:x-missing-header', rate='1/m', block=False)
        def view(request):
            return request.limited

        assert not view(_req())
        assert view(_req())

    def test_rate(self):
        @ratelimit(key='ip', rate='2/m', block=False)
        def twice(request):
            return request.limited

        assert not twice(rf.post('/')), 'First request is not limited.'
        assert not twice(rf.post('/')), 'Second request is not limited.'
        assert twice(rf.post('/')), 'Third request is limited.'

    def test_zero_rate(self):
        @ratelimit(key='ip', rate='0/m', block=False)
        def never(request):
            return request.limited

        assert never(rf.post('/'))

    def test_none_rate(self):
        @ratelimit(key='ip', rate=None, block=False)
        def always(request):
            return request.limited

        assert not always(rf.post('/'))
        assert not always(rf.post('/'))
        assert not always(rf.post('/'))
        assert not always(rf.post('/'))
        assert not always(rf.post('/'))
        assert not always(rf.post('/'))
        assert not always(rf.post('/'))

    def test_callable_rate(self):
        def _req(auth):
            req = rf.post('/')
            req.user = MockUser(authenticated=auth)
            return req

        def get_rate(group, request):
            if request.user.is_authenticated:
                return (2, 60)
            return (1, 60)

        @ratelimit(key='user_or_ip', rate=get_rate, block=False)
        def view(request):
            return request.limited

        assert not view(_req(auth=False))
        assert view(_req(auth=False))
        assert not view(_req(auth=True))
        assert not view(_req(auth=True))
        assert view(_req(auth=True))

    def test_callable_rate_none(self):
        def _req(never_limit=False):
            req = rf.post('/')
            req.never_limit = never_limit
            return req

        get_rate = lambda g, r: None if r.never_limit else '1/m'

        @ratelimit(key='ip', rate=get_rate, block=False)
        def view(request):
            return request.limited

        assert not view(_req())
        assert view(_req())
        assert not view(_req(never_limit=True))
        assert not view(_req(never_limit=True))

    def test_callable_rate_zero(self):
        def _req(auth):
            req = rf.post('/')
            req.user = MockUser(authenticated=auth)
            return req

        def get_rate(group, request):
            if request.user.is_authenticated:
                return '1/m'
            return '0/m'

        @ratelimit(key='ip', rate=get_rate, block=False)
        def view(request):
            return request.limited

        assert view(_req(auth=False))
        assert not view(_req(auth=True))
        assert view(_req(auth=True))

    def test_callable_rate_import(self):
        def _req(auth):
            req = rf.post('/')
            req.user = MockUser(authenticated=auth)
            return req

        @ratelimit(key='user_or_ip',
                   rate='django_ratelimit.tests.callable_rate',
                   block=False)
        def view(request):
            return request.limited

        assert view(_req(auth=False))
        assert not view(_req(auth=True))

    def test_user_or_ip(self):
        """Allow custom functions to set cache keys."""

        def _req(auth):
            req = rf.post('/')
            req.user = MockUser(authenticated=auth)
            return req

        @ratelimit(key='user_or_ip', rate='1/m', block=False)
        def view(request):
            return request.limited

        assert not view(_req(auth=False))
        assert view(_req(auth=False))

        auth = rf.post('/')
        auth.user = MockUser(authenticated=True)

        assert not view(_req(auth=True))
        assert view(_req(auth=True))

    def test_callable_key_path(self):
        @ratelimit(key='django_ratelimit.tests.mykey', rate='1/m', block=False)
        def view(request):
            return request.limited

        assert not view(rf.post('/'))
        assert view(rf.post('/'))

    def test_callable_key(self):
        @ratelimit(key=mykey, rate='1/m', block=False)
        def view(request):
            return request.limited

        assert not view(rf.post('/'))
        assert view(rf.post('/'))

    def test_stacked_decorator(self):
        """Allow @ratelimit to be stacked."""
        # Put the shorter one first and make sure the second one doesn't
        # reset request.limited back to False.
        @ratelimit(rate='1/m', block=False, key=lambda x, y: 'min')
        @ratelimit(rate='10/d', block=False, key=lambda x, y: 'day')
        def view(request):
            return request.limited

        assert not view(rf.post('/'))
        assert view(rf.post('/'))

    def test_stacked_methods(self):
        """Different methods should result in different counts."""
        @ratelimit(rate='1/m', key='ip', method='GET', block=False)
        @ratelimit(rate='1/m', key='ip', method='POST', block=False)
        def view(request):
            return request.limited

        assert not view(rf.get('/'))
        assert not view(rf.post('/'))
        assert view(rf.get('/'))
        assert view(rf.post('/'))

    def test_sorted_methods(self):
        """Order of the methods shouldn't matter."""
        @ratelimit(rate='1/m', key='ip', method=['GET', 'POST'],
                   group='a', block=False)
        def get_post(request):
            return request.limited

        @ratelimit(rate='1/m', key='ip', method=['POST', 'GET'],
                   group='a', block=False)
        def post_get(request):
            return request.limited

        assert not get_post(rf.get('/'))
        assert post_get(rf.get('/'))

    def test_ratelimit_full_mask_v4(self):
        @ratelimit(rate='1/m', key='ip', block=False)
        def view(request):
            return request.limited

        with self.settings(RATELIMIT_IPV4_MASK=32):
            req = rf.get('/')
            req.META['REMOTE_ADDR'] = '10.1.1.1'
            assert not view(req)
            assert view(req)

            req = rf.get('/')
            req.META['REMOTE_ADDR'] = '10.1.1.2'
            assert not view(req)

    def test_ratelimit_full_mask_v6(self):
        @ratelimit(rate='1/m', key='ip', block=False)
        def view(request):
            return request.limited

        with self.settings(RATELIMIT_IPV6_MASK=128):
            req = rf.get('/')
            req.META['REMOTE_ADDR'] = '2001:db8::1000'
            assert not view(req)
            assert view(req)

            req = rf.get('/')
            req.META['REMOTE_ADDR'] = '2001:db8::1001'
            assert not view(req)

    def test_ratelimit_mask_v4(self):
        @ratelimit(rate='1/m', key='ip', block=False)
        def view(request):
            return request.limited

        with self.settings(RATELIMIT_IPV4_MASK=16):
            req = rf.get('/')
            req.META['REMOTE_ADDR'] = '10.1.1.1'
            assert not view(req)
            assert view(req)

            req = rf.get('/')
            req.META['REMOTE_ADDR'] = '10.1.0.1'
            assert view(req)

            req = rf.get('/')
            req.META['REMOTE_ADDR'] = '192.168.1.1'
            assert not view(req)

    def test_ratelimit_mask_v6(self):
        @ratelimit(rate='1/m', key='ip', block=False)
        def view(request):
            return request.limited

        with self.settings(RATELIMIT_IPV6_MASK=64):
            req = rf.get('/')
            req.META['REMOTE_ADDR'] = '2001:db8::1000'
            assert not view(req)
            assert view(req)

            req = rf.get('/')
            req.META['REMOTE_ADDR'] = '2001:db8::1001'
            assert view(req)

            req = rf.get('/')
            req.META['REMOTE_ADDR'] = '2001:db9::1000'
            assert not view(req)


class FunctionsTests(TestCase):
    def setUp(self):
        cache.clear()

    def test_is_ratelimited(self):
        not_increment = partial(is_ratelimited, increment=False, rate='1/m',
                                method=is_ratelimited.ALL, key='ip', group='a')

        # Does not increment. Count still 0. Does not rate limit
        # because 0 < 1.
        assert not not_increment(rf.get('/'))

        # Does not increment. Count still 1. Not limited because 1 > 1
        # is false.
        assert not not_increment(rf.get('/'))

    def test_is_ratelimited_increment(self):
        do_increment = partial(is_ratelimited, increment=True, rate='1/m',
                               method=is_ratelimited.ALL, key='ip', group='a')

        # Increments. Does not rate limit because 0 < 1. Count now 1.
        assert not do_increment(rf.get('/'))

        # Count = 2, 2 > 1.
        assert do_increment(rf.get('/'))

    def test_get_usage(self):
        _get_usage = partial(get_usage, method=get_usage.ALL, key='ip',
                             rate='1/m', group='a')
        usage = _get_usage(rf.get('/'))

        self.assertEqual(usage['count'], 0)
        self.assertEqual(usage['limit'], 1)
        self.assertLessEqual(usage['time_left'], 60)
        self.assertFalse(usage['should_limit'])

    def test_get_usage_increment(self):
        _get_usage = partial(get_usage, method=get_usage.ALL, key='ip',
                             rate='1/m', group='a', increment=True)
        _get_usage(rf.get('/'))
        usage = _get_usage(rf.get('/'))

        self.assertEqual(usage['count'], 2)
        self.assertEqual(usage['limit'], 1)
        self.assertLessEqual(usage['time_left'], 60)
        self.assertTrue(usage['should_limit'])

    def test_not_increment_after_increment(self):
        _get_usage = partial(get_usage, method=get_usage.ALL, key='ip',
                             rate='1/m', group='a')
        _get_usage(rf.get('/'), increment=True)
        _get_usage(rf.get('/'), increment=True)
        usage = _get_usage(rf.get('/'))

        self.assertEqual(usage['count'], 2)
        self.assertEqual(usage['limit'], 1)
        self.assertLessEqual(usage['time_left'], 60)
        self.assertTrue(usage['should_limit'])

    def test_get_usage_called_without_group_or_fn(self):
        with self.assertRaises(ImproperlyConfigured):
            get_usage(rf.get('/'), key='ip')


class RatelimitCBVTests(TestCase):
    def setUp(self):
        cache.clear()

    def test_method_decorator(self):
        class TestView(View):
            @method_decorator(ratelimit(key='ip', rate='1/m', block=False))
            def post(self, request):
                return request.limited

        view = TestView.as_view()

        assert not view(rf.post('/'))
        assert view(rf.post('/'))

    def test_class_decorator(self):
        @method_decorator(ratelimit(key='ip', rate='1/m', block=False),
                          name='get')
        class TestView(View):
            def get(self, request):
                return request.limited

        view = TestView.as_view()

        assert not view(rf.get('/'))
        assert view(rf.get('/'))

    def test_wrap_view(self):
        class TestView(View):
            def get(self, request):
                return request.limited

        view = TestView.as_view()
        wrapped = ratelimit(key='ip', rate='1/m', block=False)(view)

        assert not wrapped(rf.get('/'))
        assert wrapped(rf.get('/'))

    def test_methods_counted_separately(self):
        class TestView(View):
            @method_decorator(ratelimit(key='ip', rate='1/m',
                                        method='GET', block=False))
            def get(self, request):
                return request.limited

            @method_decorator(ratelimit(key='ip', rate='1/m',
                                        method='POST', block=False))
            def post(self, request):
                return request.limited

        view = TestView.as_view()

        assert not view(rf.get('/'))
        assert view(rf.get('/'))
        assert not view(rf.post('/'))

    def test_views_counted_separately(self):
        class TestView(View):
            @method_decorator(ratelimit(key='ip', rate='1/m',
                                        method='GET', block=False))
            def get(self, request):
                return request.limited

        class AnotherTestView(View):
            @method_decorator(ratelimit(key='ip', rate='1/m',
                                        method='GET', block=False))
            def get(self, request):
                return request.limited

        test_view = TestView.as_view()
        another_view = AnotherTestView.as_view()

        assert not test_view(rf.get('/'))
        assert test_view(rf.get('/'))
        assert not another_view(rf.get('/'))


class CacheFailTests(TestCase):
    @override_settings(RATELIMIT_USE_CACHE='fake-cache')
    def test_bad_cache(self):
        @ratelimit(key='ip', rate='1/m', block=False)
        def view(request):
            return request.limited

        with self.assertRaises(InvalidCacheBackendError):
            view(rf.post('/'))

    @override_settings(RATELIMIT_USE_CACHE='connection-errors')
    def test_limit_on_cache_connection_error(self):
        @ratelimit(key='ip', rate='10/m', block=False)
        def view(request):
            return request.limited

        assert view(rf.post('/'))

    @override_settings(RATELIMIT_USE_CACHE='connection-errors',
                       RATELIMIT_FAIL_OPEN=True)
    def test_fail_open_setting(self):
        @ratelimit(key='ip', rate='1/m', block=False)
        def view(request):
            return request.limited

        assert not view(rf.get('/'))
        assert not view(rf.get('/'))

    @override_settings(RATELIMIT_USE_CACHE='connection-errors')
    def test_is_ratelimited_cache_connection_error_without_increment(self):
        def not_increment(request):
            return is_ratelimited(request, increment=False,
                                  method=is_ratelimited.ALL, key='ip',
                                  rate='1/m', group='a')

        assert not not_increment(rf.get('/'))
        assert not not_increment(rf.get('/'))

    @override_settings(RATELIMIT_USE_CACHE='connection-errors')
    def test_is_ratelimited_cache_connection_error_with_increment(self):
        def do_increment(request):
            return is_ratelimited(request, increment=True,
                                  method=is_ratelimited.ALL, key='ip',
                                  rate='1/m', group='a')

        assert do_increment(rf.get('/'))
        assert do_increment(rf.get('/'))

    @skipIf(not django_redis, "django_redis is not installed")
    @override_settings(RATELIMIT_USE_CACHE='connection-errors-redis')
    def test_is_ratelimited_cache_connection_error_with_increment_redis(self):
        def do_increment(request):
            return is_ratelimited(request, increment=True,
                                  method=is_ratelimited.ALL, key='ip',
                                  rate='1/m', group='a')

        assert do_increment(rf.get('/'))
        assert do_increment(rf.get('/'))

    @override_settings(RATELIMIT_USE_CACHE='instant-expiration')
    def test_cache_timeout(self):
        @ratelimit(key='ip', rate='1/m')
        def view(request):
            return True

        assert view(rf.get('/'))
        assert view(rf.get('/'))


def my_ip(req):
    return req.META['MY_THING']


class IpMetaTests(TestCase):
    def test_default(self):
        req = rf.get('/')
        req.META['REMOTE_ADDR'] = '1.2.3.4'

        assert '1.2.3.4' == _get_ip(req)

    @override_settings(RATELIMIT_IP_META_KEY='fake')
    def test_bad_config(self):
        req = rf.get('/')
        req.META['REMOTE_ADDR'] = '1.2.3.4'

        with self.assertRaises(ImproperlyConfigured):
            _get_ip(req)

    @override_settings(RATELIMIT_IP_META_KEY='HTTP_X_CLIENT_IP')
    def test_alternate_header(self):
        req = rf.get('/')
        req.META['REMOTE_ADDR'] = '1.2.3.4'
        req.META['HTTP_X_CLIENT_IP'] = '5.6.7.8'

        assert '5.6.7.8' == _get_ip(req)

    @override_settings(RATELIMIT_IP_META_KEY='django_ratelimit.tests.my_ip')
    def test_path_to_ip_key_callable(self):
        req = rf.get('/')
        req.META['REMOTE_ADDR'] = '1.2.3.4'
        req.META['MY_THING'] = '5.6.7.8'

        assert '5.6.7.8' == _get_ip(req)

    @override_settings(RATELIMIT_IP_META_KEY=my_ip)
    def test_callable_ip_key(self):
        req = rf.get('/')
        req.META['REMOTE_ADDR'] = '1.2.3.4'
        req.META['MY_THING'] = '5.6.7.8'

        assert '5.6.7.8' == _get_ip(req)

    def test_empty_ip(self):
        req = rf.get('/')
        req.META['REMOTE_ADDR'] = ''

        with self.assertRaises(ImproperlyConfigured):
            _get_ip(req)
