
"""Catch moment when tables are in sync on master and slave.
"""

import sys, time, skytools
from londiste.handler import build_handler, load_handler_modules

from londiste.util import find_copy_source

class ATable:
    def __init__(self, row):
        self.table_name = row['table_name']
        self.dest_table = row['dest_table'] or row['table_name']
        self.merge_state = row['merge_state']
        attrs = row['table_attrs'] or ''
        self.table_attrs = skytools.db_urldecode(attrs)
        hstr = self.table_attrs.get('handler', '')
        self.plugin = build_handler(self.table_name, hstr, row['dest_table'])

class Syncer(skytools.DBScript):
    """Walks tables in primary key order and checks if data matches."""

    bad_tables = 0

    provider_info = None

    def __init__(self, args):
        """Syncer init."""
        skytools.DBScript.__init__(self, 'londiste3', args)
        self.set_single_loop(1)

        # compat names
        self.queue_name = self.cf.get("pgq_queue_name", '')
        self.consumer_name = self.cf.get('pgq_consumer_id', '')

        # good names
        if not self.queue_name:
            self.queue_name = self.cf.get("queue_name")
        if not self.consumer_name:
            self.consumer_name = self.cf.get('consumer_name', self.job_name)

        self.lock_timeout = self.cf.getfloat('lock_timeout', 10)

        if self.pidfile:
            self.pidfile += ".repair"

        load_handler_modules(self.cf)

    def set_lock_timeout(self, curs):
        ms = int(1000 * self.lock_timeout)
        if ms > 0:
            q = "SET LOCAL statement_timeout = %d" % ms
            self.log.debug(q)
            curs.execute(q)

    def init_optparse(self, p=None):
        """Initialize cmdline switches."""
        p = skytools.DBScript.init_optparse(self, p)
        p.add_option("--force", action="store_true", help="ignore lag")
        return p

    def get_provider_info(self, setup_curs):
        q = "select ret_code, ret_note, node_name, node_type, worker_name"\
            " from pgq_node.get_node_info(%s)"
        res = self.exec_cmd(setup_curs, q, [self.queue_name])
        pnode = res[0]
        self.log.info('Provider: %s (%s)', pnode['node_name'], pnode['node_type'])
        return pnode

    def check_consumer(self, setup_db, dst_db):
        """Before locking anything check if consumer is working ok."""

        setup_curs = setup_db.cursor()
        dst_curs = dst_db.cursor()
        c = 0
        while 1:
            q = "select * from pgq_node.get_consumer_state(%s, %s)"
            res = self.exec_cmd(dst_db, q, [self.queue_name, self.consumer_name])
            completed_tick = res[0]['completed_tick']

            q = "select extract(epoch from ticker_lag) from pgq.get_queue_info(%s)"
            setup_curs.execute(q, [self.queue_name])
            ticker_lag = setup_curs.fetchone()[0]

            q = "select extract(epoch from (now() - t.tick_time)) as lag"\
                " from pgq.tick t, pgq.queue q"\
                " where q.queue_name = %s"\
                "   and t.tick_queue = q.queue_id"\
                "   and t.tick_id = %s"
            setup_curs.execute(q, [self.queue_name, completed_tick])
            res = setup_curs.fetchall()

            if len(res) == 0:
                self.log.warning('Consumer completed_tick (%d) to not exists on provider (%s), too big lag?',
                                 completed_tick, self.provider_info['node_name'])
                self.sleep(10)
                continue

            consumer_lag = res[0][0]
            if consumer_lag < ticker_lag + 5:
                break

            lag_msg = 'Consumer lag: %s, ticker_lag %s, too big difference, waiting'
            if c % 30 == 0:
                self.log.warning(lag_msg, consumer_lag, ticker_lag)
            else:
                self.log.debug(lag_msg, consumer_lag, ticker_lag)
            c += 1
            time.sleep(1)

    def get_tables(self, db):
        """Load table info.

        Returns tuple of (dict(name->ATable), namelist)"""

        curs = db.cursor()
        q = "select table_name, merge_state, dest_table, table_attrs"\
            " from londiste.get_table_list(%s) where local"
        curs.execute(q, [self.queue_name])
        rows = curs.fetchall()
        db.commit()

        res = {}
        names = []
        for row in rows:
            t = ATable(row)
            res[t.table_name] = t
            names.append(t.table_name)
        return res, names

    def work(self):
        """Syncer main function."""

        # 'SELECT 1' and COPY must use same snapshot, so change isolation level.
        dst_db = self.get_database('db', isolation_level = skytools.I_REPEATABLE_READ)
        pnode, ploc = self.get_provider_location(dst_db)

        dst_tables, names = self.get_tables(dst_db)

        if len(self.args) > 2:
            tlist = self.args[2:]
        else:
            tlist = names

        for tbl in tlist:
            tbl = skytools.fq_name(tbl)
            if not tbl in dst_tables:
                self.log.warning('Table not subscribed: %s', tbl)
                continue
            t2 = dst_tables[tbl]
            if t2.merge_state != 'ok':
                self.log.warning('Table %s not synced yet, no point', tbl)
                continue

            pnode, ploc, wname = find_copy_source(self, self.queue_name, tbl, pnode, ploc)
            self.log.info('%s: Using node %s as provider', tbl, pnode)

            if wname is None:
                wname = self.consumer_name
            self.downstream_worker_name = wname

            self.process_one_table(tbl, t2, dst_db, pnode, ploc)

        # signal caller about bad tables
        sys.exit(self.bad_tables)

    def process_one_table(self, tbl, t2, dst_db, provider_node, provider_loc):

        lock_db = self.get_database('lock_db', connstr = provider_loc, profile = 'remote')
        setup_db = self.get_database('setup_db', autocommit = 1, connstr = provider_loc, profile = 'remote')

        src_db = self.get_database('provider_db', connstr = provider_loc, profile = 'remote',
                                   isolation_level = skytools.I_REPEATABLE_READ)

        setup_curs = setup_db.cursor()

        # provider node info
        self.provider_info = self.get_provider_info(setup_curs)

        src_tables, ignore = self.get_tables(src_db)
        if not tbl in src_tables:
            self.log.warning('Table not available on provider: %s', tbl)
            return
        t1 = src_tables[tbl]

        if t1.merge_state != 'ok':
            self.log.warning('Table %s not ready yet on provider', tbl)
            return

        #self.check_consumer(setup_db, dst_db)

        self.check_table(t1, t2, lock_db, src_db, dst_db, setup_db)
        lock_db.commit()
        src_db.commit()
        dst_db.commit()

        self.close_database('setup_db')
        self.close_database('lock_db')
        self.close_database('provider_db')

    def force_tick(self, setup_curs, wait=True):
        q = "select pgq.force_tick(%s)"
        setup_curs.execute(q, [self.queue_name])
        res = setup_curs.fetchone()
        cur_pos = res[0]
        if not wait:
            return cur_pos

        start = time.time()
        while 1:
            time.sleep(0.5)
            setup_curs.execute(q, [self.queue_name])
            res = setup_curs.fetchone()
            if res[0] != cur_pos:
                # new pos
                return res[0]

            # dont loop more than 10 secs
            dur = time.time() - start
            #if dur > 10 and not self.options.force:
            #    raise Exception("Ticker seems dead")

    def check_table(self, t1, t2, lock_db, src_db, dst_db, setup_db):
        """Get transaction to same state, then process."""

        src_tbl = t1.dest_table
        dst_tbl = t2.dest_table

        lock_curs = lock_db.cursor()
        src_curs = src_db.cursor()
        dst_curs = dst_db.cursor()

        if not skytools.exists_table(src_curs, src_tbl):
            self.log.warning("Table %s does not exist on provider side", src_tbl)
            return
        if not skytools.exists_table(dst_curs, dst_tbl):
            self.log.warning("Table %s does not exist on subscriber side", dst_tbl)
            return

        # lock table against changes
        try:
            if self.provider_info['node_type'] == 'root':
                self.lock_table_root(lock_db, setup_db, dst_db, src_tbl, dst_tbl)
            else:
                self.lock_table_branch(lock_db, setup_db, dst_db, src_tbl, dst_tbl)

            # take snapshot on provider side
            src_db.commit()
            src_curs.execute("SELECT 1")

            # take snapshot on subscriber side
            dst_db.commit()
            dst_curs.execute("SELECT 1")
        finally:
            # release lock
            if self.provider_info['node_type'] == 'root':
                self.unlock_table_root(lock_db, setup_db)
            else:
                self.unlock_table_branch(lock_db, setup_db)

        # do work
        bad = self.process_sync(t1, t2, src_db, dst_db)
        if bad:
            self.bad_tables += 1

        # done
        src_db.commit()
        dst_db.commit()

    def lock_table_root(self, lock_db, setup_db, dst_db, src_tbl, dst_tbl):

        setup_curs = setup_db.cursor()
        lock_curs = lock_db.cursor()

        # lock table in separate connection
        self.log.info('Locking %s', src_tbl)
        lock_db.commit()
        self.set_lock_timeout(lock_curs)
        lock_time = time.time()
        lock_curs.execute("LOCK TABLE %s IN SHARE MODE" % skytools.quote_fqident(src_tbl))

        # now wait until consumer has updated target table until locking
        self.log.info('Syncing %s', dst_tbl)

        # consumer must get futher than this tick
        tick_id = self.force_tick(setup_curs)
        # try to force second tick also
        self.force_tick(setup_curs)

        # now wait
        while 1:
            time.sleep(0.5)

            q = "select * from pgq_node.get_node_info(%s)"
            res = self.exec_cmd(dst_db, q, [self.queue_name])
            last_tick = res[0]['worker_last_tick']
            if last_tick > tick_id:
                break

            # limit lock time
            if time.time() > lock_time + self.lock_timeout and not self.options.force:
                self.log.error('Consumer lagging too much, exiting')
                lock_db.rollback()
                sys.exit(1)

    def unlock_table_root(self, lock_db, setup_db):
        lock_db.commit()

    def lock_table_branch(self, lock_db, setup_db, dst_db, src_tbl, dst_tbl):
        setup_curs = setup_db.cursor()

        lock_time = time.time()
        self.old_worker_paused = self.pause_consumer(setup_curs, self.provider_info['worker_name'])

        lock_curs = lock_db.cursor()
        self.log.info('Syncing %s', dst_tbl)

        # consumer must get futher than this tick
        tick_id = self.force_tick(setup_curs, False)

        # now wait
        while 1:
            time.sleep(0.5)

            q = "select * from pgq_node.get_node_info(%s)"
            res = self.exec_cmd(dst_db, q, [self.queue_name])
            last_tick = res[0]['worker_last_tick']
            if last_tick > tick_id:
                break

            # limit lock time
            if time.time() > lock_time + self.lock_timeout and not self.options.force:
                self.log.error('Consumer lagging too much, exiting')
                lock_db.rollback()
                sys.exit(1)

    def unlock_table_branch(self, lock_db, setup_db):
        # keep worker paused if it was so before
        if self.old_worker_paused:
            return
        setup_curs = setup_db.cursor()
        self.resume_consumer(setup_curs, self.provider_info['worker_name'])

    def process_sync(self, t1, t2, src_db, dst_db):
        """It gets 2 connections in state where tbl should be in same state.
        """
        raise Exception('process_sync not implemented')

    def get_provider_location(self, dst_db):
        curs = dst_db.cursor()
        q = "select * from pgq_node.get_node_info(%s)"
        rows = self.exec_cmd(dst_db, q, [self.queue_name])
        return (rows[0]['provider_node'], rows[0]['provider_location'])

    def pause_consumer(self, curs, cons_name):
        self.log.info("Pausing upstream worker: %s", cons_name)
        return self.set_pause_flag(curs, cons_name, True)

    def resume_consumer(self, curs, cons_name):
        self.log.info("Resuming upstream worker: %s", cons_name)
        return self.set_pause_flag(curs, cons_name, False)

    def set_pause_flag(self, curs, cons_name, flag):
        q = "select * from pgq_node.get_consumer_state(%s, %s)"
        res = self.exec_cmd(curs, q, [self.queue_name, cons_name])
        oldflag = res[0]['paused']

        q = "select * from pgq_node.set_consumer_paused(%s, %s, %s)"
        self.exec_cmd(curs, q, [self.queue_name, cons_name, flag])

        while 1:
            q = "select * from pgq_node.get_consumer_state(%s, %s)"
            res = self.exec_cmd(curs, q, [self.queue_name, cons_name])
            if res[0]['uptodate']:
                break
            time.sleep(0.5)
        return oldflag
