import datetime
import os
from sqlalchemy import *
from sqlalchemy import event
from sqlalchemy import sql, util
from sqlalchemy.orm import *
from sqlalchemy.ext.horizontal_shard import ShardedSession
from sqlalchemy.sql import operators
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.engines import testing_engine
from sqlalchemy.testing import eq_

# TODO: ShardTest can be turned into a base for further subclasses




class ShardTest(object):
    __skip_if__ = (lambda: util.win32,)
    __requires__ = 'sqlite',

    schema = None

    def setUp(self):
        global db1, db2, db3, db4, weather_locations, weather_reports

        db1, db2, db3, db4 = self._init_dbs()

        meta = MetaData()
        ids = Table('ids', meta,
            Column('nextid', Integer, nullable=False))

        def id_generator(ctx):
            # in reality, might want to use a separate transaction for this.

            c = db1.contextual_connect()
            nextid = c.execute(ids.select(for_update=True)).scalar()
            c.execute(ids.update(values={ids.c.nextid: ids.c.nextid + 1}))
            return nextid

        weather_locations = Table("weather_locations", meta,
                Column('id', Integer, primary_key=True, default=id_generator),
                Column('continent', String(30), nullable=False),
                Column('city', String(50), nullable=False),
                schema=self.schema
            )

        weather_reports = Table(
            'weather_reports',
            meta,
            Column('id', Integer, primary_key=True),
            Column('location_id', Integer,
                   ForeignKey(weather_locations.c.id)),
            Column('temperature', Float),
            Column('report_time', DateTime,
                   default=datetime.datetime.now),
            schema=self.schema
            )

        for db in (db1, db2, db3, db4):
            meta.create_all(db)

        db1.execute(ids.insert(), nextid=1)

        self.setup_session()
        self.setup_mappers()


    @classmethod
    def setup_session(cls):
        global create_session
        shard_lookup = {
            'North America': 'north_america',
            'Asia': 'asia',
            'Europe': 'europe',
            'South America': 'south_america',
            }

        def shard_chooser(mapper, instance, clause=None):
            if isinstance(instance, WeatherLocation):
                return shard_lookup[instance.continent]
            else:
                return shard_chooser(mapper, instance.location)

        def id_chooser(query, ident):
            return ['north_america', 'asia', 'europe', 'south_america']

        def query_chooser(query):
            ids = []


            class FindContinent(sql.ClauseVisitor):

                def visit_binary(self, binary):
                    if binary.left.shares_lineage(
                            weather_locations.c.continent):
                        if binary.operator == operators.eq:
                            ids.append(shard_lookup[binary.right.value])
                        elif binary.operator == operators.in_op:
                            for bind in binary.right.clauses:
                                ids.append(shard_lookup[bind.value])

            if query._criterion is not None:
                FindContinent().traverse(query._criterion)
            if len(ids) == 0:
                return ['north_america', 'asia', 'europe',
                        'south_america']
            else:
                return ids

        create_session = sessionmaker(class_=ShardedSession,
                autoflush=True, autocommit=False)
        create_session.configure(shards={
            'north_america': db1,
            'asia': db2,
            'europe': db3,
            'south_america': db4,
            }, shard_chooser=shard_chooser, id_chooser=id_chooser,
                query_chooser=query_chooser)


    @classmethod
    def setup_mappers(cls):
        global WeatherLocation, Report

        class WeatherLocation(object):
            def __init__(self, continent, city):
                self.continent = continent
                self.city = city

        class Report(object):
            def __init__(self, temperature):
                self.temperature = temperature

        mapper(WeatherLocation, weather_locations, properties={
            'reports': relationship(Report, backref='location'),
            'city': deferred(weather_locations.c.city),
        })

        mapper(Report, weather_reports)

    def _fixture_data(self):
        tokyo = WeatherLocation('Asia', 'Tokyo')
        newyork = WeatherLocation('North America', 'New York')
        toronto = WeatherLocation('North America', 'Toronto')
        london = WeatherLocation('Europe', 'London')
        dublin = WeatherLocation('Europe', 'Dublin')
        brasilia = WeatherLocation('South America', 'Brasila')
        quito = WeatherLocation('South America', 'Quito')
        tokyo.reports.append(Report(80.0))
        newyork.reports.append(Report(75))
        quito.reports.append(Report(85))
        sess = create_session()
        for c in [
            tokyo,
            newyork,
            toronto,
            london,
            dublin,
            brasilia,
            quito,
            ]:
            sess.add(c)
        sess.commit()
        sess.close()
        return sess

    def test_roundtrip(self):
        sess = self._fixture_data()
        tokyo = sess.query(WeatherLocation).filter_by(city="Tokyo").one()
        tokyo.city  # reload 'city' attribute on tokyo
        sess.expunge_all()
        eq_(db2.execute(weather_locations.select()).fetchall(), [(1,
            'Asia', 'Tokyo')])
        eq_(db1.execute(weather_locations.select()).fetchall(), [(2,
            'North America', 'New York'), (3, 'North America', 'Toronto'
            )])
        eq_(sess.execute(weather_locations.select(), shard_id='asia'
            ).fetchall(), [(1, 'Asia', 'Tokyo')])
        t = sess.query(WeatherLocation).get(tokyo.id)
        eq_(t.city, tokyo.city)
        eq_(t.reports[0].temperature, 80.0)
        north_american_cities = \
            sess.query(WeatherLocation).filter(WeatherLocation.continent
                == 'North America')
        eq_(set([c.city for c in north_american_cities]),
            set(['New York', 'Toronto']))
        asia_and_europe = \
            sess.query(WeatherLocation).filter(
                WeatherLocation.continent.in_(['Europe', 'Asia']))
        eq_(set([c.city for c in asia_and_europe]), set(['Tokyo',
            'London', 'Dublin']))

    def test_shard_id_event(self):
        canary = []
        def load(instance, ctx):
            canary.append(ctx.attributes["shard_id"])

        event.listen(WeatherLocation, "load", load)
        sess = self._fixture_data()

        tokyo = sess.query(WeatherLocation).\
                    filter_by(city="Tokyo").set_shard("asia").one()

        sess.query(WeatherLocation).all()
        eq_(
            canary,
            ['asia', 'north_america', 'north_america',
            'europe', 'europe', 'south_america',
            'south_america']
        )

class DistinctEngineShardTest(ShardTest, fixtures.TestBase):

    def _init_dbs(self):
        db1 = testing_engine('sqlite:///shard1.db',
                            options=dict(pool_threadlocal=True))
        db2 = testing_engine('sqlite:///shard2.db')
        db3 = testing_engine('sqlite:///shard3.db')
        db4 = testing_engine('sqlite:///shard4.db')

        return db1, db2, db3, db4

    def tearDown(self):
        clear_mappers()

        for db in (db1, db2, db3, db4):
            db.connect().invalidate()
        for i in range(1, 5):
            os.remove("shard%d.db" % i)

class AttachedFileShardTest(ShardTest, fixtures.TestBase):
    schema = "changeme"

    def _init_dbs(self):
        db1 = testing_engine('sqlite://', options={"execution_options":
                                            {"shard_id": "shard1"}})
        db2 = db1.execution_options(shard_id="shard2")
        db3 = db1.execution_options(shard_id="shard3")
        db4 = db1.execution_options(shard_id="shard4")

        import re
        @event.listens_for(db1, "before_cursor_execute", retval=True)
        def _switch_shard(conn, cursor, stmt, params, context, executemany):
            shard_id = conn._execution_options['shard_id']
            # because SQLite can't just give us a "use" statement, we have
            # to use the schema hack to locate table names
            if shard_id:
                stmt = re.sub(r"\"?changeme\"?\.", shard_id + "_", stmt)

            return stmt, params

        return db1, db2, db3, db4


