1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
|
import argparse
import logging
from typing import Dict, List
import torch
import torch.nn.functional as F
import torchaudio
from torchaudio.models.decoder import ctc_decoder, CTCDecoder, download_pretrained_files
from utils import _get_id2label
logger = logging.getLogger(__name__)
def _load_checkpoint(checkpoint: str) -> torch.nn.Module:
model = torchaudio.models.hubert_base(aux_num_out=29)
checkpoint = torch.load(checkpoint, map_location="cpu")
state_dict = checkpoint["state_dict"]
new_state_dict = {}
for k in state_dict:
if "model.wav2vec2" in k:
new_state_dict[k.replace("model.wav2vec2.", "")] = state_dict[k]
elif "aux" in k:
new_state_dict[k] = state_dict[k]
model.load_state_dict(new_state_dict)
return model
def _viterbi_decode(emission: torch.Tensor, id2token: Dict, blank_idx: int = 0) -> List[str]:
"""Run greedy decoding for ctc outputs.
Args:
emission (torch.Tensor): Output of CTC layer. Tensor with dimensions (..., time, num_tokens).
id2token (Dictionary): The dictionary that maps indices of emission's last dimension
to the corresponding tokens.
Returns:
(List of str): The decoding result. List of string in lower case.
"""
hypothesis = emission.argmax(-1).unique_consecutive()
hypothesis = hypothesis[hypothesis != blank_idx]
hypothesis = "".join(id2token[int(i)] for i in hypothesis).replace("|", " ").strip()
return hypothesis.split()
def _ctc_decode(emission, decoder: CTCDecoder) -> List[str]:
"""Run CTC decoding with a KenLM language model.
Args:
emission (torch.Tensor): Output of CTC layer. Tensor with dimensions `(..., time, num_tokens)`.
decoder (CTCDecoder): The initialized CTCDecoder.
Returns:
(List of str): The decoding result. List of string in lower case.
"""
hypothesis = decoder(emission)
hypothesis = hypothesis[0][0].words
hypothesis = [word for word in hypothesis if word != " "]
return hypothesis
def run_inference(args):
if args.use_gpu:
device = torch.device("cuda")
else:
device = torch.device("cpu")
# Load the fine-tuned HuBERTPretrainModel from checkpoint.
model = _load_checkpoint(args.checkpoint)
model.eval().to(device)
if args.use_lm:
# get decoder files
files = download_pretrained_files("librispeech-4-gram")
decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
nbest=args.nbest,
beam_size=args.beam_size,
beam_size_token=args.beam_size_token,
beam_threshold=args.beam_threshold,
lm_weight=args.lm_weight,
word_score=args.word_score,
unk_score=args.unk_score,
sil_score=args.sil_score,
log_add=False,
)
else:
id2token = _get_id2label()
dataset = torchaudio.datasets.LIBRISPEECH(args.librispeech_path, url=args.split)
total_edit_distance = 0
total_length = 0
for idx, sample in enumerate(dataset):
waveform, _, transcript, _, _, _ = sample
transcript = transcript.strip().lower().strip().replace("\n", "")
with torch.inference_mode():
emission, _ = model(waveform.to(device))
emission = F.log_softmax(emission, dim=-1)
if args.use_lm:
hypothesis = _ctc_decode(emission.cpu(), decoder)
else:
hypothesis = _viterbi_decode(emission, id2token)
total_edit_distance += torchaudio.functional.edit_distance(hypothesis, transcript.split())
total_length += len(transcript.split())
if idx % 100 == 0:
logger.info(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
logger.info(f"Final WER: {total_edit_distance / total_length}")
def _parse_args():
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--librispeech-path",
type=str,
help="Folder where LibriSpeech dataset is stored.",
)
parser.add_argument(
"--split",
type=str,
choices=["dev-clean", "dev-other", "test-clean", "test-other"],
help="LibriSpeech dataset split. (Default: 'test-clean')",
default="test-clean",
)
parser.add_argument(
"--checkpoint",
type=str,
help="The checkpoint path of fine-tuned HuBERTPretrainModel.",
)
parser.add_argument("--use-lm", action="store_true", help="Whether to use language model for decoding.")
parser.add_argument("--nbest", type=int, default=1, help="Number of best hypotheses to return.")
parser.add_argument(
"--beam-size",
type=int,
default=1500,
help="Beam size for determining number of hypotheses to store. (Default: 1500)",
)
parser.add_argument(
"--beam-size-token",
type=int,
default=29,
help="Number of tokens to consider at each beam search step. (Default: 29)",
)
parser.add_argument(
"--beam-threshold", type=int, default=100, help="Beam threshold for pruning hypotheses. (Default: 100)"
)
parser.add_argument(
"--lm-weight",
type=float,
default=2.46,
help="Languge model weight in decoding. (Default: 2.46)",
)
parser.add_argument(
"--word-score",
type=float,
default=-0.59,
help="Word insertion score in decoding. (Default: -0.59)",
)
parser.add_argument(
"--unk-score", type=float, default=float("-inf"), help="Unknown word insertion score. (Default: -inf)"
)
parser.add_argument("--sil-score", type=float, default=0, help="Silence insertion score. (Default: 0)")
parser.add_argument("--use-gpu", action="store_true", help="Whether to use GPU for decoding.")
parser.add_argument("--debug", action="store_true", help="Whether to use debug level for logging.")
return parser.parse_args()
def _init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def _main():
args = _parse_args()
_init_logger(args.debug)
run_inference(args)
if __name__ == "__main__":
_main()
|