#!/usr/bin/env python3

# Copyright (c) 2010  Carnegie Mellon University
#
# You may copy and modify this freely under the same terms as
# Sphinx-III
"""
FST utility functions
"""

__author__ = "David Huggins-Daines <dhdaines@gmail.com>"
__version__ = "$Revision $"

import sys
import os
import tempfile
import openfst
import sphinxbase
import math


class AutoFst(openfst.StdVectorFst):
    """
    FST class which automatically adds states, input and output symbols as required.

    This is meant to behave somewhat like the Dot language.
    """
    def __init__(self, isyms=None, osyms=None, ssyms=None):
        openfst.StdVectorFst.__init__(self)
        if isyms is None:
            isyms = openfst.SymbolTable("inputs")
            isyms.AddSymbol("&epsilon;")
        if osyms is None:
            osyms = openfst.SymbolTable("outputs")
            osyms.AddSymbol("&epsilon;")
        if ssyms is None:
            ssyms = openfst.SymbolTable("states")
            ssyms.AddSymbol("__START__")
        self.ssyms = ssyms
        self.SetInputSymbols(isyms)
        self.SetOutputSymbols(osyms)
        self.SetStart(self.AddState())

    def AddArc(self, src, isym, osym, weight, dest):
        if not isinstance(isym, int):
            isym = self.isyms.AddSymbol(isym)
        if not isinstance(osym, int):
            osym = self.osyms.AddSymbol(osym)
        if not isinstance(src, int):
            src = self.ssyms.AddSymbol(src)
        if not isinstance(dest, int):
            dest = self.ssyms.AddSymbol(dest)
        while src >= self.NumStates():
            self.AddState()
        while dest >= self.NumStates():
            self.AddState()
        openfst.StdVectorFst.AddArc(self, src, isym, osym, weight, dest)

    def Write(self, *args):
        openfst.StdVectorFst.SetInputSymbols(self, self.isyms)
        openfst.StdVectorFst.SetOutputSymbols(self, self.osyms)
        openfst.StdVectorFst.Write(self, *args)

    def SetFinal(self, state, weight=0):
        if not isinstance(state, int):
            state = self.ssyms.AddSymbol(state)
        openfst.StdVectorFst.SetFinal(self, state, weight)

    def SetInputSymbols(self, isyms):
        self.isyms = isyms
        openfst.StdVectorFst.SetInputSymbols(self, self.isyms)

    def SetOutputSymbols(self, osyms):
        self.osyms = osyms
        openfst.StdVectorFst.SetOutputSymbols(self, self.osyms)


def add_mgram_states(fst, symtab, lm, m, sidtab, bo_label=0):
    """
    Add states and arcs for all M-grams in the language model, where M<N.
    """
    for mg in lm.mgrams(m):
        wsym = symtab.Find(mg.words[m])
        if wsym == -1:
            continue  # skip mgrams ending in OOV
        if m > 0 and mg.words[0] == '</s>':
            continue  # skip >1-grams starting with </s>
        if m == 0:
            src = 0  # 1-grams start in backoff state
        elif tuple(mg.words[0:m]) not in sidtab:
            continue  # this means it has an OOV
        else:
            src = sidtab[tuple(mg.words[0:m])]
        if mg.words[m] == '</s>':
            # only one final state is allowed
            final = True
            newstate = False
            if ('</s>', ) in sidtab:
                dest = sidtab[('</s>', )]
            else:
                dest = fst.AddState()
                fst.SetFinal(dest, 0)
                sidtab[('</s>', )] = dest
                # print "Final state", dest
                # print "Entered state ID mapping (</s>,) =>", dest
        else:
            final = False
            newstate = True
            dest = fst.AddState()

        if mg.words[m] == '<s>':
            # <s> is a non-event
            if m == 0:
                # The destination state will be the initial state
                fst.SetStart(dest)
                # print "Initial state", dest
        else:
            fst.AddArc(src, openfst.StdArc(wsym, wsym, -mg.log_prob, dest))
            # print "Added %d-gram arc %d => %d %s/%.4f" % (m+1, src, dest,
            # mg.words[m], -mg.log_prob)

        if newstate:
            # Add a new state to the mapping if needed
            sidtab[tuple(mg.words)] = dest
            # print "Entered state ID mapping", tuple(mg.words), "=>", dest

        if not final:
            # Create a backoff arc to the suffix M-1-gram
            # Note taht if mg.log_bowt == 0 it's particularly important to do this!
            if m == 0:
                bo_state = 0  # backoff state
            elif tuple(mg.words[1:]) in sidtab:
                bo_state = sidtab[tuple(mg.words[1:])]
            else:
                continue  # Not a 1-gram, no suffix M-gram
            fst.AddArc(
                dest, openfst.StdArc(bo_label, bo_label, -mg.log_bowt,
                                     bo_state))
            # print "Adding backoff arc %d => %d %.4f" % (dest, bo_state, -mg.log_bowt)


def add_ngram_arcs(fst, symtab, lm, n, sidtab):
    """
    Add states and arcs for all N-grams in the language model, where
    N=N (the order of the model, that is).
    """
    for ng in lm.mgrams(n - 1):
        wsym = symtab.Find(ng.words[n - 1])
        if wsym == -1:  # OOV
            continue
        if ng.words[n - 1] == '<s>':  # non-event
            continue
        if '</s>' in ng.words[0:n - 1]:
            continue
        for w in ng.words[:n - 1]:  # skip OOVs
            if symtab.Find(w) == -1:
                # print w, "not found"
                continue
        src = sidtab[tuple(ng.words[:n - 1])]
        # Find longest suffix N-gram that exists
        spos = 1
        while tuple(ng.words[spos:]) not in sidtab:
            spos += 1
            if spos == n:
                raise RuntimeError(
                    "Unable to find suffix N-gram for").with_traceback(ng.wids)
        dest = sidtab[tuple(ng.words[spos:])]
        fst.AddArc(src, openfst.StdArc(wsym, wsym, -ng.log_prob, dest))
        #print "Adding %d-gram arc %d => %d %s/%.4f" % (n, src, dest, ng.words[n-1], -ng.log_prob)


def build_lmfst(lm, use_phi=False):
    """
    Build an FST recognizer from an N-gram backoff language model.
    """
    fst = openfst.StdVectorFst()
    symtab = openfst.SymbolTable("words")
    epsilon = symtab.AddSymbol("&epsilon;")
    if use_phi:
        phi = symtab.AddSymbol("&phi;")
        bo_label = phi
    else:
        bo_label = epsilon
    for ug in lm.mgrams(0):
        wsym = symtab.AddSymbol(ug.words[0])
    fst.SetInputSymbols(symtab)
    fst.SetOutputSymbols(symtab)
    # The algorithm goes like this:
    #
    # Create a backoff state
    # For M in 1 to N-1:
    #  For each M-gram w(1,M):
    #   Create a state q(1,M)
    #   Create an arc from state q(1,M-1) to q(1,M) with weight P(w(1,M))
    #   Create an arc from state q(1,M) to q(2,M) with weight bowt(w(1,M-1))
    # For each N-gram w(1,N):
    #   Create an arc from state q(1,N-1) to q(2,N) with weight P(w(1,N))

    # Table holding M-gram to state mappings
    sidtab = {}
    fst.AddState()  # guaranteed to be zero (we hope)
    for m in range(lm.get_size() - 1):
        add_mgram_states(fst, symtab, lm, m, sidtab, bo_label)
    add_ngram_arcs(fst, symtab, lm, lm.get_size(), sidtab)

    # Connect and arc-sort the resulting FST
    openfst.Connect(fst)
    openfst.ArcSortInput(fst)
    return fst


class SphinxProbdef(object):
    """
    Probability definition file used for Sphinx class language models.
    """
    def __init__(self, infile=None):
        self.classes = {}
        if infile is not None:
            self.read(infile)

    def read(self, infile):
        """
        Read probability definition from a file.
        """
        inclass = None
        classname = None
        for spam in infile:
            spam = spam.strip()
            if spam.startswith('#') or spam.startswith(';'):
                continue
            if spam == "":
                continue
            if inclass:
                parts = spam.split()
                if len(parts) == 2 \
                   and parts[0] == "END" and parts[1] == classname:
                    inclass = None
                else:
                    prob = 1.0
                    if len(parts) > 1:
                        prob = float(parts[1])
                    self.add_class_word(inclass, parts[0], prob)
            else:
                if spam.startswith('LMCLASS'):
                    foo, classname = spam.split()
                    self.add_class(classname)
                    inclass = classname

    def add_class(self, name):
        """
        Add a class to this probability definition.
        """
        self.classes[name] = {}

    def add_class_word(self, name, word, prob):
        """
        Add a word to a class in this probability definition.
        """
        self.classes[name][word] = prob

    def write(self, outfile):
        """
        Write out probability definition to a file.
        """
        for c in self.classes:
            outfile.write("LMCLASS %s\n" % c)
            for word, prob in self.classes[c]:
                outfile.write("%s %g\n" % (word, prob))
            outfile.write("END %s\n" % c)
            outfile.write("\n")

    def normalize(self):
        """
        Normalize probabilities.
        """
        for c in self.classes:
            t = sum(self.classes[c].values())
            if t != 0:
                for w in self.classes[c]:
                    self.classes[c][w] /= t


def build_classfst(probdef, isyms=None):
    """
    Build an FST from the classes in a Sphinx probability definition
    file.  This transducer maps words to classes, and can either be
    composed with the input, or pre-composed with the language model.
    In the latter case you can project the resulting transducer to its
    input to obtain an equivalent non-class-based model.
    """
    if not isinstance(probdef, SphinxProbdef):
        probdef = SphinxProbdef(probdef)
    fst = openfst.StdVectorFst()
    if isyms:
        symtab = isyms
    else:
        symtab = openfst.SymbolTable("words")
        symtab.AddSymbol("&epsilon;")
    st = fst.AddState()
    fst.SetStart(st)
    fst.SetFinal(st, 0)
    for word, label in symtab:
        if label == openfst.epsilon:
            continue
        fst.AddArc(st, label, label, 0, st)
    for c in probdef.classes:
        clabel = symtab.AddSymbol(c)
        for word, prob in probdef.classes[c].items():
            wlabel = symtab.AddSymbol(word)
            fst.AddArc(st, wlabel, clabel, -math.log(prob), st)
    fst.SetOutputSymbols(symtab)
    fst.SetInputSymbols(symtab)
    return fst


def build_class_lmfst(lm, probdef, use_phi=False):
    """
    Build an FST from a class-based language model.  By default this
    returns the lazy composition of the class definition transducer
    and the language model.  To obtain the full language model, create
    a VectorFst from it and project it to its input.
    """
    lmfst = build_lmfst(lm, use_phi)
    classfst = build_classfst(probdef, lmfst.InputSymbols())
    openfst.ArcSortInput(lmfst)
    openfst.ArcSortInput(classfst)
    return openfst.StdComposeFst(classfst, lmfst)


def build_dictfst(lmfst):
    """
    Build a character-to-word FST based on the symbol table of lmfst.
    """
    insym = openfst.SymbolTable("letters")
    insym.AddSymbol("&epsilon;")
    outsym = lmfst.InputSymbols()
    fst = openfst.StdVectorFst()
    start = fst.AddState()
    fst.SetStart(start)
    final = fst.AddState()
    fst.SetFinal(final, 0)

    for w, wsym in outsym:
        if wsym == 0:
            continue
        # Use a single symbol for end-of-sentence
        if w == '</s>':
            w = [
                w,
            ]
        for c in w:
            csym = insym.AddSymbol(c)

    for w, wsym in outsym:
        if wsym == 0:
            continue
        wsym = outsym.Find(w)
        # Add an epsilon:word arc to the first state of this word
        prev = fst.AddState()
        fst.AddArc(start, openfst.StdArc(0, wsym, 0, prev))
        # Use a single symbol for end-of-sentence
        if w == '</s>':
            w = [
                w,
            ]
        for c in w:
            csym = insym.Find(c)
            next = fst.AddState()
            fst.AddArc(prev, openfst.StdArc(csym, 0, 0, next))
            prev = next
        # And an epsilon arc to the final state
        fst.AddArc(prev, openfst.StdArc(0, 0, 0, final))
    fst.SetInputSymbols(insym)
    fst.SetOutputSymbols(outsym)
    return fst


def fst2pdf(fst, outfile, acceptor=False):
    """
    Draw an FST as a PDF using fstdraw and dot.
    """
    tempdir = tempfile.mkdtemp()
    fstfile = os.path.join(tempdir, "output.fst")
    fst.Write(fstfile)
    if acceptor:
        acceptor = "--acceptor"
    else:
        acceptor = ""
    rv = os.system("fstdraw %s '%s' | dot -Tpdf > '%s'" %
                   (acceptor, fstfile, outfile))
    os.unlink(fstfile)
    os.rmdir(tempdir)
    return rv


def sent2fst(txt, fstclass=openfst.StdVectorFst, isyms=None, omitstart=True):
    """
    Convert a list of words, or a string of whitespace-separated
    tokens, to a sentence FST.
    """
    fst = fstclass()
    start = fst.AddState()
    fst.SetStart(start)
    if isyms:
        symtab = isyms
    else:
        symtab = openfst.SymbolTable("words")
        symtab.AddSymbol("&epsilon;")
    prev = start
    if isinstance(txt, str):
        txt = txt.split()
    for c in txt:
        if omitstart and c == '<s>':
            continue
        nxt = fst.AddState()
        if isyms:
            sym = isyms.Find(c)
            if sym == -1:
                # print "Warning, unknown word", c
                continue
        else:
            sym = symtab.AddSymbol(c)
        # print prev, sym, nxt
        fst.AddArc(prev, sym, sym, 0, nxt)
        prev = nxt
    fst.SetFinal(nxt, 0)
    fst.SetInputSymbols(symtab)
    fst.SetOutputSymbols(symtab)
    return fst


def str2fst(txt, fstclass=openfst.StdVectorFst):
    """
    Convert a text string to an FST.
    """
    fst = fstclass()
    start = fst.AddState()
    fst.SetStart(start)
    symtab = openfst.SymbolTable("chars")
    symtab.AddSymbol("&epsilon;")
    prev = start
    for c in txt:
        nxt = fst.AddState()
        sym = symtab.AddSymbol(c)
        fst.AddArc(prev, sym, sym, 0, nxt)
        prev = nxt
    fst.SetFinal(nxt, 0)
    fst.SetInputSymbols(symtab)
    fst.SetOutputSymbols(symtab)
    return fst


def strset2fst(strs, fstclass=openfst.StdVectorFst):
    """
    Build a dictionary lookup FST for a set of strings.
    """
    fst = fstclass()
    isyms = openfst.SymbolTable("chars")
    osyms = openfst.SymbolTable("words")
    isyms.AddSymbol("&epsilon;")
    osyms.AddSymbol("&epsilon;")
    start = fst.AddState()
    fst.SetStart(start)
    for s in strs:
        prev = start
        for c in s:
            nxt = fst.AddState()
            isym = isyms.AddSymbol(c)
            fst.AddArc(prev, isym, 0, 0, nxt)
            prev = nxt
        nxt = fst.AddState()
        osym = osyms.AddSymbol(s)
        fst.AddArc(prev, 0, osym, 0, nxt)
        fst.SetFinal(nxt, 0)
    dfst = fstclass()
    openfst.Determinize(fst, dfst)
    openfst.RmEpsilon(dfst)
    dfst.SetInputSymbols(isyms)
    dfst.SetOutputSymbols(osyms)
    return dfst


def lmfst_eval(lmfst, sent):
    sentfst = sent2fst(sent, openfst.StdVectorFst, lmfst.InputSymbols())
    phi = lmfst.InputSymbols().Find("&phi;")
    if phi != -1:
        opts = openfst.StdPhiComposeOptions()
        opts.matcher1 = openfst.StdPhiMatcher(sentfst, openfst.MATCH_NONE)
        opts.matcher2 = openfst.StdPhiMatcher(lmfst, openfst.MATCH_INPUT, phi)
        c = openfst.StdComposeFst(sentfst, lmfst, opts)
    else:
        c = openfst.StdComposeFst(sentfst, lmfst)
    o = openfst.StdVectorFst()
    openfst.ShortestPath(c, o, 1)
    st = o.Start()
    ll = 0
    while st != -1 and o.NumArcs(st):
        a = o.GetArc(st, 0)
        #        print o.InputSymbols().Find(a.ilabel), \
        #              o.OutputSymbols().Find(a.olabel), \
        #              -a.weight.Value() / math.log(10)
        ll -= a.weight.Value()
        st = a.nextstate
    return ll


def lm_eval(lm, sent):
    sent = [x for x in sent.split() if not x.startswith('++')]
    ll = 0
    for i in range(len(sent)):
        if sent[i] == '<s>':
            continue
        prob = lm.prob(sent[i::-1])
        # print sent[i::-1], prob / math.log(10), bo
        ll += prob
    return ll


if __name__ == '__main__':
    lmf, fstf = sys.argv[1:]
    lm = sphinxbase.NGramModel(lmf)
    fst = build_lmfst(lm)
    fst.Write(fstf)
