1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
|
#!/usr/bin/env python3
# Copyright (c) 2010 Carnegie Mellon University
#
# You may copy and modify this freely under the same terms as
# Sphinx-III
"""
Rescore a lattice using a language model FST (or a set of them).
"""
__author__ = "David Huggins-Daines <dhdaines@gmail.com>"
__version__ = "$Revision $"
import openfst
from cmusphinx import lattice, lat2fsg
import sys
import os
def lat_rescore(dag, lmfst, lw=9.5):
"""
Rescore a lattice using a language model FST.
"""
fst = lat2fsg.build_lattice_fsg(dag, lmfst.InputSymbols(), 1. / lw)
phi = lmfst.InputSymbols().Find("φ")
if phi != -1:
opts = openfst.StdPhiComposeOptions()
opts.matcher1 = openfst.StdPhiMatcher(fst, openfst.MATCH_NONE)
opts.matcher2 = openfst.StdPhiMatcher(lmfst, openfst.MATCH_INPUT, phi)
c = openfst.StdComposeFst(fst, lmfst, opts)
else:
c = openfst.StdComposeFst(fst, lmfst)
o = openfst.StdVectorFst()
openfst.ShortestPath(c, o, 1)
words = ['<s>']
st = o.Start()
score = 0
while st != -1 and o.NumArcs(st):
a = o.GetArc(st, 0)
if a.olabel != 0:
words.append(lmfst.InputSymbols().Find(a.ilabel))
score -= a.weight.Value()
st = a.nextstate
return words, score
if __name__ == '__main__':
from optparse import OptionParser
parser = OptionParser(usage="%prog CTL LATDIR [LMFST]")
parser.add_option("--lmnamectl")
parser.add_option("--lmdir", default=".")
parser.add_option("--lw", type="float", default=7)
opts, args = parser.parse_args(sys.argv[1:])
ctlfile, latdir = args[0:2]
if len(args) > 2:
lmfst = openfst.StdVectorFst.Read(args[2])
lmnamectl = None
elif opts.lmnamectl:
lmnamectl = open(opts.lmnamectl)
lmfsts = {}
else:
parser.error("either --lmnamectl or LMFST must be given")
for spam in open(ctlfile):
if lmnamectl:
lmname = lmnamectl.readline().strip()
if lmname not in lmfsts:
lmfsts[lmname] = openfst.StdVectorFst.Read(
os.path.join(opts.lmdir, lmname + ".arpa.fst"))
lmfst = lmfsts[lmname]
try:
dag = lattice.Dag(os.path.join(latdir, spam.strip() + ".lat.gz"))
except IOError:
try:
dag = lattice.Dag(os.path.join(latdir, spam.strip() + ".lat"))
except IOError:
dag = lattice.Dag(
htk_file=os.path.join(latdir,
spam.strip() + ".slf"))
words, score = lat_rescore(dag, lmfst, opts.lw)
print(" ".join(words), "(%s %f)" % (spam.strip(), score))
|