import os
import warnings

import pytest
import sqlalchemy as sa
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base, synonym_for
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.session import close_all_sessions
from sqlalchemy_utils import (
    aggregates,
    coercion_listener,
    i18n,
    InstrumentedList
)

from sqlalchemy_utils.types.pg_composite import remove_composite_listeners


@sa.event.listens_for(sa.engine.Engine, 'before_cursor_execute')
def count_sql_calls(conn, cursor, statement, parameters, context, executemany):
    try:
        conn.query_count += 1
    except AttributeError:
        conn.query_count = 0


warnings.simplefilter('error', sa.exc.SAWarning)

sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener)


def get_locale():
    class Locale():
        territories = {'FI': 'Finland'}

    return Locale()


@pytest.fixture(scope='session')
def db_name():
    return os.environ.get('SQLALCHEMY_UTILS_TEST_DB', 'sqlalchemy_utils_test')


@pytest.fixture(scope='session')
def postgresql_db_user():
    return os.environ.get('SQLALCHEMY_UTILS_TEST_POSTGRESQL_USER', 'postgres')


@pytest.fixture(scope='session')
def mysql_db_user():
    return os.environ.get('SQLALCHEMY_UTILS_TEST_MYSQL_USER', 'root')


@pytest.fixture
def postgresql_dsn(postgresql_db_user, db_name):
    return 'postgresql://{0}@localhost/{1}'.format(postgresql_db_user, db_name)


@pytest.fixture
def mysql_dsn(mysql_db_user, db_name):
    return 'mysql+pymysql://{0}@localhost/{1}'.format(mysql_db_user, db_name)


@pytest.fixture
def sqlite_memory_dsn():
    return 'sqlite:///:memory:'


@pytest.fixture
def sqlite_none_database_dsn():
    return 'sqlite://'


@pytest.fixture
def sqlite_file_dsn(db_name):
    return 'sqlite:///{0}.db'.format(db_name)


@pytest.fixture
def mssql_db_user():
    return os.environ.get('SQLALCHEMY_UTILS_TEST_MSSQL_USER', 'sa')


@pytest.fixture
def mssql_db_password():
    return os.environ.get('SQLALCHEMY_UTILS_TEST_MSSQL_PASSWORD',
                          'Strong!Passw0rd')


@pytest.fixture
def mssql_db_driver():
    driver = os.environ.get('SQLALCHEMY_UTILS_TEST_MSSQL_DRIVER',
                            'ODBC Driver 17 for SQL Server')
    return driver.replace(' ', '+')


@pytest.fixture
def mssql_dsn(mssql_db_user, mssql_db_password, mssql_db_driver, db_name):
    return 'mssql+pyodbc://{0}:{1}@localhost/{2}?driver={3}'\
        .format(mssql_db_user, mssql_db_password, db_name, mssql_db_driver)


@pytest.fixture
def dsn(request):
    if 'postgresql_dsn' in request.fixturenames:
        return request.getfixturevalue('postgresql_dsn')
    elif 'mysql_dsn' in request.fixturenames:
        return request.getfixturevalue('mysql_dsn')
    elif 'mssql_dsn' in request.fixturenames:
        return request.getfixturevalue('mssql_dsn')
    elif 'sqlite_file_dsn' in request.fixturenames:
        return request.getfixturevalue('sqlite_file_dsn')
    elif 'sqlite_memory_dsn' in request.fixturenames:
        pass  # Return default
    return request.getfixturevalue('sqlite_memory_dsn')


@pytest.fixture
def engine(dsn):
    engine = create_engine(dsn)
    # engine.echo = True
    return engine


@pytest.fixture
def connection(engine):
    return engine.connect()


@pytest.fixture
def Base():
    return declarative_base()


@pytest.fixture
def User(Base):
    class User(Base):
        __tablename__ = 'user'
        id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
        name = sa.Column(sa.Unicode(255))
    return User


@pytest.fixture
def Category(Base):

    class Category(Base):
        __tablename__ = 'category'
        id = sa.Column(sa.Integer, primary_key=True)
        name = sa.Column(sa.Unicode(255))
        title = sa.Column(sa.Unicode(255))

        @hybrid_property
        def full_name(self):
            return u'%s %s' % (self.title, self.name)

        @full_name.expression
        def full_name(self):
            return sa.func.concat(self.title, ' ', self.name)

        @hybrid_property
        def articles_count(self):
            return len(self.articles)

        @articles_count.expression
        def articles_count(cls):
            Article = Base._decl_class_registry['Article']
            return (
                sa.select([sa.func.count(Article.id)])
                .where(Article.category_id == cls.id)
                .correlate(Article.__table__)
                .label('article_count')
            )

        @property
        def name_alias(self):
            return self.name

        @synonym_for('name')
        @property
        def name_synonym(self):
            return self.name
    return Category


@pytest.fixture
def Article(Base, Category):
    class Article(Base):
        __tablename__ = 'article'
        id = sa.Column(sa.Integer, primary_key=True)
        name = sa.Column(sa.Unicode(255), index=True)
        category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))

        category = sa.orm.relationship(
            Category,
            primaryjoin=category_id == Category.id,
            backref=sa.orm.backref(
                'articles',
                collection_class=InstrumentedList
            )
        )
    return Article


@pytest.fixture
def init_models(User, Category, Article):
    pass


@pytest.fixture
def session(request, engine, connection, Base, init_models):
    sa.orm.configure_mappers()
    Base.metadata.create_all(connection)
    Session = sessionmaker(bind=connection)
    session = Session()
    i18n.get_locale = get_locale

    def teardown():
        aggregates.manager.reset()
        close_all_sessions()
        Base.metadata.drop_all(connection)
        remove_composite_listeners()
        connection.close()
        engine.dispose()

    request.addfinalizer(teardown)

    return session
