1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
|
# Copyright (c) 2006 Carnegie Mellon University
#
# You may copy and modify this freely under the same terms as
# Sphinx-III
"""Read/write Sphinx-III Gaussian mixture weight files.
This module reads and writes the Gaussian mixture weight files used by
SphinxTrain, Sphinx-III, and PocketSphinx.
"""
__author__ = "David Huggins-Daines <dhdaines@gmail.com>"
__version__ = "$Revision$"
from .s3file import S3File, S3File_write
import os
def open(filename, mode="rb"):
if mode in ("r", "rb"):
return S3MixwFile(filename)
elif mode in ("w", "wb"):
return S3MixwFile_write(filename)
else:
raise Exception("mode must be 'r', 'rb', 'w', or 'wb'")
class S3MixwFile(S3File):
"Read Sphinx-III format mixture weight files"
def __init__(self, filename, mode="rb"):
super().__init__(filename=filename, mode=mode)
self._params = self._load()
def readgauheader(self):
if self.fileattr["version"] != "1.0":
raise Exception("Version mismatch: must be 1.0 but is "
+ self.fileattr["version"])
def _load(self):
self.readgauheader()
self.fh.seek(self.data_start, 0)
return self.read3d()
class S3MixwFile_write(S3File_write):
"Write Sphinx-III format mixture weight files"
def writeall(self, stuff):
self.write3d(stuff)
def accumdirs(accumdirs):
"Read and accumulate counts from several directories"
mixw = None
for d in accumdirs:
try:
submixw = S3MixwFile(os.path.join(d, "mixw_counts"), "rb")
except OSError:
submixw = None
continue
if mixw is None:
mixw = submixw
else:
mixw._params += submixw._params
return mixw
|