import warnings

import sqlalchemy as sa
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base, synonym_for
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import sessionmaker

from sqlalchemy_utils import (
    aggregates,
    coercion_listener,
    i18n,
    InstrumentedList
)
from sqlalchemy_utils.types.pg_composite import remove_composite_listeners


@sa.event.listens_for(sa.engine.Engine, 'before_cursor_execute')
def count_sql_calls(conn, cursor, statement, parameters, context, executemany):
    try:
        conn.query_count += 1
    except AttributeError:
        conn.query_count = 0


warnings.simplefilter('error', sa.exc.SAWarning)


sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener)


def get_locale():
    class Locale():
        territories = {'FI': 'Finland'}

    return Locale()


class TestCase(object):
    dns = 'sqlite:///:memory:'
    create_tables = True

    def setup_method(self, method):
        self.engine = create_engine(self.dns)
        # self.engine.echo = True
        self.connection = self.engine.connect()
        self.Base = declarative_base()

        self.create_models()
        sa.orm.configure_mappers()
        if self.create_tables:
            self.Base.metadata.create_all(self.connection)

        Session = sessionmaker(bind=self.connection)
        self.session = Session()

        i18n.get_locale = get_locale

    def teardown_method(self, method):
        aggregates.manager.reset()
        self.session.close_all()
        if self.create_tables:
            self.Base.metadata.drop_all(self.connection)
        remove_composite_listeners()
        self.connection.close()
        self.engine.dispose()

    def create_models(self):
        class User(self.Base):
            __tablename__ = 'user'
            id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
            name = sa.Column(sa.Unicode(255))

        class Category(self.Base):
            __tablename__ = 'category'
            id = sa.Column(sa.Integer, primary_key=True)
            name = sa.Column(sa.Unicode(255))

            @hybrid_property
            def articles_count(self):
                return len(self.articles)

            @articles_count.expression
            def articles_count(cls):
                return (
                    sa.select([sa.func.count(self.Article.id)])
                    .where(self.Article.category_id == self.Category.id)
                    .correlate(self.Article.__table__)
                    .label('article_count')
                )

            @property
            def name_alias(self):
                return self.name

            @synonym_for('name')
            @property
            def name_synonym(self):
                return self.name

        class Article(self.Base):
            __tablename__ = 'article'
            id = sa.Column(sa.Integer, primary_key=True)
            name = sa.Column(sa.Unicode(255), index=True)
            category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))

            category = sa.orm.relationship(
                Category,
                primaryjoin=category_id == Category.id,
                backref=sa.orm.backref(
                    'articles',
                    collection_class=InstrumentedList
                )
            )

        self.User = User
        self.Category = Category
        self.Article = Article


def assert_contains(clause, query):
    # Test that query executes
    query.all()
    assert clause in str(query)
