File: evaluate.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (188 lines) | stat: -rw-r--r-- 6,462 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import argparse
import logging
from typing import Dict, List

import torch
import torch.nn.functional as F
import torchaudio
from torchaudio.models.decoder import ctc_decoder, CTCDecoder, download_pretrained_files
from utils import _get_id2label

logger = logging.getLogger(__name__)


def _load_checkpoint(checkpoint: str) -> torch.nn.Module:
    model = torchaudio.models.hubert_base(aux_num_out=29)
    checkpoint = torch.load(checkpoint, map_location="cpu")
    state_dict = checkpoint["state_dict"]
    new_state_dict = {}
    for k in state_dict:
        if "model.wav2vec2" in k:
            new_state_dict[k.replace("model.wav2vec2.", "")] = state_dict[k]
        elif "aux" in k:
            new_state_dict[k] = state_dict[k]
    model.load_state_dict(new_state_dict)
    return model


def _viterbi_decode(emission: torch.Tensor, id2token: Dict, blank_idx: int = 0) -> List[str]:
    """Run greedy decoding for ctc outputs.

    Args:
        emission (torch.Tensor): Output of CTC layer. Tensor with dimensions (..., time, num_tokens).
        id2token (Dictionary): The dictionary that maps indices of emission's last dimension
            to the corresponding tokens.

    Returns:
        (List of str): The decoding result. List of string in lower case.
    """
    hypothesis = emission.argmax(-1).unique_consecutive()
    hypothesis = hypothesis[hypothesis != blank_idx]
    hypothesis = "".join(id2token[int(i)] for i in hypothesis).replace("|", " ").strip()
    return hypothesis.split()


def _ctc_decode(emission, decoder: CTCDecoder) -> List[str]:
    """Run CTC decoding with a KenLM language model.

    Args:
        emission (torch.Tensor): Output of CTC layer. Tensor with dimensions `(..., time, num_tokens)`.
        decoder (CTCDecoder): The initialized CTCDecoder.

    Returns:
        (List of str): The decoding result. List of string in lower case.
    """
    hypothesis = decoder(emission)
    hypothesis = hypothesis[0][0].words
    hypothesis = [word for word in hypothesis if word != " "]
    return hypothesis


def run_inference(args):
    if args.use_gpu:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    # Load the fine-tuned HuBERTPretrainModel from checkpoint.
    model = _load_checkpoint(args.checkpoint)
    model.eval().to(device)

    if args.use_lm:
        # get decoder files
        files = download_pretrained_files("librispeech-4-gram")
        decoder = ctc_decoder(
            lexicon=files.lexicon,
            tokens=files.tokens,
            lm=files.lm,
            nbest=args.nbest,
            beam_size=args.beam_size,
            beam_size_token=args.beam_size_token,
            beam_threshold=args.beam_threshold,
            lm_weight=args.lm_weight,
            word_score=args.word_score,
            unk_score=args.unk_score,
            sil_score=args.sil_score,
            log_add=False,
        )
    else:
        id2token = _get_id2label()

    dataset = torchaudio.datasets.LIBRISPEECH(args.librispeech_path, url=args.split)

    total_edit_distance = 0
    total_length = 0
    for idx, sample in enumerate(dataset):
        waveform, _, transcript, _, _, _ = sample
        transcript = transcript.strip().lower().strip().replace("\n", "")

        with torch.inference_mode():
            emission, _ = model(waveform.to(device))
            emission = F.log_softmax(emission, dim=-1)
        if args.use_lm:
            hypothesis = _ctc_decode(emission.cpu(), decoder)
        else:
            hypothesis = _viterbi_decode(emission, id2token)

        total_edit_distance += torchaudio.functional.edit_distance(hypothesis, transcript.split())
        total_length += len(transcript.split())

        if idx % 100 == 0:
            logger.info(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
    logger.info(f"Final WER: {total_edit_distance / total_length}")


def _parse_args():
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawTextHelpFormatter,
    )
    parser.add_argument(
        "--librispeech-path",
        type=str,
        help="Folder where LibriSpeech dataset is stored.",
    )
    parser.add_argument(
        "--split",
        type=str,
        choices=["dev-clean", "dev-other", "test-clean", "test-other"],
        help="LibriSpeech dataset split. (Default: 'test-clean')",
        default="test-clean",
    )
    parser.add_argument(
        "--checkpoint",
        type=str,
        help="The checkpoint path of fine-tuned HuBERTPretrainModel.",
    )
    parser.add_argument("--use-lm", action="store_true", help="Whether to use language model for decoding.")
    parser.add_argument("--nbest", type=int, default=1, help="Number of best hypotheses to return.")
    parser.add_argument(
        "--beam-size",
        type=int,
        default=1500,
        help="Beam size for determining number of hypotheses to store. (Default: 1500)",
    )
    parser.add_argument(
        "--beam-size-token",
        type=int,
        default=29,
        help="Number of tokens to consider at each beam search step. (Default: 29)",
    )
    parser.add_argument(
        "--beam-threshold", type=int, default=100, help="Beam threshold for pruning hypotheses. (Default: 100)"
    )
    parser.add_argument(
        "--lm-weight",
        type=float,
        default=2.46,
        help="Languge model weight in decoding. (Default: 2.46)",
    )
    parser.add_argument(
        "--word-score",
        type=float,
        default=-0.59,
        help="Word insertion score in decoding. (Default: -0.59)",
    )
    parser.add_argument(
        "--unk-score", type=float, default=float("-inf"), help="Unknown word insertion score. (Default: -inf)"
    )
    parser.add_argument("--sil-score", type=float, default=0, help="Silence insertion score. (Default: 0)")
    parser.add_argument("--use-gpu", action="store_true", help="Whether to use GPU for decoding.")
    parser.add_argument("--debug", action="store_true", help="Whether to use debug level for logging.")
    return parser.parse_args()


def _init_logger(debug):
    fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
    level = logging.DEBUG if debug else logging.INFO
    logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")


def _main():
    args = _parse_args()
    _init_logger(args.debug)
    run_inference(args)


if __name__ == "__main__":
    _main()