# sharedobject.py, classes to aid activities in sharing a state
# Reinier Heeres, reinier@heeres.eu
# Miguel Alvarez, miguel@laptop.org
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
#
# Change log:
#   2007-07-14: rwh, the big merge. Old function are kept with _nc
#               (non-contiguous) appended
#   2007-07-07: miguel, conflict resolution added
#   2007-06-21: rwh, first version

import copy
import pickle
import base64
import difflib
import time

import logging
_logger = logging.getLogger('sharinghelper')

class DiffRec:
    def __init__(self, versionid, sender, incremental, objstr):
        self.version_id = versionid
        self.sender = sender
        self.incremental = incremental
        self.obj = objstr

    def is_newer(self, v2):
        if self.version_id > v2.version_id or \
            (self.version_id == v2.version_id and self.sender > v2.sender):
            return True
        else:
            return False

    def __str__(self):
        return "[%d]%s" % (self.version_id, str(self.obj))

    def __repr__(self):
        return self.__str__()

class SharedObject:
    """Base class for shared objects, able to share python objects the
    dumb way"""

    def __init__(self, name, helper, opt = {}):
        self._name = name
        self._options = opt
        self._helper = helper
        self._value = None
        self._version_id = 0
        self._received_diffs = []
        self._inverse_diffs = []
        self._pending_diffs = {}
        self._cached_versions = 8
        self._locked = False
        self._locked_by = None
        self._locked_time = 0

    def encode(self, obj):
        return base64.b64encode(pickle.dumps(obj))

    def decode(self, obj):
        return pickle.loads(base64.b64decode(obj))

    def _should_encode_incremental(self, incremental):
        """Decide whether to encode an object incrementally. If nothing is
        specified (incremental=None) determine from self._options, else
        return what is specified."""

        if incremental is None:
            if 'incremental' in self._options and self._options['incremental'] is True:
                return True
            else:
                return False
        return incremental

    def changed_nc(self, diffobj, incremental):
        """This function should be called when the object has changed."""

        self._version_id += 1
        try:
            diff = DiffRec(self._version_id, self._helper._own_bus_name, incremental, diffobj)
            self.insert_diff(diff)
            enc = self.encode(diff.obj)
            if self._helper.tube_connected():
                self._helper.SendObject(self._name, self._version_id, incremental, enc)
        except Exception, inst:
            _logger.error('changed(): %s', inst)

        self.do_changed_callback()

    def changed(self, diffobj, incremental=False):
        """This function should be called when the object has change
        If diffobj is None the whole object will be sent
        If defined, only this difference will be sent
        A diff entry is also added to our received_diffs so that
        we can properly sequence and undo/redo this change.
        """

        try:
            self.insert_diff(DiffRec(self._version_id+1,
            self._helper._own_bus_name, incremental, diffobj))
            if not incremental:
                enc = self.encode(self._value)
            else:
                enc = self.encode(diffobj)
            self._helper.SendObject(self._name, self._version_id, incremental, enc)
            _logger.debug("Modification sent")
        except Exception, inst:
            _logger.error('changed(): %s, currval: %s', inst, self._value)
        self.do_changed_callback()

    def set_value_nc(self, v, incremental=None):
        """Function to set value of this object. Specifying incremental
        (a boolean) allows forcing of either incremental/full encoding.
        If not specified, the default will be used from self._options"""

        incremental = self._should_encode_incremental(incremental)
        _logger.debug('Setting value of %s to %s, incremental=%s', self._name, v, incremental)
        if incremental:
            old = copy.deepcopy(self._value)
            self._value = v
            d = self.diff(v, old)
            _logger.debug('set_value(): generated diff:\n%r', d)
            del old
            if d is not None:
                self.changed(d, True)
        else:
            self._value = v
            self.changed_nc(v, False)

    def set_value(self, v, incremental=False):
        incremental = self._should_encode_incremental(incremental)
        _logger.debug('Setting value of %s to %s, incremental=%s', self._name, v, incremental)
        old = copy.deepcopy(self._value)
        d = self.diff(v, old)
        _logger.debug("Diff= %s", d)
        self.changed(d, incremental)

    def get_value(self):
        return self._value

    def do_changed_callback(self):
        if 'changed' in self._options:
            self._options['changed'](self._value)

    def output_diff_stack(self):
        _logger.debug('Diff stack:')
        for versionid, incremental, objstr, sender in self._received_diffs:
            _logger.debug('\tv:%d, inc:%d, sender:%d', versionid,
                incremental, sender)

    def get_version(self, versionid):
        if versionid == self._version_id:
            return self._value
        if versionid > self._version_id or versionid < 0:
            return None
        target_index = self._get_diff_index(versionid + 1) #versionid of a diff== the version it leads _up_ to
        if 0 > target_index or index > len(self._received_diffs):
            return None
        object = self._value
        for i in range(len(self._received_diffs)-1, target_index -1, -1):
            d = self._received_diffs[i]
            object = self._apply_diff_to(object, self.inverse_diff(diff))
        return object

    def get_version(self, versionid):
        if versionid == self._version_id:
            return self._value
        if versionid > self._version_id or versionid < 0:
            return None
        if versionid < self._received_diffs[0].version_id:
            return None

        i = len(self._received_diffs) - 1
        while self._version_id > versionid:
            obj = self._apply_diff_to(obj, self.inverse_diff(diff))
            i -= 1

        return obj

    def inverse_diff(self, diff):
        """Invert a diff object, so that if was the result of an o -> n comparison, the
        result is associated to a n -> o comparison."""
        index = self._get_diff_index(diff)
        if 0 < index < len(self._inverse_diffs): #found (?)
            return self._inverse_diffs[index]
        return None

    def _get_diff_index(self, vi):
        for i in xrange(len(self._received_diffs)-1, -1 , -1):
            if self._received_diffs[i].version_id == vi:
                return i
        return -1

    def insert_diff_nc(self, d):
        if len(self._received_diffs) > 0 and self._received_diffs[0].is_newer(d):
            return -1

        if len(self._received_diffs) >= self._cached_versions:
            del(self._received_diffs[0])

        i = len(self._received_diffs) - 1
        while i > -1:
            if d.is_newer(self._received_diffs[i]):
                break
            i -= 1

        self._received_diffs.insert(i + 1, d)
        self._inverse_diffs.insert(i + 1, DiffRec(d.version_id, d.sender, d.incremental, None))

        return i + 1

    def insert_diff(self, recv_diff, old=None):
        """ Places a new and compatible change of incremental type in its correct place, and
        actualizes thevalue of the shared object accordingly."""
        _logger.debug("insert_diff(): Current versionid:%s, recv_diff:%s", self._version_id, recv_diff)
        if not recv_diff.incremental:
            _logger.debug("insert_diff(): Applying non-incremental diff")
        if recv_diff in self._received_diffs:
            _logger.debug("Dupe")
            return True
        recvi = recv_diff.version_id
        index = self._get_diff_index(recvi)
        if old == None:
            old = self.get_version(recvi - 1)
            #if old == None: #can't find the father, doesn't use
            #    self._discarded_diff = recv_diff
            #    #TODO: signal that we have a discarded diff --> return False
            #    return False

        _logger.debug("insert_diff(): old=%s", old)
        if index < 0:
            diffs = []
        else:
                diffs = self._received_diffs[index:]
        i = 0
        while len(diffs) > 0 and recv_diff.is_newer(diffs[0]):
            diffs = diffs[1:]
            index += 1
            i += 1
        recv_diff.version_id += i
        for d in diffs:
            d.version_id += 1
        res = [recv_diff] + diffs

        ret = old
        invdiffs = []
        for d in res:
            _logger.debug("[insert_diff] Applying %s to %s", d.obj, ret)
            ret, id = self._apply_diff_to(ret, d.obj)
            invdiffs.append(id)
        self._value = ret
        if index >= 0:
            self._received_diffs = self._received_diffs[:index] + res
            self._inverse_diffs = self._inverse_diffs[:index] + invdiffs
            _logger.info("insert_diff(): recv_diff inserted at %d", index)
        else:
            self._received_diffs = self._received_diffs + res
            self._inverse_diffs = self._inverse_diffs + invdiffs
            _logger.debug("insert_diff(): recv_diff inserted at %d", len(self._received_diffs))
        self._version_id += 1
        if not self._version_id == self._received_diffs[-1].version_id:
            _logger.debug("Version inconsistency: expected %s, got %s", self._received_diffs[-1], self._version_id)
            self._version_id = self._received_diffs[-1].version_id
        _logger.debug("new version_id: %s, actualized received_diffs: %s", self._version_id, self._received_diffs)
        return True

    def process_update_nc(self, versionid, incremental, objstr, sender, force=False):
        """Process an update:
            -Undo all newer diffs
            -Apply the just added one
            -Redo all newer diffs
        """

        obj = self.decode(objstr)
        d = DiffRec(versionid, sender, incremental, obj)

        i = self.insert_diff(d)
        _logger.debug('Inserted diff at position %d', i)
        if i == -1:
            return False

        j = len(self._received_diffs) - 1        # Don't include ourselve
        while j > i:
            if not self._received_diffs[j].incremental:
                break        # not necessary to continue beyond replacement object
            self.apply_diff(self._inverse_diffs[j].obj)
            j -= 1

        while j < len(self._received_diffs):
            if not self._received_diffs[j].incremental:
                self._value = self._received_diffs[j].obj
            else:
                self._inverse_diffs[j].obj = self.apply_diff(self._received_diffs[j].obj)
            self._version_id = self._received_diffs[j].version_id
            j += 1

        self.do_changed_callback()

        return True

    def process_update(self, versionid, incremental, objstr, sender, force=False):
        """Process an update:
            -Undo all newer diffs to get to the common version
            - Check for compatability between the received update and the 'combined' newer diffs
            - If compatible, insert the received one
            - If not, apply the 'winning' one, and pass the other to the 'rejected' folder
            We return 'True' if there is no conflict, and 'False' if there was one
        """
        if  (versionid == self._version_id):
            _logger.debug("Maybe dupe? our versionid:%d, received:%d", self._version_id, versionid)
            a = len(self._received_diffs) > 0
            b = (sender == self._received_diffs[-1].sender)
            _logger.debug("Two tests: %s, %s", a, b)
            if a and b:
                return True #Dupe
        obj = self.decode(objstr)
        _logger.debug( "process_update: version: %s, obj: %s" % (versionid, obj))
        if versionid > self._version_id + 1 and incremental and not force:
            #Disordered diffs. Store for later use
            self._pending_diffs[versionid] = DiffRec(versionid, sender, incremental, obj)
            for i in range(self._version_id+1, versionid):
                #TODO : Call for missing diffs:
                pass
            _logger.debug("Disordered diffs. returning")
            return True
        old = self.get_version(versionid - 1) #supposed common ancestor
        if not incremental: #we get the incremental version maually
            obj2 = self.diff(obj, old)
        else:
            obj2 = obj
        db = DiffRec(versionid, sender, True, obj2)
        if versionid > self._version_id or force: #expected version number
            if not incremental:
                _logger.debug("Update in expected range, non incremental")
                self._value = obj
                self._received_diffs.append(db)
                self._version_id = versionid
            else:
                _logger.debug("Update in expected range, incremental")
                self.insert_diff(db, old)
                while self._version_id + 1 in self._pending_diffs:
                    #We get the pending diff and apply it
                    pd = self._pending_diffs[self._version_id + 1]
                    if not pd.incremental:
                        self._value = pd.obj
                        pd.obj = self.diff(pd.obj, self._value)
                        pd.incremental = True
                    else:
                        self.insert_diff(pd, self._value)
                    self._received_diffs.append(pd)
                    self._version_id += 1
                    del self._pending_diffs[self._version_id +1]
            _logger.debug("Updated value: %s", self._value)
            self.do_changed_callback()
            return True
        else:
            _logger.debug("Conflicting versionids: mine:%d, his:%d" %(self._version_id, versionid))
            obj_da = self.diff(self._value, old)
            if len(self._received_diffs)>0:
                sender_a = self._received_diffs[-1].sender
            else:
                sender_a = None
            da = DiffRec(self._version_id, sender_a, True, obj_da)
            _logger.debug("Verifying compatibility")
            if self._compatible_diffs(da.obj, db.obj):
                _logger.debug("compatible diff")
                if self.insert_diff(db, old):
                    #insert does take care of version actualization and updates self._version_id
                    _logger.debug("changed without error")
                    self.do_changed_callback()
                    return True
                else:
                    _logger.debug("Error with insertion")
                    return False

    def _compatible_diffs(self, da, db):
        return True

    def diff(self, new, old):
        return None

    def apply_diff_to(self, obj, diffobj):
        return (obj, diffobj)

    def apply_diff(self, diffobj):
        (newobj, idiff) = self.apply_diff_to(self._value, diffobj)
        self._value = newobj
        return idiff

    def is_locked(self):
        return self._locked

    def lock(self):
        if not self._locked:
            self._locked = True
            self._locked_by = "me"
            self._locked_time = time.time()
            self._helper.LockObject(self._name, self._locked_time)

    def receive_lock(self, sender, when):
        if not self._locked or \
            (self._locked and self._locked_time > when):
            if self._locked and 'locklost' in self._options:
                self._options['locklost']()
            self._locked = True
            self._locked_by = sender
            self._locked_time = when
            if 'locked' in self._options:
                self._options['locked'](sender)

    def unlock(self):
        if self._locked and self._locked_by is 'me':
            self._helper.UnlockObject(self._name)

    def receive_unlock(self, sender):
        if self._locked:
            self._locked = False
            self._locked_by = None
            self._locked_time = 0
            if 'unlocked' in self._options:
                self._options['unlocked'](sender)
