
"""Repair data on subscriber.

Walks tables by primary key and searches for missing inserts/updates/deletes.
"""

import sys, os, skytools, subprocess

from londiste.syncer import Syncer

__all__ = ['Repairer']

def unescape(s):
    """Remove copy escapes."""
    return skytools.unescape_copy(s)

class Repairer(Syncer):
    """Walks tables in primary key order and checks if data matches."""

    cnt_insert = 0
    cnt_update = 0
    cnt_delete = 0
    total_src = 0
    total_dst = 0
    pkey_list = []
    common_fields = []
    apply_curs = None

    def init_optparse(self, p=None):
        """Initialize cmdline switches."""
        p = super(Repairer, self).init_optparse(p)
        p.add_option("--apply", action="store_true", help="apply fixes")
        return p

    def process_sync(self, t1, t2, src_db, dst_db):
        """Actual comparison."""

        apply_db = None

        if self.options.apply:
            apply_db = self.get_database('db', cache='applydb', autocommit=1)
            self.apply_curs = apply_db.cursor()
            self.apply_curs.execute("set session_replication_role = 'replica'")

        src_tbl = t1.dest_table
        dst_tbl = t2.dest_table

        src_curs = src_db.cursor()
        dst_curs = dst_db.cursor()

        self.log.info('Checking %s', dst_tbl)

        self.common_fields = []
        self.fq_common_fields = []
        self.pkey_list = []
        self.load_common_columns(src_tbl, dst_tbl, src_curs, dst_curs)

        dump_src = dst_tbl + ".src"
        dump_dst = dst_tbl + ".dst"
        dump_src_sorted = dump_src + ".sorted"
        dump_dst_sorted = dump_dst + ".sorted"

        dst_where = t2.plugin.get_copy_condition(src_curs, dst_curs)
        src_where = dst_where

        self.log.info("Dumping src table: %s", src_tbl)
        self.dump_table(src_tbl, src_curs, dump_src, src_where)
        src_db.commit()
        self.log.info("Dumping dst table: %s", dst_tbl)
        self.dump_table(dst_tbl, dst_curs, dump_dst, dst_where)
        dst_db.commit()

        self.log.info("Sorting src table: %s", dump_src)
        self.do_sort(dump_src, dump_src_sorted)
        self.log.info("Sorting dst table: %s", dump_dst)
        self.do_sort(dump_dst, dump_dst_sorted)

        self.dump_compare(dst_tbl, dump_src_sorted, dump_dst_sorted)

        os.unlink(dump_src)
        os.unlink(dump_dst)
        os.unlink(dump_src_sorted)
        os.unlink(dump_dst_sorted)

    def do_sort(self, src, dst):
        """ Sort contents of src file, write them to dst file. """

        p = subprocess.Popen(["sort", "--version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        s_ver = p.communicate()[0]
        del p

        xenv = os.environ.copy()
        xenv['LANG'] = 'C'
        xenv['LC_ALL'] = 'C'

        cmdline = ['sort', '-T', '.']
        if s_ver.find("coreutils") > 0:
            cmdline.append('-S')
            cmdline.append('30%')
        cmdline.append('-o')
        cmdline.append(dst)
        cmdline.append(src)
        p = subprocess.Popen(cmdline, env = xenv)
        if p.wait() != 0:
            raise Exception('sort failed')

    def load_common_columns(self, src_tbl, dst_tbl, src_curs, dst_curs):
        """Get common fields, put pkeys in start."""

        self.pkey_list = skytools.get_table_pkeys(src_curs, src_tbl)
        dst_pkey = skytools.get_table_pkeys(dst_curs, dst_tbl)
        if dst_pkey != self.pkey_list:
            self.log.error('pkeys do not match')
            sys.exit(1)

        src_cols = skytools.get_table_columns(src_curs, src_tbl)
        dst_cols = skytools.get_table_columns(dst_curs, dst_tbl)
        field_list = []
        for f in self.pkey_list:
            field_list.append(f)
        for f in src_cols:
            if f in self.pkey_list:
                continue
            if f in dst_cols:
                field_list.append(f)

        self.common_fields = field_list

        fqlist = [skytools.quote_ident(col) for col in field_list]
        self.fq_common_fields = fqlist

        cols = ",".join(fqlist)
        self.log.debug("using columns: %s", cols)

    def dump_table(self, tbl, curs, fn, whr):
        """Dump table to disk."""
        cols = ','.join(self.fq_common_fields)
        if len(whr) == 0:
            whr = 'true'
        q = "copy (SELECT %s FROM %s WHERE %s) to stdout" % (cols, skytools.quote_fqident(tbl), whr)
        self.log.debug("Query: %s", q)
        f = open(fn, "w", 64*1024)
        curs.copy_expert(q, f)
        size = f.tell()
        f.close()
        self.log.info('%s: Got %d bytes', tbl, size)

    def get_row(self, ln):
        """Parse a row into dict."""
        if not ln:
            return None
        t = ln[:-1].split('\t')
        row = {}
        for i in range(len(self.common_fields)):
            row[self.common_fields[i]] = t[i]
        return row

    def dump_compare(self, tbl, src_fn, dst_fn):
        """ Compare two table dumps, create sql file to fix target table
            or apply changes to target table directly.
        """
        self.log.info("Comparing dumps: %s", tbl)
        self.cnt_insert = 0
        self.cnt_update = 0
        self.cnt_delete = 0
        self.total_src = 0
        self.total_dst = 0
        f1 = open(src_fn, "r", 64*1024)
        f2 = open(dst_fn, "r", 64*1024)
        src_ln = f1.readline()
        dst_ln = f2.readline()
        if src_ln: self.total_src += 1
        if dst_ln: self.total_dst += 1

        fix = "fix.%s.sql" % tbl
        if os.path.isfile(fix):
            os.unlink(fix)

        while src_ln or dst_ln:
            keep_src = keep_dst = 0
            if src_ln != dst_ln:
                src_row = self.get_row(src_ln)
                dst_row = self.get_row(dst_ln)

                diff = self.cmp_keys(src_row, dst_row)
                if diff > 0:
                    # src > dst
                    self.got_missed_delete(tbl, dst_row)
                    keep_src = 1
                elif diff < 0:
                    # src < dst
                    self.got_missed_insert(tbl, src_row)
                    keep_dst = 1
                else:
                    if self.cmp_data(src_row, dst_row) != 0:
                        self.got_missed_update(tbl, src_row, dst_row)

            if not keep_src:
                src_ln = f1.readline()
                if src_ln: self.total_src += 1
            if not keep_dst:
                dst_ln = f2.readline()
                if dst_ln: self.total_dst += 1

        self.log.info("finished %s: src: %d rows, dst: %d rows,"
                " missed: %d inserts, %d updates, %d deletes",
                tbl, self.total_src, self.total_dst,
                self.cnt_insert, self.cnt_update, self.cnt_delete)

    def got_missed_insert(self, tbl, src_row):
        """Create sql for missed insert."""
        self.cnt_insert += 1
        fld_list = self.common_fields
        fq_list = []
        val_list = []
        for f in fld_list:
            fq_list.append(skytools.quote_ident(f))
            v = unescape(src_row[f])
            val_list.append(skytools.quote_literal(v))
        q = "insert into %s (%s) values (%s);" % (
                tbl, ", ".join(fq_list), ", ".join(val_list))
        self.show_fix(tbl, q, 'insert')

    def got_missed_update(self, tbl, src_row, dst_row):
        """Create sql for missed update."""
        self.cnt_update += 1
        fld_list = self.common_fields
        set_list = []
        whe_list = []
        for f in self.pkey_list:
            self.addcmp(whe_list, skytools.quote_ident(f), unescape(src_row[f]))
        for f in fld_list:
            v1 = src_row[f]
            v2 = dst_row[f]
            if self.cmp_value(v1, v2) == 0:
                continue

            self.addeq(set_list, skytools.quote_ident(f), unescape(v1))
            self.addcmp(whe_list, skytools.quote_ident(f), unescape(v2))

        q = "update only %s set %s where %s;" % (
                tbl, ", ".join(set_list), " and ".join(whe_list))
        self.show_fix(tbl, q, 'update')

    def got_missed_delete(self, tbl, dst_row):
        """Create sql for missed delete."""
        self.cnt_delete += 1
        whe_list = []
        for f in self.pkey_list:
            self.addcmp(whe_list, skytools.quote_ident(f), unescape(dst_row[f]))
        q = "delete from only %s where %s;" % (skytools.quote_fqident(tbl), " and ".join(whe_list))
        self.show_fix(tbl, q, 'delete')

    def show_fix(self, tbl, q, desc):
        """Print/write/apply repair sql."""
        self.log.debug("missed %s: %s", desc, q)
        if self.apply_curs:
            self.apply_curs.execute(q)
        else:
            fn = "fix.%s.sql" % tbl
            open(fn, "a").write("%s\n" % q)

    def addeq(self, list, f, v):
        """Add quoted SET."""
        vq = skytools.quote_literal(v)
        s = "%s = %s" % (f, vq)
        list.append(s)

    def addcmp(self, list, f, v):
        """Add quoted comparison."""
        if v is None:
            s = "%s is null" % f
        else:
            vq = skytools.quote_literal(v)
            s = "%s = %s" % (f, vq)
        list.append(s)

    def cmp_data(self, src_row, dst_row):
        """Compare data field-by-field."""
        for k in self.common_fields:
            v1 = src_row[k]
            v2 = dst_row[k]
            if self.cmp_value(v1, v2) != 0:
                return -1
        return 0

    def cmp_value(self, v1, v2):
        """Compare single field, tolerates tz vs notz dates."""
        if v1 == v2:
            return 0

        # try to work around tz vs. notz
        z1 = len(v1)
        z2 = len(v2)
        if z1 == z2 + 3 and z2 >= 19 and v1[z2] == '+':
            v1 = v1[:-3]
            if v1 == v2:
                return 0
        elif z1 + 3 == z2 and z1 >= 19 and v2[z1] == '+':
            v2 = v2[:-3]
            if v1 == v2:
                return 0

        return -1

    def cmp_keys(self, src_row, dst_row):
        """Compare primary keys of the rows.

        Returns 1 if src > dst, -1 if src < dst and 0 if src == dst"""

        # None means table is done.  tag it larger than any existing row.
        if src_row is None:
            if dst_row is None:
                return 0
            return 1
        elif dst_row is None:
            return -1

        for k in self.pkey_list:
            v1 = src_row[k]
            v2 = dst_row[k]
            if v1 < v2:
                return -1
            elif v1 > v2:
                return 1
        return 0
