#!/usr/bin/python3
################################################################################
#
#       This file is part of ghmm (Graph Algorithm Toolbox) 
#
#	file:   HMMEditingContext.py
#	author: Janne Grunau
#
#       Copyright (C) 2007, Janne Grunau
#                                   
#       Contact: schliep@molgen.mpg.de
#
#       Information: http://ghmm.org
#
#       This library is free software; you can redistribute it and/or
#       modify it under the terms of the GNU Library General Public
#       License as published by the Free Software Foundation; either
#       version 2 of the License, or (at your option) any later version.
#
#       This library is distributed in the hope that it will be useful,
#       but WITHOUT ANY WARRANTY; without even the implied warranty of
#       MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
#       Library General Public License for more details.
#
#       You should have received a copy of the GNU Library General Public
#       License along with this library; if not, write to the Free
#       Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
#
################################################################################

import ghmm
import ghmmwrapper
from ObjectHMM import *


class HMMEditingContext(object):
    def __init__(self, parameter, mtype=0):
        self.modeltype = mtype
        # check whether the parameter is an EmissionDomain
        if isinstance(parameter, ghmm.EmissionDomain):
            if parameter.CDataType == "int":
                self.HMM = ObjectHMM(State, Transition, alphabet=parameter, etype=0)
                self.modeltype |= ghmmwrapper.kDiscreteHMM
            elif parameter.CDataType == "double":
                self.HMM = ObjectHMM(State, Transition, emissionClass=ContinuousMixtureDistribution, etype=1)
                self.modeltype |= ghmmwrapper.kContinuousHMM
            else:
                raise ghmm.NoValidCDataType("C data type " + str(parameter.CDataType) + " invalid.")
            self.HMM.initHMM(self.modeltype)
        # Distribution parameter
        elif type(parameter) is type:
            self.HMM = ObjectHMM(State, Transition, emissionClass=parameter, etype=1)
            self.modeltype |= ghmmwrapper.kContinuousHMM
            self.HMM.initHMM(self.modeltype)
        # existing hidden markov model
        elif isinstance(parameter, ghmm.HMM):
            hmm = parameter
            self.HMM = ObjectHMM(State, Transition)
            if isinstance(hmm, ghmm.DiscreteEmissionHMM):
                self.modeltype = hmm.cmodel.model_type
                if hmm.cmodel.alphabet is None:
                    hmm.cmodel.alphabet = self.emissionDomain.toCstruct()
            else:
                self.modeltype = ghmmwrapper.ContinuousHMM
            self.HMM.initHMM(self.modeltype)
            self.HMM.buildFromCModel(hmm.cmodel)
        # filename
        elif isinstance(parameter, str) or isinstance(parameter, str):
            self.load(parameter)
        else:
            raise ghmm.UnknownInputType

    def __str__(self):
        return str(self.HMM)

    ## container emulation
    def __getitem__(self, key):
        if isinstance(key, tuple):
            tail,head = [self.getUniqueId(k) for k in key]
            return self.HMM.edges[(tail,head)]
        elif isinstance(key, int):
            return self.HMM.vertices[key]
        else:
            retlist = [self.HMM.vertices[vid] for vid in self.HMM.name2id[key]]
            if len(retlist) == 1:
                return retlist[0]
            else:
                return retlist

    def __setitem__(self, key, value):
        pass

    def __delitem__(self, key):
        if isinstance(key, tuple):
            tail,head = [self.getUniqueId(k) for k in key]
            self.HMM.DeleteEdge(tail, head)
        elif isinstance(key, int):
            self.deleteState(key)
        else:
            for vid in self.HMM.name2id[key]:
                self.deleteState(vid)

    def __len__(self):
        return len(self.HMM)

    def __contains__(self, key):
        if isinstance(key, tuple):
            tail,head = [self.getUniqueId(k) for k in key]
            return (tail,head) in self.HMM.edges
        elif isinstance(key, int):
            return key in self.HMM.vertices
        else:
            return key in self.HMM.name2id

    def __iter__(self):
        pass


    def load(self, filename):
        self.HMM = ObjectHMM(State, Transition)
        self.HMM.openXML(filename)
        self.modeltype = self.HMM.modelType


    def save(self, filename):
        self.HMM.writeXML(filename)


    def finalize(self):
        cmodel = self.HMM.finalize()
        
        if (self.modeltype & ghmmwrapper.kContinuousHMM):
            return ghmm.ContinuousMixtureHMM(ghmm.Float(),
                                             ghmm.ContinuousMixtureDistribution(ghmm.Float()),
                                             cmodel)

        elif ((self.modeltype & ghmmwrapper.kDiscreteHMM)
              and not (self.modeltype & ghmmwrapper.kTransitionClasses)
              and not (self.modeltype & ghmmwrapper.kPairHMM)):
            emission_domain = ghmm.Alphabet([], cmodel.alphabet)
            if (self.modeltype & ghmmwrapper.kLabeledStates):
                labelDomain = ghmm.LabelDomain([], cmodel.label_alphabet)
                return ghmm.StateLabelHMM(emission_domain,
                                          ghmm.DiscreteDistribution(emission_domain),
                                          labelDomain,
                                          cmodel)

            else:
                return ghmm.DiscreteEmissionHMM(emission_domain,
                                                ghmm.DiscreteDistribution(emission_domain),
                                                cmodel)


    def addState(self, emission=None, initial=-1):
        vid = self.HMM.AddVertex()
        if emission is not None:
            self.HMM.vertices[vid].emission = emission

        self.HMM.vertices[vid].initial = initial
        return vid

    def deleteState(self, vid):
        self.HMM.DeleteVertex(vid)
        
    def addTransition(self, tail, head, p=-1):
        tail = self.getUniqueId(tail)
        head = self.getUniqueId(head)
        
        self.HMM.AddEdge(tail, head)
        self.HMM.edges[tail,head].SetWeight(p)

    def getUniqueId(self, name):
        if isinstance(name, str) or isinstance(name, str):
            tmp = self.HMM.name2id[name]
            if len(tmp) > 1:
                raise IdentifierIsAmbigious(name)
            else:
                return tmp[0]
        else:
            return name



