# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import unittest

from packaging.version import Version

from cassandra import InvalidRequest
from cassandra.cqlengine.management import sync_table, drop_table
from tests.integration.cqlengine.base import BaseCassEngTestCase
from cassandra.cqlengine.models import Model
from uuid import uuid4
from cassandra.cqlengine import columns
from unittest import mock
from cassandra.cqlengine.connection import get_session
from tests.integration import CASSANDRA_VERSION, greaterthancass20


class TestTTLModel(Model):
    id = columns.UUID(primary_key=True, default=lambda: uuid4())
    count = columns.Integer()
    text = columns.Text(required=False)


class BaseTTLTest(BaseCassEngTestCase):

    @classmethod
    def setUpClass(cls):
        super(BaseTTLTest, cls).setUpClass()
        sync_table(TestTTLModel)

    @classmethod
    def tearDownClass(cls):
        super(BaseTTLTest, cls).tearDownClass()
        drop_table(TestTTLModel)


class TestDefaultTTLModel(Model):
    __options__ = {'default_time_to_live': 20}
    id = columns.UUID(primary_key=True, default=lambda:uuid4())
    count = columns.Integer()
    text = columns.Text(required=False)


class BaseDefaultTTLTest(BaseCassEngTestCase):

    @classmethod
    def setUpClass(cls):
        if CASSANDRA_VERSION >= Version('2.0'):
            super(BaseDefaultTTLTest, cls).setUpClass()
            sync_table(TestDefaultTTLModel)
            sync_table(TestTTLModel)

    @classmethod
    def tearDownClass(cls):
        if CASSANDRA_VERSION >= Version('2.0'):
            super(BaseDefaultTTLTest, cls).tearDownClass()
            drop_table(TestDefaultTTLModel)
            drop_table(TestTTLModel)


class TTLQueryTests(BaseTTLTest):

    def test_update_queryset_ttl_success_case(self):
        """ tests that ttls on querysets work as expected """

    def test_select_ttl_failure(self):
        """ tests that ttls on select queries raise an exception """


class TTLModelTests(BaseTTLTest):

    def test_ttl_included_on_create(self):
        """ tests that ttls on models work as expected """
        session = get_session()

        with mock.patch.object(session, 'execute') as m:
            TestTTLModel.ttl(60).create(text="hello blake")

        query = m.call_args[0][0].query_string
        self.assertIn("USING TTL", query)

    def test_queryset_is_returned_on_class(self):
        """
        ensures we get a queryset descriptor back
        """
        qs = TestTTLModel.ttl(60)
        self.assertTrue(isinstance(qs, TestTTLModel.__queryset__), type(qs))


class TTLInstanceUpdateTest(BaseTTLTest):
    def test_update_includes_ttl(self):
        session = get_session()

        model = TestTTLModel.create(text="goodbye blake")
        with mock.patch.object(session, 'execute') as m:
            model.ttl(60).update(text="goodbye forever")

        query = m.call_args[0][0].query_string
        self.assertIn("USING TTL", query)

    def test_update_syntax_valid(self):
        # sanity test that ensures the TTL syntax is accepted by cassandra
        model = TestTTLModel.create(text="goodbye blake")
        model.ttl(60).update(text="goodbye forever")


class TTLInstanceTest(BaseTTLTest):
    def test_instance_is_returned(self):
        """
        ensures that we properly handle the instance.ttl(60).save() scenario
        :return:
        """
        o = TestTTLModel.create(text="whatever")
        o.text = "new stuff"
        o = o.ttl(60)
        self.assertEqual(60, o._ttl)

    def test_ttl_is_include_with_query_on_update(self):
        session = get_session()

        o = TestTTLModel.create(text="whatever")
        o.text = "new stuff"
        o = o.ttl(60)

        with mock.patch.object(session, 'execute') as m:
            o.save()

        query = m.call_args[0][0].query_string
        self.assertIn("USING TTL", query)


class TTLBlindUpdateTest(BaseTTLTest):
    def test_ttl_included_with_blind_update(self):
        session = get_session()

        o = TestTTLModel.create(text="whatever")
        tid = o.id

        with mock.patch.object(session, 'execute') as m:
            TestTTLModel.objects(id=tid).ttl(60).update(text="bacon")

        query = m.call_args[0][0].query_string
        self.assertIn("USING TTL", query)


class TTLDefaultTest(BaseDefaultTTLTest):
    def get_default_ttl(self, table_name):
        session = get_session()
        try:
            default_ttl = session.execute("SELECT default_time_to_live FROM system_schema.tables "
                                          "WHERE keyspace_name = 'cqlengine_test' AND table_name = '{0}'".format(table_name))
        except InvalidRequest:
            default_ttl = session.execute("SELECT default_time_to_live FROM system.schema_columnfamilies "
                                          "WHERE keyspace_name = 'cqlengine_test' AND columnfamily_name = '{0}'".format(table_name))
        return default_ttl[0]['default_time_to_live']

    def test_default_ttl_not_set(self):
        session = get_session()

        o = TestTTLModel.create(text="some text")
        tid = o.id

        self.assertIsNone(o._ttl)

        default_ttl = self.get_default_ttl('test_ttlmodel')
        self.assertEqual(default_ttl, 0)

        with mock.patch.object(session, 'execute') as m:
            TestTTLModel.objects(id=tid).update(text="aligators")

        query = m.call_args[0][0].query_string
        self.assertNotIn("USING TTL", query)

    def test_default_ttl_set(self):
        session = get_session()

        o = TestDefaultTTLModel.create(text="some text on ttl")
        tid = o.id

        # Should not be set, it's handled by Cassandra
        self.assertIsNone(o._ttl)

        default_ttl = self.get_default_ttl('test_default_ttlmodel')
        self.assertEqual(default_ttl, 20)

        with mock.patch.object(session, 'execute') as m:
            TestTTLModel.objects(id=tid).update(text="aligators expired")

        # Should not be set either
        query = m.call_args[0][0].query_string
        self.assertNotIn("USING TTL", query)

    def test_default_ttl_modify(self):
        session = get_session()

        default_ttl = self.get_default_ttl('test_default_ttlmodel')
        self.assertEqual(default_ttl, 20)

        TestDefaultTTLModel.__options__ = {'default_time_to_live': 10}
        sync_table(TestDefaultTTLModel)

        default_ttl = self.get_default_ttl('test_default_ttlmodel')
        self.assertEqual(default_ttl, 10)

        # Restore default TTL
        TestDefaultTTLModel.__options__ = {'default_time_to_live': 20}
        sync_table(TestDefaultTTLModel)

    def test_override_default_ttl(self):
        session = get_session()
        o = TestDefaultTTLModel.create(text="some text on ttl")
        tid = o.id

        o.ttl(3600)
        self.assertEqual(o._ttl, 3600)

        with mock.patch.object(session, 'execute') as m:
            TestDefaultTTLModel.objects(id=tid).ttl(None).update(text="aligators expired")

        query = m.call_args[0][0].query_string
        self.assertNotIn("USING TTL", query)
