File: lat_rescore_fst.py

package info (click to toggle)
sphinxtrain 5.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 32,572 kB
  • sloc: ansic: 94,052; perl: 8,939; python: 6,702; cpp: 2,044; makefile: 6
file content (80 lines) | stat: -rw-r--r-- 2,705 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python3

# Copyright (c) 2010  Carnegie Mellon University
#
# You may copy and modify this freely under the same terms as
# Sphinx-III
"""
Rescore a lattice using a language model FST (or a set of them).
"""

__author__ = "David Huggins-Daines <dhdaines@gmail.com>"
__version__ = "$Revision $"

import openfst
from cmusphinx import lattice, lat2fsg
import sys
import os


def lat_rescore(dag, lmfst, lw=9.5):
    """
    Rescore a lattice using a language model FST.
    """
    fst = lat2fsg.build_lattice_fsg(dag, lmfst.InputSymbols(), 1. / lw)
    phi = lmfst.InputSymbols().Find("&phi;")
    if phi != -1:
        opts = openfst.StdPhiComposeOptions()
        opts.matcher1 = openfst.StdPhiMatcher(fst, openfst.MATCH_NONE)
        opts.matcher2 = openfst.StdPhiMatcher(lmfst, openfst.MATCH_INPUT, phi)
        c = openfst.StdComposeFst(fst, lmfst, opts)
    else:
        c = openfst.StdComposeFst(fst, lmfst)
    o = openfst.StdVectorFst()
    openfst.ShortestPath(c, o, 1)
    words = ['<s>']
    st = o.Start()
    score = 0
    while st != -1 and o.NumArcs(st):
        a = o.GetArc(st, 0)
        if a.olabel != 0:
            words.append(lmfst.InputSymbols().Find(a.ilabel))
        score -= a.weight.Value()
        st = a.nextstate
    return words, score


if __name__ == '__main__':
    from optparse import OptionParser
    parser = OptionParser(usage="%prog CTL LATDIR [LMFST]")
    parser.add_option("--lmnamectl")
    parser.add_option("--lmdir", default=".")
    parser.add_option("--lw", type="float", default=7)
    opts, args = parser.parse_args(sys.argv[1:])
    ctlfile, latdir = args[0:2]
    if len(args) > 2:
        lmfst = openfst.StdVectorFst.Read(args[2])
        lmnamectl = None
    elif opts.lmnamectl:
        lmnamectl = open(opts.lmnamectl)
        lmfsts = {}
    else:
        parser.error("either --lmnamectl or LMFST must be given")
    for spam in open(ctlfile):
        if lmnamectl:
            lmname = lmnamectl.readline().strip()
            if lmname not in lmfsts:
                lmfsts[lmname] = openfst.StdVectorFst.Read(
                    os.path.join(opts.lmdir, lmname + ".arpa.fst"))
            lmfst = lmfsts[lmname]
        try:
            dag = lattice.Dag(os.path.join(latdir, spam.strip() + ".lat.gz"))
        except IOError:
            try:
                dag = lattice.Dag(os.path.join(latdir, spam.strip() + ".lat"))
            except IOError:
                dag = lattice.Dag(
                    htk_file=os.path.join(latdir,
                                          spam.strip() + ".slf"))
        words, score = lat_rescore(dag, lmfst, opts.lw)
        print(" ".join(words), "(%s %f)" % (spam.strip(), score))