1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
|
#!/usr/bin/env python3
import os
import sys
from cmusphinx import lattice
import sphinxbase
if __name__ == '__main__':
if len(sys.argv) != 11:
sys.stderr.write(
"Usage: %s ABEAM NBEAM LMWEIGHT LMFILE DENLATDIR "
"PRUNED_DENLATDIR FILELIST TRANSFILE FILECOUNT FILEOFFSET\n"
% (sys.argv[0]))
sys.exit(1)
# print command line
command = ''
for argv in sys.argv:
command += argv + ' '
print("%s\n" % command)
abeam, nbeam, lw, lmfile, denlatdir, pruned_denlatdir, \
ctlfile, transfile, filecount, fileoffset = sys.argv[1:]
abeam = float(abeam)
nbeam = float(nbeam)
lw = float(lw)
start = int(fileoffset)
end = int(fileoffset) + int(filecount)
# load language model
lm = sphinxbase.NGramModel(lmfile)
# read control file
f = open(ctlfile, 'r')
ctl = f.readlines()
f.close()
# read transcription file
f = open(transfile, 'r')
ref = f.readlines()
f.close()
sentcount = 0
wer = 0
nodecount = 0
edgecount = 0
density = 0
# prune lattices one by one
for i in range(start, end):
c = ctl[i].strip()
r = ref[i].split()
del r[-1]
if r[0] != '<s>':
r.insert(0, '<s>')
if r[-1] != '</s>':
r.append('</s>')
r = [x for x in r if not lattice.is_filler(x)]
print("process sent: %s" % c)
# load lattice
print("\t load lattice ...")
dag = lattice.Dag(os.path.join(denlatdir, c + ".lat.gz"))
dag.bypass_fillers()
dag.remove_unreachable()
# prune lattice
dag.edges_unigram_score(lm, lw)
dag.dt_posterior()
# edge pruning
print("\t edge pruning ...")
dag.forward_edge_prune(abeam)
dag.backward_edge_prune(abeam)
dag.remove_unreachable()
# node pruning
print("\t node pruning ...")
dag.post_node_prune(nbeam)
dag.remove_unreachable()
# calculate error
err, bt = dag.minimum_error(r)
# save pruned lattice
print("\t saving pruned lattice ...\n")
dag.dag2sphinx(os.path.join(pruned_denlatdir, c + ".lat.gz"))
sentcount += 1
nodecount += dag.n_nodes()
edgecount += dag.n_edges()
wer += float(err) / len(r)
density += float(dag.n_edges()) / len(r)
print("Average Lattice Word Error Rate: %.2f%%" % (wer / sentcount * 100))
print("Average Lattice Density: %.2f" % (float(density) / sentcount))
print("Average Number of Node: %.2f" % (float(nodecount) / sentcount))
print("Average Number of Arc: %.2f" % (float(edgecount) / sentcount))
print("ALL DONE")
|