1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597
|
"""
ASR Inference with CTC Decoder
==============================
**Author**: `Caroline Chen <carolinechen@meta.com>`__
This tutorial shows how to perform speech recognition inference using a
CTC beam search decoder with lexicon constraint and KenLM language model
support. We demonstrate this on a pretrained wav2vec 2.0 model trained
using CTC loss.
"""
######################################################################
# Overview
# --------
#
# Beam search decoding works by iteratively expanding text hypotheses (beams)
# with next possible characters, and maintaining only the hypotheses with the
# highest scores at each time step. A language model can be incorporated into
# the scoring computation, and adding a lexicon constraint restricts the
# next possible tokens for the hypotheses so that only words from the lexicon
# can be generated.
#
# The underlying implementation is ported from `Flashlight <https://arxiv.org/pdf/2201.12465.pdf>`__'s
# beam search decoder. A mathematical formula for the decoder optimization can be
# found in the `Wav2Letter paper <https://arxiv.org/pdf/1609.03193.pdf>`__, and
# a more detailed algorithm can be found in this `blog
# <https://towardsdatascience.com/boosting-your-sequence-generation-performance-with-beam-search-language-model-decoding-74ee64de435a>`__.
#
# Running ASR inference using a CTC Beam Search decoder with a language
# model and lexicon constraint requires the following components
#
# - Acoustic Model: model predicting phonetics from audio waveforms
# - Tokens: the possible predicted tokens from the acoustic model
# - Lexicon: mapping between possible words and their corresponding
# tokens sequence
# - Language Model (LM): n-gram language model trained with the `KenLM
# library <https://kheafield.com/code/kenlm/>`__, or custom language
# model that inherits :py:class:`~torchaudio.models.decoder.CTCDecoderLM`
#
######################################################################
# Acoustic Model and Set Up
# -------------------------
#
# First we import the necessary utilities and fetch the data that we are
# working with
#
import torch
import torchaudio
print(torch.__version__)
print(torchaudio.__version__)
######################################################################
#
import time
from typing import List
import IPython
import matplotlib.pyplot as plt
from torchaudio.models.decoder import ctc_decoder
from torchaudio.utils import download_asset
######################################################################
#
# We use the pretrained `Wav2Vec 2.0 <https://arxiv.org/abs/2006.11477>`__
# Base model that is finetuned on 10 min of the `LibriSpeech
# dataset <http://www.openslr.org/12>`__, which can be loaded in using
# :data:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_10M`.
# For more detail on running Wav2Vec 2.0 speech
# recognition pipelines in torchaudio, please refer to `this
# tutorial <./speech_recognition_pipeline_tutorial.html>`__.
#
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_10M
acoustic_model = bundle.get_model()
######################################################################
# We will load a sample from the LibriSpeech test-other dataset.
#
speech_file = download_asset("tutorial-assets/ctc-decoding/1688-142285-0007.wav")
IPython.display.Audio(speech_file)
######################################################################
# The transcript corresponding to this audio file is
#
# .. code-block::
#
# i really was very much afraid of showing him how much shocked i was at some parts of what he said
#
waveform, sample_rate = torchaudio.load(speech_file)
if sample_rate != bundle.sample_rate:
waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
######################################################################
# Files and Data for Decoder
# --------------------------
#
# Next, we load in our token, lexicon, and language model data, which are used
# by the decoder to predict words from the acoustic model output. Pretrained
# files for the LibriSpeech dataset can be downloaded through torchaudio,
# or the user can provide their own files.
#
######################################################################
# Tokens
# ~~~~~~
#
# The tokens are the possible symbols that the acoustic model can predict,
# including the blank and silent symbols. It can either be passed in as a
# file, where each line consists of the tokens corresponding to the same
# index, or as a list of tokens, each mapping to a unique index.
#
# .. code-block::
#
# # tokens.txt
# _
# |
# e
# t
# ...
#
tokens = [label.lower() for label in bundle.get_labels()]
print(tokens)
######################################################################
# Lexicon
# ~~~~~~~
#
# The lexicon is a mapping from words to their corresponding tokens
# sequence, and is used to restrict the search space of the decoder to
# only words from the lexicon. The expected format of the lexicon file is
# a line per word, with a word followed by its space-split tokens.
#
# .. code-block::
#
# # lexcion.txt
# a a |
# able a b l e |
# about a b o u t |
# ...
# ...
#
######################################################################
# Language Model
# ~~~~~~~~~~~~~~
#
# A language model can be used in decoding to improve the results, by
# factoring in a language model score that represents the likelihood of
# the sequence into the beam search computation. Below, we outline the
# different forms of language models that are supported for decoding.
#
######################################################################
# No Language Model
# ^^^^^^^^^^^^^^^^^
#
# To create a decoder instance without a language model, set `lm=None`
# when initializing the decoder.
#
######################################################################
# KenLM
# ^^^^^
#
# This is an n-gram language model trained with the `KenLM
# library <https://kheafield.com/code/kenlm/>`__. Both the ``.arpa`` or
# the binarized ``.bin`` LM can be used, but the binary format is
# recommended for faster loading.
#
# The language model used in this tutorial is a 4-gram KenLM trained using
# `LibriSpeech <http://www.openslr.org/11>`__.
#
######################################################################
# Custom Language Model
# ^^^^^^^^^^^^^^^^^^^^^
#
# Users can define their own custom language model in Python, whether
# it be a statistical or neural network language model, using
# :py:class:`~torchaudio.models.decoder.CTCDecoderLM` and
# :py:class:`~torchaudio.models.decoder.CTCDecoderLMState`.
#
# For instance, the following code creates a basic wrapper around a PyTorch
# ``torch.nn.Module`` language model.
#
from torchaudio.models.decoder import CTCDecoderLM, CTCDecoderLMState
class CustomLM(CTCDecoderLM):
"""Create a Python wrapper around `language_model` to feed to the decoder."""
def __init__(self, language_model: torch.nn.Module):
CTCDecoderLM.__init__(self)
self.language_model = language_model
self.sil = -1 # index for silent token in the language model
self.states = {}
language_model.eval()
def start(self, start_with_nothing: bool = False):
state = CTCDecoderLMState()
with torch.no_grad():
score = self.language_model(self.sil)
self.states[state] = score
return state
def score(self, state: CTCDecoderLMState, token_index: int):
outstate = state.child(token_index)
if outstate not in self.states:
score = self.language_model(token_index)
self.states[outstate] = score
score = self.states[outstate]
return outstate, score
def finish(self, state: CTCDecoderLMState):
return self.score(state, self.sil)
######################################################################
# Downloading Pretrained Files
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Pretrained files for the LibriSpeech dataset can be downloaded using
# :py:func:`~torchaudio.models.decoder.download_pretrained_files`.
#
# Note: this cell may take a couple of minutes to run, as the language
# model can be large
#
from torchaudio.models.decoder import download_pretrained_files
files = download_pretrained_files("librispeech-4-gram")
print(files)
######################################################################
# Construct Decoders
# ------------------
# In this tutorial, we construct both a beam search decoder and a greedy decoder
# for comparison.
#
######################################################################
# Beam Search Decoder
# ~~~~~~~~~~~~~~~~~~~
# The decoder can be constructed using the factory function
# :py:func:`~torchaudio.models.decoder.ctc_decoder`.
# In addition to the previously mentioned components, it also takes in various beam
# search decoding parameters and token/word parameters.
#
# This decoder can also be run without a language model by passing in `None` into the
# `lm` parameter.
#
LM_WEIGHT = 3.23
WORD_SCORE = -0.26
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
nbest=3,
beam_size=1500,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
######################################################################
# Greedy Decoder
# ~~~~~~~~~~~~~~
#
class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels, blank=0):
super().__init__()
self.labels = labels
self.blank = blank
def forward(self, emission: torch.Tensor) -> List[str]:
"""Given a sequence emission over labels, get the best path
Args:
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns:
List[str]: The resulting transcript
"""
indices = torch.argmax(emission, dim=-1) # [num_seq,]
indices = torch.unique_consecutive(indices, dim=-1)
indices = [i for i in indices if i != self.blank]
joined = "".join([self.labels[i] for i in indices])
return joined.replace("|", " ").strip().split()
greedy_decoder = GreedyCTCDecoder(tokens)
######################################################################
# Run Inference
# -------------
#
# Now that we have the data, acoustic model, and decoder, we can perform
# inference. The output of the beam search decoder is of type
# :py:class:`~torchaudio.models.decoder.CTCHypothesis`, consisting of the
# predicted token IDs, corresponding words (if a lexicon is provided), hypothesis score,
# and timesteps corresponding to the token IDs. Recall the transcript corresponding to the
# waveform is
#
# .. code-block::
#
# i really was very much afraid of showing him how much shocked i was at some parts of what he said
#
actual_transcript = "i really was very much afraid of showing him how much shocked i was at some parts of what he said"
actual_transcript = actual_transcript.split()
emission, _ = acoustic_model(waveform)
######################################################################
# The greedy decoder gives the following result.
#
greedy_result = greedy_decoder(emission[0])
greedy_transcript = " ".join(greedy_result)
greedy_wer = torchaudio.functional.edit_distance(actual_transcript, greedy_result) / len(actual_transcript)
print(f"Transcript: {greedy_transcript}")
print(f"WER: {greedy_wer}")
######################################################################
# Using the beam search decoder:
#
beam_search_result = beam_search_decoder(emission)
beam_search_transcript = " ".join(beam_search_result[0][0].words).strip()
beam_search_wer = torchaudio.functional.edit_distance(actual_transcript, beam_search_result[0][0].words) / len(
actual_transcript
)
print(f"Transcript: {beam_search_transcript}")
print(f"WER: {beam_search_wer}")
######################################################################
# .. note::
#
# The :py:attr:`~torchaudio.models.decoder.CTCHypothesis.words`
# field of the output hypotheses will be empty if no lexicon
# is provided to the decoder. To retrieve a transcript with lexicon-free
# decoding, you can perform the following to retrieve the token indices,
# convert them to original tokens, then join them together.
#
# .. code::
#
# tokens_str = "".join(beam_search_decoder.idxs_to_tokens(beam_search_result[0][0].tokens))
# transcript = " ".join(tokens_str.split("|"))
#
# We see that the transcript with the lexicon-constrained beam search
# decoder produces a more accurate result consisting of real words, while
# the greedy decoder can predict incorrectly spelled words like “affrayd”
# and “shoktd”.
#
######################################################################
# Timestep Alignments
# -------------------
# Recall that one of the components of the resulting Hypotheses is timesteps
# corresponding to the token IDs.
#
timesteps = beam_search_result[0][0].timesteps
predicted_tokens = beam_search_decoder.idxs_to_tokens(beam_search_result[0][0].tokens)
print(predicted_tokens, len(predicted_tokens))
print(timesteps, timesteps.shape[0])
######################################################################
# Below, we visualize the token timestep alignments relative to the original waveform.
#
def plot_alignments(waveform, emission, tokens, timesteps):
fig, ax = plt.subplots(figsize=(32, 10))
ax.plot(waveform)
ratio = waveform.shape[0] / emission.shape[1]
word_start = 0
for i in range(len(tokens)):
if i != 0 and tokens[i - 1] == "|":
word_start = timesteps[i]
if tokens[i] != "|":
plt.annotate(tokens[i].upper(), (timesteps[i] * ratio, waveform.max() * 1.02), size=14)
elif i != 0:
word_end = timesteps[i]
ax.axvspan(word_start * ratio, word_end * ratio, alpha=0.1, color="red")
xticks = ax.get_xticks()
plt.xticks(xticks, xticks / bundle.sample_rate)
ax.set_xlabel("time (sec)")
ax.set_xlim(0, waveform.shape[0])
plot_alignments(waveform[0], emission, predicted_tokens, timesteps)
######################################################################
# Beam Search Decoder Parameters
# ------------------------------
#
# In this section, we go a little bit more in depth about some different
# parameters and tradeoffs. For the full list of customizable parameters,
# please refer to the
# :py:func:`documentation <torchaudio.models.decoder.ctc_decoder>`.
#
######################################################################
# Helper Function
# ~~~~~~~~~~~~~~~
#
def print_decoded(decoder, emission, param, param_value):
start_time = time.monotonic()
result = decoder(emission)
decode_time = time.monotonic() - start_time
transcript = " ".join(result[0][0].words).lower().strip()
score = result[0][0].score
print(f"{param} {param_value:<3}: {transcript} (score: {score:.2f}; {decode_time:.4f} secs)")
######################################################################
# nbest
# ~~~~~
#
# This parameter indicates the number of best hypotheses to return, which
# is a property that is not possible with the greedy decoder. For
# instance, by setting ``nbest=3`` when constructing the beam search
# decoder earlier, we can now access the hypotheses with the top 3 scores.
#
for i in range(3):
transcript = " ".join(beam_search_result[0][i].words).strip()
score = beam_search_result[0][i].score
print(f"{transcript} (score: {score})")
######################################################################
# beam size
# ~~~~~~~~~
#
# The ``beam_size`` parameter determines the maximum number of best
# hypotheses to hold after each decoding step. Using larger beam sizes
# allows for exploring a larger range of possible hypotheses which can
# produce hypotheses with higher scores, but it is computationally more
# expensive and does not provide additional gains beyond a certain point.
#
# In the example below, we see improvement in decoding quality as we
# increase beam size from 1 to 5 to 50, but notice how using a beam size
# of 500 provides the same output as beam size 50 while increase the
# computation time.
#
beam_sizes = [1, 5, 50, 500]
for beam_size in beam_sizes:
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
beam_size=beam_size,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "beam size", beam_size)
######################################################################
# beam size token
# ~~~~~~~~~~~~~~~
#
# The ``beam_size_token`` parameter corresponds to the number of tokens to
# consider for expanding each hypothesis at the decoding step. Exploring a
# larger number of next possible tokens increases the range of potential
# hypotheses at the cost of computation.
#
num_tokens = len(tokens)
beam_size_tokens = [1, 5, 10, num_tokens]
for beam_size_token in beam_size_tokens:
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
beam_size_token=beam_size_token,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "beam size token", beam_size_token)
######################################################################
# beam threshold
# ~~~~~~~~~~~~~~
#
# The ``beam_threshold`` parameter is used to prune the stored hypotheses
# set at each decoding step, removing hypotheses whose scores are greater
# than ``beam_threshold`` away from the highest scoring hypothesis. There
# is a balance between choosing smaller thresholds to prune more
# hypotheses and reduce the search space, and choosing a large enough
# threshold such that plausible hypotheses are not pruned.
#
beam_thresholds = [1, 5, 10, 25]
for beam_threshold in beam_thresholds:
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
beam_threshold=beam_threshold,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "beam threshold", beam_threshold)
######################################################################
# language model weight
# ~~~~~~~~~~~~~~~~~~~~~
#
# The ``lm_weight`` parameter is the weight to assign to the language
# model score which to accumulate with the acoustic model score for
# determining the overall scores. Larger weights encourage the model to
# predict next words based on the language model, while smaller weights
# give more weight to the acoustic model score instead.
#
lm_weights = [0, LM_WEIGHT, 15]
for lm_weight in lm_weights:
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
lm_weight=lm_weight,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "lm weight", lm_weight)
######################################################################
# additional parameters
# ~~~~~~~~~~~~~~~~~~~~~
#
# Additional parameters that can be optimized include the following
#
# - ``word_score``: score to add when word finishes
# - ``unk_score``: unknown word appearance score to add
# - ``sil_score``: silence appearance score to add
# - ``log_add``: whether to use log add for lexicon Trie smearing
#
|