File: lattice_prune.py

package info (click to toggle)
sphinxtrain 5.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 32,572 kB
  • sloc: ansic: 94,052; perl: 8,939; python: 6,702; cpp: 2,044; makefile: 6
file content (100 lines) | stat: -rwxr-xr-x 2,735 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python3

import os
import sys
from cmusphinx import lattice
import sphinxbase

if __name__ == '__main__':
    if len(sys.argv) != 11:
        sys.stderr.write(
            "Usage: %s ABEAM NBEAM LMWEIGHT LMFILE DENLATDIR "
            "PRUNED_DENLATDIR FILELIST TRANSFILE FILECOUNT FILEOFFSET\n"
            % (sys.argv[0]))
        sys.exit(1)

    # print command line
    command = ''
    for argv in sys.argv:
        command += argv + ' '
    print("%s\n" % command)

    abeam, nbeam, lw, lmfile, denlatdir, pruned_denlatdir, \
        ctlfile, transfile, filecount, fileoffset = sys.argv[1:]

    abeam = float(abeam)
    nbeam = float(nbeam)
    lw = float(lw)
    start = int(fileoffset)
    end = int(fileoffset) + int(filecount)

    # load language model
    lm = sphinxbase.NGramModel(lmfile)

    # read control file
    f = open(ctlfile, 'r')
    ctl = f.readlines()
    f.close()

    # read transcription file
    f = open(transfile, 'r')
    ref = f.readlines()
    f.close()

    sentcount = 0
    wer = 0
    nodecount = 0
    edgecount = 0
    density = 0
    # prune lattices one by one
    for i in range(start, end):
        c = ctl[i].strip()
        r = ref[i].split()
        del r[-1]
        if r[0] != '<s>':
            r.insert(0, '<s>')
        if r[-1] != '</s>':
            r.append('</s>')
        r = [x for x in r if not lattice.is_filler(x)]

        print("process sent: %s" % c)

        # load lattice
        print("\t load lattice ...")
        dag = lattice.Dag(os.path.join(denlatdir, c + ".lat.gz"))
        dag.bypass_fillers()
        dag.remove_unreachable()

        # prune lattice
        dag.edges_unigram_score(lm, lw)
        dag.dt_posterior()

        # edge pruning
        print("\t edge pruning ...")
        dag.forward_edge_prune(abeam)
        dag.backward_edge_prune(abeam)
        dag.remove_unreachable()

        # node pruning
        print("\t node pruning ...")
        dag.post_node_prune(nbeam)
        dag.remove_unreachable()

        # calculate error
        err, bt = dag.minimum_error(r)

        # save pruned lattice
        print("\t saving pruned lattice ...\n")
        dag.dag2sphinx(os.path.join(pruned_denlatdir, c + ".lat.gz"))

        sentcount += 1
        nodecount += dag.n_nodes()
        edgecount += dag.n_edges()
        wer += float(err) / len(r)
        density += float(dag.n_edges()) / len(r)

    print("Average Lattice Word Error Rate: %.2f%%" % (wer / sentcount * 100))
    print("Average Lattice Density: %.2f" % (float(density) / sentcount))
    print("Average Number of Node: %.2f" % (float(nodecount) / sentcount))
    print("Average Number of Arc: %.2f" % (float(edgecount) / sentcount))
    print("ALL DONE")