"""Catch moment when tables are in sync on master and slave.
"""


import os
import subprocess
import sys
import time

import skytools


class TableRepair:
    """Checks that tables in two databases are in sync."""

    def __init__(self, table_name, log):
        self.table_name = table_name
        self.fq_table_name = skytools.quote_fqident(table_name)
        self.log = log

        self.pkey_list = []
        self.common_fields = []
        self.apply_fixes = False
        self.apply_cursor = None

        self.reset()

    def reset(self):
        self.cnt_insert = 0
        self.cnt_update = 0
        self.cnt_delete = 0
        self.total_src = 0
        self.total_dst = 0
        self.pkey_list = []
        self.common_fields = []
        self.apply_fixes = False
        self.apply_cursor = None

    def do_repair(self, src_db, dst_db, where, pfx='repair', apply_fixes=False):
        """Actual comparison."""

        self.reset()

        src_curs = src_db.cursor()
        dst_curs = dst_db.cursor()

        self.apply_fixes = apply_fixes
        if apply_fixes:
            self.apply_cursor = dst_curs

        self.log.info('Checking %s', self.table_name)

        copy_tbl = self.gen_copy_tbl(src_curs, dst_curs, where)

        dump_src = "%s.%s.src" % (pfx, self.table_name)
        dump_dst = "%s.%s.dst" % (pfx, self.table_name)
        fix = "%s.%s.fix" % (pfx, self.table_name)

        self.log.info("Dumping src table: %s", self.table_name)
        self.dump_table(copy_tbl, src_curs, dump_src)
        src_db.commit()
        self.log.info("Dumping dst table: %s", self.table_name)
        self.dump_table(copy_tbl, dst_curs, dump_dst)
        dst_db.commit()

        self.log.info("Sorting src table: %s", self.table_name)
        self.do_sort(dump_src, dump_src + '.sorted')

        self.log.info("Sorting dst table: %s", self.table_name)
        self.do_sort(dump_dst, dump_dst + '.sorted')

        self.dump_compare(dump_src + ".sorted", dump_dst + ".sorted", fix)

        os.unlink(dump_src)
        os.unlink(dump_dst)
        os.unlink(dump_src + ".sorted")
        os.unlink(dump_dst + ".sorted")

        if apply_fixes:
            dst_db.commit()

    def do_sort(self, src, dst):
        p = subprocess.Popen(["sort", "--version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        s_ver = p.communicate()[0]
        del p

        xenv = os.environ.copy()
        xenv['LANG'] = 'C'
        xenv['LC_ALL'] = 'C'

        cmdline = ['sort', '-T', '.']
        if s_ver.find("coreutils") > 0:
            cmdline.append('-S')
            cmdline.append('30%')
        cmdline.append('-o')
        cmdline.append(dst)
        cmdline.append(src)
        p = subprocess.Popen(cmdline, env=xenv)
        if p.wait() != 0:
            raise Exception('sort failed')

    def gen_copy_tbl(self, src_curs, dst_curs, where):
        """Create COPY expession from common fields."""
        self.pkey_list = skytools.get_table_pkeys(src_curs, self.table_name)
        dst_pkey = skytools.get_table_pkeys(dst_curs, self.table_name)
        if dst_pkey != self.pkey_list:
            self.log.error('pkeys do not match')
            sys.exit(1)

        src_cols = skytools.get_table_columns(src_curs, self.table_name)
        dst_cols = skytools.get_table_columns(dst_curs, self.table_name)
        field_list = []
        for f in self.pkey_list:
            field_list.append(f)
        for f in src_cols:
            if f in self.pkey_list:
                continue
            if f in dst_cols:
                field_list.append(f)

        self.common_fields = field_list

        fqlist = [skytools.quote_ident(col) for col in field_list]

        tbl_expr = "select %s from %s" % (",".join(fqlist), self.fq_table_name)
        if where:
            tbl_expr += ' where ' + where
        tbl_expr = "COPY (%s) TO STDOUT" % tbl_expr

        self.log.debug("using copy expr: %s", tbl_expr)

        return tbl_expr

    def dump_table(self, copy_cmd, curs, fn):
        """Dump table to disk."""
        with open(fn, "w", 64 * 1024) as f:
            curs.copy_expert(copy_cmd, f)
            self.log.info('%s: Got %d bytes', self.table_name, f.tell())

    def get_row(self, ln):
        """Parse a row into dict."""
        if not ln:
            return None
        t = ln[:-1].split('\t')
        row = {}
        for i in range(len(self.common_fields)):
            row[self.common_fields[i]] = t[i]
        return row

    def dump_compare(self, src_fn, dst_fn, fix):
        """Dump + compare single table."""
        self.log.info("Comparing dumps: %s", self.table_name)
        f1 = open(src_fn, "r", 64 * 1024)
        f2 = open(dst_fn, "r", 64 * 1024)
        src_ln = f1.readline()
        dst_ln = f2.readline()
        if src_ln:
            self.total_src += 1
        if dst_ln:
            self.total_dst += 1

        if os.path.isfile(fix):
            os.unlink(fix)

        while src_ln or dst_ln:
            keep_src = keep_dst = 0
            if src_ln != dst_ln:
                src_row = self.get_row(src_ln)
                dst_row = self.get_row(dst_ln)

                diff = self.cmp_keys(src_row, dst_row)
                if diff > 0:
                    # src > dst
                    self.got_missed_delete(dst_row, fix)
                    keep_src = 1
                elif diff < 0:
                    # src < dst
                    self.got_missed_insert(src_row, fix)
                    keep_dst = 1
                else:
                    if self.cmp_data(src_row, dst_row) != 0:
                        self.got_missed_update(src_row, dst_row, fix)

            if not keep_src:
                src_ln = f1.readline()
                if src_ln:
                    self.total_src += 1
            if not keep_dst:
                dst_ln = f2.readline()
                if dst_ln:
                    self.total_dst += 1

        self.log.info("finished %s: src: %d rows, dst: %d rows,"
                      " missed: %d inserts, %d updates, %d deletes",
                      self.table_name, self.total_src, self.total_dst,
                      self.cnt_insert, self.cnt_update, self.cnt_delete)
        f1.close()
        f2.close()

    def got_missed_insert(self, src_row, fn):
        """Create sql for missed insert."""
        self.cnt_insert += 1
        fld_list = self.common_fields
        fq_list = []
        val_list = []
        for f in fld_list:
            fq_list.append(skytools.quote_ident(f))
            v = skytools.unescape_copy(src_row[f])
            val_list.append(skytools.quote_literal(v))
        q = "insert into %s (%s) values (%s);" % (
            self.fq_table_name, ", ".join(fq_list), ", ".join(val_list))
        self.show_fix(q, 'insert', fn)

    def got_missed_update(self, src_row, dst_row, fn):
        """Create sql for missed update."""
        self.cnt_update += 1
        fld_list = self.common_fields
        set_list = []
        whe_list = []
        for f in self.pkey_list:
            self.addcmp(whe_list, skytools.quote_ident(f), skytools.unescape_copy(src_row[f]))
        for f in fld_list:
            v1 = src_row[f]
            v2 = dst_row[f]
            if self.cmp_value(v1, v2) == 0:
                continue

            self.addeq(set_list, skytools.quote_ident(f), skytools.unescape_copy(v1))
            self.addcmp(whe_list, skytools.quote_ident(f), skytools.unescape_copy(v2))

        q = "update only %s set %s where %s;" % (
            self.fq_table_name, ", ".join(set_list), " and ".join(whe_list))
        self.show_fix(q, 'update', fn)

    def got_missed_delete(self, dst_row, fn):
        """Create sql for missed delete."""
        self.cnt_delete += 1
        whe_list = []
        for f in self.pkey_list:
            self.addcmp(whe_list, skytools.quote_ident(f), skytools.unescape_copy(dst_row[f]))
        q = "delete from only %s where %s;" % (self.fq_table_name, " and ".join(whe_list))
        self.show_fix(q, 'delete', fn)

    def show_fix(self, q, desc, fn):
        """Print/write/apply repair sql."""
        self.log.debug("missed %s: %s", desc, q)
        with open(fn, "a") as f:
            f.write("%s\n" % q)

        if self.apply_fixes:
            self.apply_cursor.execute(q)

    def addeq(self, dst_list, f, v):
        """Add quoted SET."""
        vq = skytools.quote_literal(v)
        s = "%s = %s" % (f, vq)
        dst_list.append(s)

    def addcmp(self, dst_list, f, v):
        """Add quoted comparison."""
        if v is None:
            s = "%s is null" % f
        else:
            vq = skytools.quote_literal(v)
            s = "%s = %s" % (f, vq)
        dst_list.append(s)

    def cmp_data(self, src_row, dst_row):
        """Compare data field-by-field."""
        for k in self.common_fields:
            v1 = src_row[k]
            v2 = dst_row[k]
            if self.cmp_value(v1, v2) != 0:
                return -1
        return 0

    def cmp_value(self, v1, v2):
        """Compare single field, tolerates tz vs notz dates."""
        if v1 == v2:
            return 0

        # try to work around tz vs. notz
        z1 = len(v1)
        z2 = len(v2)
        if z1 == z2 + 3 and z2 >= 19 and v1[z2] == '+':
            v1 = v1[:-3]
            if v1 == v2:
                return 0
        elif z1 + 3 == z2 and z1 >= 19 and v2[z1] == '+':
            v2 = v2[:-3]
            if v1 == v2:
                return 0

        return -1

    def cmp_keys(self, src_row, dst_row):
        """Compare primary keys of the rows.

        Returns 1 if src > dst, -1 if src < dst and 0 if src == dst"""

        # None means table is done.  tag it larger than any existing row.
        if src_row is None:
            if dst_row is None:
                return 0
            return 1
        elif dst_row is None:
            return -1

        for k in self.pkey_list:
            v1 = src_row[k]
            v2 = dst_row[k]
            if v1 < v2:
                return -1
            elif v1 > v2:
                return 1
        return 0


class Syncer(skytools.DBScript):
    """Checks that tables in two databases are in sync."""
    lock_timeout = 10
    ticker_lag_limit = 20
    consumer_lag_limit = 20

    def sync_table(self, cstr1, cstr2, queue_name, consumer_name, table_name):
        """Syncer main function.

        Returns (src_db, dst_db) that are in transaction
        where table should be in sync.
        """

        setup_db = self.get_database('setup_db', connstr=cstr1, autocommit=1)
        lock_db = self.get_database('lock_db', connstr=cstr1)

        src_db = self.get_database('src_db', connstr=cstr1, isolation_level=skytools.I_REPEATABLE_READ)
        dst_db = self.get_database('dst_db', connstr=cstr2, isolation_level=skytools.I_REPEATABLE_READ)

        lock_curs = lock_db.cursor()
        setup_curs = setup_db.cursor()
        src_curs = src_db.cursor()
        dst_curs = dst_db.cursor()

        self.check_consumer(setup_curs, queue_name, consumer_name)

        # lock table in separate connection
        self.log.info('Locking %s', table_name)
        self.set_lock_timeout(lock_curs)
        lock_time = time.time()
        lock_curs.execute("LOCK TABLE %s IN SHARE MODE" % skytools.quote_fqident(table_name))

        # now wait until consumer has updated target table until locking
        self.log.info('Syncing %s', table_name)

        # consumer must get further than this tick
        self.force_tick(setup_curs, queue_name)

        # try to force second tick also
        self.force_tick(setup_curs, queue_name)

        # take server time
        setup_curs.execute("select to_char(now(), 'YYYY-MM-DD HH24:MI:SS.MS')")
        tpos = setup_curs.fetchone()[0]
        # now wait
        while True:
            time.sleep(0.5)

            q = "select now() - lag > timestamp %s, now(), lag from pgq.get_consumer_info(%s, %s)"
            setup_curs.execute(q, [tpos, queue_name, consumer_name])
            res = setup_curs.fetchall()
            if len(res) == 0:
                raise Exception('No such consumer: %s/%s' % (queue_name, consumer_name))
            row = res[0]
            self.log.debug("tpos=%s now=%s lag=%s ok=%s", tpos, row[1], row[2], row[0])
            if row[0]:
                break

            # limit lock time
            if time.time() > lock_time + self.lock_timeout:
                self.log.error('Consumer lagging too much, exiting')
                lock_db.rollback()
                sys.exit(1)

        # take snapshot on provider side
        src_db.commit()
        src_curs.execute("SELECT 1")

        # take snapshot on subscriber side
        dst_db.commit()
        dst_curs.execute("SELECT 1")

        # release lock
        lock_db.commit()

        self.close_database('setup_db')
        self.close_database('lock_db')

        return (src_db, dst_db)

    def set_lock_timeout(self, curs):
        ms = int(1000 * self.lock_timeout)
        if ms > 0:
            q = "SET LOCAL statement_timeout = %d" % ms
            self.log.debug(q)
            curs.execute(q)

    def check_consumer(self, curs, queue_name, consumer_name):
        """ Before locking anything check if consumer is working ok.
        """
        self.log.info("Queue: %s Consumer: %s", queue_name, consumer_name)

        curs.execute('select current_database()')
        self.log.info('Actual db: %s', curs.fetchone()[0])

        # get ticker lag
        q = "select extract(epoch from ticker_lag) from pgq.get_queue_info(%s);"
        curs.execute(q, [queue_name])
        ticker_lag = curs.fetchone()[0]
        self.log.info("Ticker lag: %s", ticker_lag)
        # get consumer lag
        q = "select extract(epoch from lag) from pgq.get_consumer_info(%s, %s);"
        curs.execute(q, [queue_name, consumer_name])
        res = curs.fetchall()
        if len(res) == 0:
            self.log.error('check_consumer: No such consumer: %s/%s', queue_name, consumer_name)
            sys.exit(1)
        consumer_lag = res[0][0]

        # check that lag is acceptable
        self.log.info("Consumer lag: %s", consumer_lag)
        if consumer_lag > ticker_lag + 10:
            self.log.error('Consumer lagging too much, cannot proceed')
            sys.exit(1)

    def force_tick(self, curs, queue_name):
        """ Force tick into source queue so that consumer can move on faster
        """
        q = "select pgq.force_tick(%s)"
        curs.execute(q, [queue_name])
        res = curs.fetchone()
        cur_pos = res[0]

        start = time.time()
        while True:
            time.sleep(0.5)
            curs.execute(q, [queue_name])
            res = curs.fetchone()
            if res[0] != cur_pos:
                # new pos
                return res[0]

            # dont loop more than 10 secs
            dur = time.time() - start
            if dur > 10 and not self.options.force:
                raise Exception("Ticker seems dead")


class Checker(Syncer):
    """Checks that tables in two databases are in sync.

    Config options::

        ## data_checker ##
        confdb = dbname=confdb host=confdb.service

        extra_connstr = user=marko

        # one of: compare, repair, repair-apply, compare-repair-apply
        check_type = compare

        # random params used in queries
        cluster_name =
        instance_name =
        proxy_host =
        proxy_db =

        # list of tables to be compared
        table_list = foo, bar, baz

        where_expr = (hashtext(key_user_name) & %%(max_slot)s) in (%%(slots)s)

        # gets no args
        source_query =
         select h.hostname, d.db_name
           from dba.cluster c
                join dba.cluster_host ch on (ch.key_cluster = c.id_cluster)
                join conf.host h on (h.id_host = ch.key_host)
                join dba.database d on (d.key_host = ch.key_host)
          where c.db_name = '%(cluster_name)s'
            and c.instance_name = '%(instance_name)s'
            and d.mk_db_type = 'partition'
            and d.mk_db_status = 'active'
          order by d.db_name, h.hostname


        target_query =
            select db_name, hostname, slots, max_slot
              from dba.get_cross_targets(%%(hostname)s, %%(db_name)s, '%(proxy_host)s', '%(proxy_db)s')

        consumer_query =
            select q.queue_name, c.consumer_name
              from conf.host h
              join dba.database d on (d.key_host = h.id_host)
              join dba.pgq_queue q on (q.key_database = d.id_database)
              join dba.pgq_consumer c on (c.key_queue = q.id_queue)
             where h.hostname = %%(hostname)s
               and d.db_name = %%(db_name)s
               and q.queue_name like 'xm%%%%'
    """

    def __init__(self, args):
        """Checker init."""
        super(Checker, self).__init__('data_checker', args)
        self.set_single_loop(1)
        self.log.info('Checker starting %s', str(args))

        self.lock_timeout = self.cf.getfloat('lock_timeout', 10)

        self.table_list = self.cf.getlist('table_list')

    def work(self):
        """Syncer main function."""

        source_query = self.cf.get('source_query')
        target_query = self.cf.get('target_query')
        consumer_query = self.cf.get('consumer_query')
        where_expr = self.cf.get('where_expr')
        extra_connstr = self.cf.get('extra_connstr')

        check = self.cf.get('check_type', 'compare')

        confdb = self.get_database('confdb', autocommit=1)
        curs = confdb.cursor()

        curs.execute(source_query)
        for src_row in curs.fetchall():
            s_host = src_row['hostname']
            s_db = src_row['db_name']

            curs.execute(consumer_query, src_row)
            r = curs.fetchone()
            consumer_name = r['consumer_name']
            queue_name = r['queue_name']

            curs.execute(target_query, src_row)
            for dst_row in curs.fetchall():
                d_db = dst_row['db_name']
                d_host = dst_row['hostname']

                cstr1 = "dbname=%s host=%s %s" % (s_db, s_host, extra_connstr)
                cstr2 = "dbname=%s host=%s %s" % (d_db, d_host, extra_connstr)
                where = where_expr % dst_row

                self.log.info('Source: db=%s host=%s queue=%s consumer=%s',
                              s_db, s_host, queue_name, consumer_name)
                self.log.info('Target: db=%s host=%s where=%s', d_db, d_host, where)

                for tbl in self.table_list:
                    src_db, dst_db = self.sync_table(cstr1, cstr2, queue_name, consumer_name, tbl)
                    if check == 'compare':
                        self.do_compare(tbl, src_db, dst_db, where)
                    elif check == 'repair':
                        r = TableRepair(tbl, self.log)
                        r.do_repair(src_db, dst_db, where, 'fix.' + tbl, False)
                    elif check == 'repair-apply':
                        r = TableRepair(tbl, self.log)
                        r.do_repair(src_db, dst_db, where, 'fix.' + tbl, True)
                    elif check == 'compare-repair-apply':
                        ok = self.do_compare(tbl, src_db, dst_db, where)
                        if not ok:
                            r = TableRepair(tbl, self.log)
                            r.do_repair(src_db, dst_db, where, 'fix.' + tbl, True)
                    else:
                        raise Exception('unknown check type')
                    self.reset()

    def do_compare(self, tbl, src_db, dst_db, where):
        """Actual comparison."""

        src_curs = src_db.cursor()
        dst_curs = dst_db.cursor()

        self.log.info('Counting %s', tbl)

        q = "select count(1) as cnt, sum(hashtext(t.*::text)) as chksum from only _TABLE_ t where %s;" % where
        q = self.cf.get('compare_sql', q)
        q = q.replace('_TABLE_', skytools.quote_fqident(tbl))

        f = "%(cnt)d rows, checksum=%(chksum)s"
        f = self.cf.get('compare_fmt', f)

        self.log.debug("srcdb: %s", q)
        src_curs.execute(q)
        src_row = src_curs.fetchone()
        src_str = f % src_row
        self.log.info("srcdb: %s", src_str)

        self.log.debug("dstdb: %s", q)
        dst_curs.execute(q)
        dst_row = dst_curs.fetchone()
        dst_str = f % dst_row
        self.log.info("dstdb: %s", dst_str)

        src_db.commit()
        dst_db.commit()

        if src_str != dst_str:
            self.log.warning("%s: Results do not match!", tbl)
            return False
        else:
            self.log.info("%s: OK!", tbl)
            return True


if __name__ == '__main__':
    script = Checker(sys.argv[1:])
    script.start()

