import pytest
import sqlalchemy as sa

from sqlalchemy_utils import merge_references


class TestMergeReferences(object):

    @pytest.fixture
    def User(self, Base):
        class User(Base):
            __tablename__ = 'user'
            id = sa.Column(sa.Integer, primary_key=True)
            name = sa.Column(sa.Unicode(255))

            def __repr__(self):
                return 'User(%r)' % self.name
        return User

    @pytest.fixture
    def BlogPost(self, Base, User):
        class BlogPost(Base):
            __tablename__ = 'blog_post'
            id = sa.Column(sa.Integer, primary_key=True)
            title = sa.Column(sa.Unicode(255))
            content = sa.Column(sa.UnicodeText)
            author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))

            author = sa.orm.relationship(User)
        return BlogPost

    @pytest.fixture
    def init_models(self, User, BlogPost):
        pass

    def test_updates_foreign_keys(self, session, User, BlogPost):
        john = User(name=u'John')
        jack = User(name=u'Jack')
        post = BlogPost(title=u'Some title', author=john)
        post2 = BlogPost(title=u'Other title', author=jack)
        session.add(john)
        session.add(jack)
        session.add(post)
        session.add(post2)
        session.commit()
        merge_references(john, jack)
        session.commit()
        assert post.author == jack
        assert post2.author == jack

    def test_object_merging_whenever_possible(self, session, User, BlogPost):
        john = User(name=u'John')
        jack = User(name=u'Jack')
        post = BlogPost(title=u'Some title', author=john)
        post2 = BlogPost(title=u'Other title', author=jack)
        session.add(john)
        session.add(jack)
        session.add(post)
        session.add(post2)
        session.commit()
        # Load the author for post
        assert post.author_id == john.id
        merge_references(john, jack)
        assert post.author_id == jack.id
        assert post2.author_id == jack.id


class TestMergeReferencesWithManyToManyAssociations(object):

    @pytest.fixture
    def User(self, Base):
        class User(Base):
            __tablename__ = 'user'
            id = sa.Column(sa.Integer, primary_key=True)
            name = sa.Column(sa.Unicode(255))

            def __repr__(self):
                return 'User(%r)' % self.name
        return User

    @pytest.fixture
    def Team(self, Base):
        team_member = sa.Table(
            'team_member', Base.metadata,
            sa.Column(
                'user_id', sa.Integer,
                sa.ForeignKey('user.id', ondelete='CASCADE'),
                primary_key=True
            ),
            sa.Column(
                'team_id', sa.Integer,
                sa.ForeignKey('team.id', ondelete='CASCADE'),
                primary_key=True
            )
        )

        class Team(Base):
            __tablename__ = 'team'
            id = sa.Column(sa.Integer, primary_key=True)
            name = sa.Column(sa.Unicode(255))

            members = sa.orm.relationship(
                'User',
                secondary=team_member,
                backref='teams'
            )
        return Team

    @pytest.fixture
    def init_models(self, User, Team):
        pass

    def test_supports_associations(self, session, User, Team):
        john = User(name=u'John')
        jack = User(name=u'Jack')
        team = Team(name=u'Team')
        team.members.append(john)
        session.add(john)
        session.add(jack)
        session.commit()
        merge_references(john, jack)
        assert john not in team.members
        assert jack in team.members


class TestMergeReferencesWithManyToManyAssociationObjects(object):

    @pytest.fixture
    def Team(self, Base):
        class Team(Base):
            __tablename__ = 'team'
            id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
            name = sa.Column(sa.Unicode(255))
        return Team

    @pytest.fixture
    def User(self, Base):
        class User(Base):
            __tablename__ = 'user'
            id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
            name = sa.Column(sa.Unicode(255))
        return User

    @pytest.fixture
    def TeamMember(self, Base, User, Team):
        class TeamMember(Base):
            __tablename__ = 'team_member'
            user_id = sa.Column(
                sa.Integer,
                sa.ForeignKey(User.id, ondelete='CASCADE'),
                primary_key=True
            )
            team_id = sa.Column(
                sa.Integer,
                sa.ForeignKey(Team.id, ondelete='CASCADE'),
                primary_key=True
            )
            role = sa.Column(sa.Unicode(255))
            team = sa.orm.relationship(
                Team,
                backref=sa.orm.backref(
                    'members',
                    cascade='all, delete-orphan'
                ),
                primaryjoin=team_id == Team.id,
            )
            user = sa.orm.relationship(
                User,
                backref=sa.orm.backref(
                    'memberships',
                    cascade='all, delete-orphan'
                ),
                primaryjoin=user_id == User.id,
            )
        return TeamMember

    @pytest.fixture
    def init_models(self, User, Team, TeamMember):
        pass

    def test_supports_associations(self, session, User, Team, TeamMember):
        john = User(name=u'John')
        jack = User(name=u'Jack')
        team = Team(name=u'Team')
        team.members.append(TeamMember(user=john))
        session.add(john)
        session.add(jack)
        session.add(team)
        session.commit()
        merge_references(john, jack)
        session.commit()
        users = [member.user for member in team.members]
        assert john not in users
        assert jack in users
