# sharedpython.py, classes to aid activities in sharing a state
# @author: Miguel Angel Alvarez, miguel@laptop.org
# @author: Reinier Heeres
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
#
# Change log:
#

import pickle
import difflib
import logging
from sharedobject import DiffRec
_logger = logging.getLogger('sharedpython')

from sharedobject import SharedObject

class SharedPython(SharedObject):

    def __init__(self, name, helper, opt={}):
        SharedObject.__init__(self, name, helper, opt=opt)
        self._value = None
        self._picklestr = ''

    def _divide_change(self, c):
        """Separate the index and the string parts of the chage string"""
        if ' ' in c:
            i = c.index(' ')
            return (c[:i], c[i+1:])
        else:
            print c
            return None

    def inverse_diff(self, diff):
        """Return the inverse diff object"""
        d = SharedObject.inverse_diff(self, diff)
        if d == None:
            obj = self._generate_inverse_diffobj(diff.obj)
        return DiffRec(diff.version_id, diff.sender, True, obj)

    def _generate_inverse_diffobj(self, changes):
        ret = []
        delta = 0
        last_num = -1
        last_n_in = -1

        for c in changes:
            n, s = self._divide_change(c)
            l = len(s)
            num = abs(int(n))
            if num == last_num:
                n_in = last_n_in
            else:
                n_in = num + delta
            last_num = num
            last_n_in = n_in

            if c[0] == '+':
                ret.append('-'+str(n_in)+' '+s)
                delta += l
            elif c[0] == '-':
                ret.append('+'+str(n_in)+' '+s)
                delta -= l
            else:
                print "Unknown line type:", c
        ret = self._format_changes(ret)
        return ret

    def _update_interval(self, i, cs):
        """Get an 'exclusion interval' (where no other different edits are accepted) for the change cs[i]"""
        c = cs[i]
        n, d = self._divide_change(c)
        n = int(n)
        if n >= 0:
            return (n+1, n+1)
        else:
            return (abs(n), abs(n) + len(d))

    def _intersect(self, ia, ib):
        """function that takes 2 intervals (2 element ordered int lists) and returns whether they
        intersect and, if not, which one is bigger:
        ret > 0 => ia bigger
        ret < 0 => ib bigger
        ret = 0 => intersection"""
        assert ia[1] >= ia[0] and len(ia) == 2
        assert ib[1] >= ib[0] and len(ib) == 2
        if ia[0] > ib[1]:
            return 1 #No _intersect, ia bigger
        if ia[1] < ib[0]:
            return -1 #No _intersect, ia smaller
        return 0 #Intersect

    def _compatible_diffs(self, diff_a, diff_b):
        """ It returns whether  two change arrays act upon the same positions, and
        cannot therefore be automatically merged without risking conflict"""
        _logger.debug("_compatible_diffs(): a=%s, b=%s", diff_a, diff_b)
        index_a = index_b = 0
        while index_a < len(diff_a) and index_b < len(diff_b):
            interval_a = self._update_interval(index_a, diff_a)
            interval_b = self._update_interval(index_b, diff_b)
            d = self._intersect(interval_a, interval_b)
            if d == 0 :
##                if diff_a[index_a] not in diff_b and diff_b[index_b] not in diff_a:
##                    print "change a:'%s'\tchange_b:'%s'" % (diff_a[index_a], diff_b[index_b])
##                    return False
##                elif interval_a[1] > interval_b[1]:
##                    index_b += 1
##                else:
##                    index_a += 1
##            ATT: More restrictive version of compatible_diffs used right now
                return False
            elif d == 1:
                index_b += 1
            elif d == -1:
                index_a += 1
            else:
                index_a += 1
                index_b += 1
        return True

    def _format_changes(self, changes):
        last_sign = None
        last_num = -1
        foreseen_index = -1
        res=[]
        for c in changes:
            sign = c[0]
            n, s = self._divide_change(c)
            number = abs(int(n))
            if sign == "+" and last_sign == "-" and (number == foreseen_index or number == last_num) :
                res = res[:-1]+[sign +str(last_num)+" "+s, res[-1]]
                last_sign = sign
                foreseen_index = number + len(s)
                last_num = number
            else:
                res.append(c)
                last_sign = sign
                foreseen_index = number + len(s)
                last_num = number
        return res

    def diff(self, new_object, old_object):
        """Generate a change array from two python objects"""
        _logger.debug("Diffing old:%s (type: %s), new: %s (type: %s)", old_object, type(old_object),
        new_object, type(new_object))
        differ = difflib.Differ()
        old = pickle.dumps(old_object)
        new = pickle.dumps(new_object)
##        _logger.debug('Old text: %s', old)
##        _logger.debug(' New text:%s', new)
        ret = []
        raw_delta = list(differ.compare(old, new))
##        _logger.debug('raw delta: %s', raw_delta)
        pos = 0
        continuous = False
        for r in raw_delta:
            if r[:2] == "+ ":
                if len(ret) > 0 and ret[-1][0] == "+" and continuous:
                    ret[-1] +=  r[2:]
                    #if 2 continuous additions occur, append at the end
                else:
                    string = "+"+str(pos)+" "+r[2:]
                    ret.append(string)
                continuous = True
            elif r[:2] == "- ":
                if len(ret) > 0 and ret[-1][0] == "-" and continuous:
                    ret[-1] += r[2:] #append at the end
                else:
                    string = "-"+str(pos)+" "+r[2:]
                    ret.append(string)
                continuous = True
                pos += 1 # TODO: important change; verify
            elif r[:2] == "? ":
                pass
            else:
                continuous = False
                pos += len(r) - 2
##        _logger.debug('Ret before format changes: %s', ret)
        return self._format_changes(ret)

    def _apply_diff_to(self, object, diffobj):
        """ Apply a diff to a given object and return the new version
        In this case, the provided object gets modified, too."""
        old = pickle.dumps(object)
        _logger.debug("_apply_diff_to(): old=%r, diffobj=%s", old, diffobj)
        new = ""
        pos = old_pos = 0
        for c in diffobj:
            id, st = self._divide_change(c)
            pos = abs(int(id))
            new += old[old_pos:pos]
            if id[0] == "+":
##                _logger.debug("_apply_diff_to(): adding %r", st)
                new += st
            elif id[0] == "-":
##                _logger.debug("_apply_diff_to(): deleting %r", st)
                if old[pos:pos+len(st)] != st:
                    exc = "Bad delete at %d, expected %s, got %r" % (pos,
                st, old[pos:pos+len(st)])
                    raise Exception(exc)
                pos += len(st)
            old_pos = pos
        if pos < len(old):
            new += old[pos:]
        res = pickle.loads(new)
        _logger.debug("_apply_diff_to(): new=%r, res=%s", new, res)
        idiff = self._generate_inverse_diffobj(diffobj)
        return (res, idiff)

    def get_version(self, versionid):
        _logger.debug("get_version(): called with versionid = %d", versionid)
        if versionid == self._version_id:
            return self._value
        ret = self._value
        if versionid > self._version_id:
            return None
        for i in range(len(self._received_diffs)-1, self._get_diff_index(versionid), -1):
            d = self._received_diffs[i]
            idobj = self._generate_inverse_diffobj(d.obj)
            ret = self._apply_diff_to(ret, idobj)[0]
        _logger.debug("get_version(): return value: %s", ret)
        return ret
