import sqlalchemy as sa
from sqlalchemy import event
from sqlalchemy import util
from sqlalchemy.ext import instrumentation
from sqlalchemy.orm import attributes
from sqlalchemy.orm import class_mapper
from sqlalchemy.orm import clear_mappers
from sqlalchemy.orm import events
from sqlalchemy.orm.attributes import del_attribute
from sqlalchemy.orm.attributes import get_attribute
from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.instrumentation import is_instrumented
from sqlalchemy.orm.instrumentation import manager_of_class
from sqlalchemy.orm.instrumentation import register_class
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import ne_
from sqlalchemy.testing.util import decorator


@decorator
def modifies_instrumentation_finders(fn, *args, **kw):
    pristine = instrumentation.instrumentation_finders[:]
    try:
        fn(*args, **kw)
    finally:
        del instrumentation.instrumentation_finders[:]
        instrumentation.instrumentation_finders.extend(pristine)


class _ExtBase(object):
    @classmethod
    def teardown_class(cls):
        instrumentation._reinstall_default_lookups()


class MyTypesManager(instrumentation.InstrumentationManager):
    def instrument_attribute(self, class_, key, attr):
        pass

    def install_descriptor(self, class_, key, attr):
        pass

    def uninstall_descriptor(self, class_, key):
        pass

    def instrument_collection_class(self, class_, key, collection_class):
        return MyListLike

    def get_instance_dict(self, class_, instance):
        return instance._goofy_dict

    def initialize_instance_dict(self, class_, instance):
        instance.__dict__["_goofy_dict"] = {}

    def install_state(self, class_, instance, state):
        instance.__dict__["_my_state"] = state

    def state_getter(self, class_):
        return lambda instance: instance.__dict__["_my_state"]


class MyListLike(list):
    # add @appender, @remover decorators as needed
    _sa_iterator = list.__iter__
    _sa_linker = None
    _sa_converter = None

    def _sa_appender(self, item, _sa_initiator=None):
        if _sa_initiator is not False:
            self._sa_adapter.fire_append_event(item, _sa_initiator)
        list.append(self, item)

    append = _sa_appender

    def _sa_remover(self, item, _sa_initiator=None):
        self._sa_adapter.fire_pre_remove_event(_sa_initiator)
        if _sa_initiator is not False:
            self._sa_adapter.fire_remove_event(item, _sa_initiator)
        list.remove(self, item)

    remove = _sa_remover


MyBaseClass, MyClass = None, None


class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
    @classmethod
    def setup_class(cls):
        global MyBaseClass, MyClass

        class MyBaseClass(object):
            __sa_instrumentation_manager__ = (
                instrumentation.InstrumentationManager
            )

        class MyClass(object):

            # This proves that a staticmethod will work here; don't
            # flatten this back to a class assignment!
            def __sa_instrumentation_manager__(cls):
                return MyTypesManager(cls)

            __sa_instrumentation_manager__ = staticmethod(
                __sa_instrumentation_manager__
            )

            # This proves SA can handle a class with non-string dict keys
            if not util.pypy and not util.jython:
                locals()[42] = 99  # Don't remove this line!

            def __init__(self, **kwargs):
                for k in kwargs:
                    setattr(self, k, kwargs[k])

            def __getattr__(self, key):
                if is_instrumented(self, key):
                    return get_attribute(self, key)
                else:
                    try:
                        return self._goofy_dict[key]
                    except KeyError:
                        raise AttributeError(key)

            def __setattr__(self, key, value):
                if is_instrumented(self, key):
                    set_attribute(self, key, value)
                else:
                    self._goofy_dict[key] = value

            def __hasattr__(self, key):
                if is_instrumented(self, key):
                    return True
                else:
                    return key in self._goofy_dict

            def __delattr__(self, key):
                if is_instrumented(self, key):
                    del_attribute(self, key)
                else:
                    del self._goofy_dict[key]

    def teardown(self):
        clear_mappers()

    def test_instance_dict(self):
        class User(MyClass):
            pass

        register_class(User)
        attributes.register_attribute(
            User, "user_id", uselist=False, useobject=False
        )
        attributes.register_attribute(
            User, "user_name", uselist=False, useobject=False
        )
        attributes.register_attribute(
            User, "email_address", uselist=False, useobject=False
        )

        u = User()
        u.user_id = 7
        u.user_name = "john"
        u.email_address = "lala@123.com"
        eq_(
            u.__dict__,
            {
                "_my_state": u._my_state,
                "_goofy_dict": {
                    "user_id": 7,
                    "user_name": "john",
                    "email_address": "lala@123.com",
                },
            },
        )

    def test_basic(self):
        for base in (object, MyBaseClass, MyClass):

            class User(base):
                pass

            register_class(User)
            attributes.register_attribute(
                User, "user_id", uselist=False, useobject=False
            )
            attributes.register_attribute(
                User, "user_name", uselist=False, useobject=False
            )
            attributes.register_attribute(
                User, "email_address", uselist=False, useobject=False
            )

            u = User()
            u.user_id = 7
            u.user_name = "john"
            u.email_address = "lala@123.com"

            eq_(u.user_id, 7)
            eq_(u.user_name, "john")
            eq_(u.email_address, "lala@123.com")
            attributes.instance_state(u)._commit_all(
                attributes.instance_dict(u)
            )
            eq_(u.user_id, 7)
            eq_(u.user_name, "john")
            eq_(u.email_address, "lala@123.com")

            u.user_name = "heythere"
            u.email_address = "foo@bar.com"
            eq_(u.user_id, 7)
            eq_(u.user_name, "heythere")
            eq_(u.email_address, "foo@bar.com")

    def test_deferred(self):
        for base in (object, MyBaseClass, MyClass):

            class Foo(base):
                pass

            data = {"a": "this is a", "b": 12}

            def loader(state, keys):
                for k in keys:
                    state.dict[k] = data[k]
                return attributes.ATTR_WAS_SET

            manager = register_class(Foo)
            manager.deferred_scalar_loader = loader
            attributes.register_attribute(
                Foo, "a", uselist=False, useobject=False
            )
            attributes.register_attribute(
                Foo, "b", uselist=False, useobject=False
            )

            if base is object:
                assert Foo not in (
                    instrumentation._instrumentation_factory._state_finders
                )
            else:
                assert Foo in (
                    instrumentation._instrumentation_factory._state_finders
                )

            f = Foo()
            attributes.instance_state(f)._expire(
                attributes.instance_dict(f), set()
            )
            eq_(f.a, "this is a")
            eq_(f.b, 12)

            f.a = "this is some new a"
            attributes.instance_state(f)._expire(
                attributes.instance_dict(f), set()
            )
            eq_(f.a, "this is a")
            eq_(f.b, 12)

            attributes.instance_state(f)._expire(
                attributes.instance_dict(f), set()
            )
            f.a = "this is another new a"
            eq_(f.a, "this is another new a")
            eq_(f.b, 12)

            attributes.instance_state(f)._expire(
                attributes.instance_dict(f), set()
            )
            eq_(f.a, "this is a")
            eq_(f.b, 12)

            del f.a
            eq_(f.a, None)
            eq_(f.b, 12)

            attributes.instance_state(f)._commit_all(
                attributes.instance_dict(f)
            )
            eq_(f.a, None)
            eq_(f.b, 12)

    def test_inheritance(self):
        """tests that attributes are polymorphic"""

        for base in (object, MyBaseClass, MyClass):

            class Foo(base):
                pass

            class Bar(Foo):
                pass

            register_class(Foo)
            register_class(Bar)

            def func1(state, passive):
                return "this is the foo attr"

            def func2(state, passive):
                return "this is the bar attr"

            def func3(state, passive):
                return "this is the shared attr"

            attributes.register_attribute(
                Foo, "element", uselist=False, callable_=func1, useobject=True
            )
            attributes.register_attribute(
                Foo, "element2", uselist=False, callable_=func3, useobject=True
            )
            attributes.register_attribute(
                Bar, "element", uselist=False, callable_=func2, useobject=True
            )

            x = Foo()
            y = Bar()
            assert x.element == "this is the foo attr"
            assert y.element == "this is the bar attr", y.element
            assert x.element2 == "this is the shared attr"
            assert y.element2 == "this is the shared attr"

    def test_collection_with_backref(self):
        for base in (object, MyBaseClass, MyClass):

            class Post(base):
                pass

            class Blog(base):
                pass

            register_class(Post)
            register_class(Blog)
            attributes.register_attribute(
                Post,
                "blog",
                uselist=False,
                backref="posts",
                trackparent=True,
                useobject=True,
            )
            attributes.register_attribute(
                Blog,
                "posts",
                uselist=True,
                backref="blog",
                trackparent=True,
                useobject=True,
            )
            b = Blog()
            (p1, p2, p3) = (Post(), Post(), Post())
            b.posts.append(p1)
            b.posts.append(p2)
            b.posts.append(p3)
            self.assert_(b.posts == [p1, p2, p3])
            self.assert_(p2.blog is b)

            p3.blog = None
            self.assert_(b.posts == [p1, p2])
            p4 = Post()
            p4.blog = b
            self.assert_(b.posts == [p1, p2, p4])

            p4.blog = b
            p4.blog = b
            self.assert_(b.posts == [p1, p2, p4])

            # assert no failure removing None
            p5 = Post()
            p5.blog = None
            del p5.blog

    def test_history(self):
        for base in (object, MyBaseClass, MyClass):

            class Foo(base):
                pass

            class Bar(base):
                pass

            register_class(Foo)
            register_class(Bar)
            attributes.register_attribute(
                Foo, "name", uselist=False, useobject=False
            )
            attributes.register_attribute(
                Foo, "bars", uselist=True, trackparent=True, useobject=True
            )
            attributes.register_attribute(
                Bar, "name", uselist=False, useobject=False
            )

            f1 = Foo()
            f1.name = "f1"

            eq_(
                attributes.get_state_history(
                    attributes.instance_state(f1), "name"
                ),
                (["f1"], (), ()),
            )

            b1 = Bar()
            b1.name = "b1"
            f1.bars.append(b1)
            eq_(
                attributes.get_state_history(
                    attributes.instance_state(f1), "bars"
                ),
                ([b1], [], []),
            )

            attributes.instance_state(f1)._commit_all(
                attributes.instance_dict(f1)
            )
            attributes.instance_state(b1)._commit_all(
                attributes.instance_dict(b1)
            )

            eq_(
                attributes.get_state_history(
                    attributes.instance_state(f1), "name"
                ),
                ((), ["f1"], ()),
            )
            eq_(
                attributes.get_state_history(
                    attributes.instance_state(f1), "bars"
                ),
                ((), [b1], ()),
            )

            f1.name = "f1mod"
            b2 = Bar()
            b2.name = "b2"
            f1.bars.append(b2)
            eq_(
                attributes.get_state_history(
                    attributes.instance_state(f1), "name"
                ),
                (["f1mod"], (), ["f1"]),
            )
            eq_(
                attributes.get_state_history(
                    attributes.instance_state(f1), "bars"
                ),
                ([b2], [b1], []),
            )
            f1.bars.remove(b1)
            eq_(
                attributes.get_state_history(
                    attributes.instance_state(f1), "bars"
                ),
                ([b2], [], [b1]),
            )

    def test_null_instrumentation(self):
        class Foo(MyBaseClass):
            pass

        register_class(Foo)
        attributes.register_attribute(
            Foo, "name", uselist=False, useobject=False
        )
        attributes.register_attribute(
            Foo, "bars", uselist=True, trackparent=True, useobject=True
        )

        assert Foo.name == attributes.manager_of_class(Foo)["name"]
        assert Foo.bars == attributes.manager_of_class(Foo)["bars"]

    def test_alternate_finders(self):
        """Ensure the generic finder front-end deals with edge cases."""

        class Unknown(object):
            pass

        class Known(MyBaseClass):
            pass

        register_class(Known)
        k, u = Known(), Unknown()

        assert instrumentation.manager_of_class(Unknown) is None
        assert instrumentation.manager_of_class(Known) is not None
        assert instrumentation.manager_of_class(None) is None

        assert attributes.instance_state(k) is not None
        assert_raises((AttributeError, KeyError), attributes.instance_state, u)
        assert_raises(
            (AttributeError, KeyError), attributes.instance_state, None
        )

    def test_unmapped_not_type_error(self):
        """extension version of the same test in test_mapper.

        fixes #3408
        """
        assert_raises_message(
            sa.exc.ArgumentError,
            "Class object expected, got '5'.",
            class_mapper,
            5,
        )

    def test_unmapped_not_type_error_iter_ok(self):
        """extension version of the same test in test_mapper.

        fixes #3408
        """
        assert_raises_message(
            sa.exc.ArgumentError,
            r"Class object expected, got '\(5, 6\)'.",
            class_mapper,
            (5, 6),
        )


class FinderTest(_ExtBase, fixtures.ORMTest):
    def test_standard(self):
        class A(object):
            pass

        register_class(A)

        eq_(type(manager_of_class(A)), instrumentation.ClassManager)

    def test_nativeext_interfaceexact(self):
        class A(object):
            __sa_instrumentation_manager__ = (
                instrumentation.InstrumentationManager
            )

        register_class(A)
        ne_(type(manager_of_class(A)), instrumentation.ClassManager)

    def test_nativeext_submanager(self):
        class Mine(instrumentation.ClassManager):
            pass

        class A(object):
            __sa_instrumentation_manager__ = Mine

        register_class(A)
        eq_(type(manager_of_class(A)), Mine)

    @modifies_instrumentation_finders
    def test_customfinder_greedy(self):
        class Mine(instrumentation.ClassManager):
            pass

        class A(object):
            pass

        def find(cls):
            return Mine

        instrumentation.instrumentation_finders.insert(0, find)
        register_class(A)
        eq_(type(manager_of_class(A)), Mine)

    @modifies_instrumentation_finders
    def test_customfinder_pass(self):
        class A(object):
            pass

        def find(cls):
            return None

        instrumentation.instrumentation_finders.insert(0, find)
        register_class(A)

        eq_(type(manager_of_class(A)), instrumentation.ClassManager)


class InstrumentationCollisionTest(_ExtBase, fixtures.ORMTest):
    def test_none(self):
        class A(object):
            pass

        register_class(A)

        def mgr_factory(cls):
            return instrumentation.ClassManager(cls)

        class B(object):
            __sa_instrumentation_manager__ = staticmethod(mgr_factory)

        register_class(B)

        class C(object):
            __sa_instrumentation_manager__ = instrumentation.ClassManager

        register_class(C)

    def test_single_down(self):
        class A(object):
            pass

        register_class(A)

        def mgr_factory(cls):
            return instrumentation.ClassManager(cls)

        class B(A):
            __sa_instrumentation_manager__ = staticmethod(mgr_factory)

        assert_raises_message(
            TypeError,
            "multiple instrumentation implementations",
            register_class,
            B,
        )

    def test_single_up(self):
        class A(object):
            pass

        # delay registration

        def mgr_factory(cls):
            return instrumentation.ClassManager(cls)

        class B(A):
            __sa_instrumentation_manager__ = staticmethod(mgr_factory)

        register_class(B)

        assert_raises_message(
            TypeError,
            "multiple instrumentation implementations",
            register_class,
            A,
        )

    def test_diamond_b1(self):
        def mgr_factory(cls):
            return instrumentation.ClassManager(cls)

        class A(object):
            pass

        class B1(A):
            pass

        class B2(A):
            __sa_instrumentation_manager__ = staticmethod(mgr_factory)

        class C(object):
            pass

        assert_raises_message(
            TypeError,
            "multiple instrumentation implementations",
            register_class,
            B1,
        )

    def test_diamond_b2(self):
        def mgr_factory(cls):
            return instrumentation.ClassManager(cls)

        class A(object):
            pass

        class B1(A):
            pass

        class B2(A):
            __sa_instrumentation_manager__ = staticmethod(mgr_factory)

        class C(object):
            pass

        register_class(B2)
        assert_raises_message(
            TypeError,
            "multiple instrumentation implementations",
            register_class,
            B1,
        )

    def test_diamond_c_b(self):
        def mgr_factory(cls):
            return instrumentation.ClassManager(cls)

        class A(object):
            pass

        class B1(A):
            pass

        class B2(A):
            __sa_instrumentation_manager__ = staticmethod(mgr_factory)

        class C(object):
            pass

        register_class(C)

        assert_raises_message(
            TypeError,
            "multiple instrumentation implementations",
            register_class,
            B1,
        )


class ExtendedEventsTest(_ExtBase, fixtures.ORMTest):

    """Allow custom Events implementations."""

    @modifies_instrumentation_finders
    def test_subclassed(self):
        class MyEvents(events.InstanceEvents):
            pass

        class MyClassManager(instrumentation.ClassManager):
            dispatch = event.dispatcher(MyEvents)

        instrumentation.instrumentation_finders.insert(
            0, lambda cls: MyClassManager
        )

        class A(object):
            pass

        register_class(A)
        manager = instrumentation.manager_of_class(A)
        assert issubclass(manager.dispatch._events, MyEvents)
