1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
|
#!/usr/bin/env python3
"""Evaluate the lightning module by loading the checkpoint, the SentencePiece model, and the global_stats.json.
Example:
python eval.py --model-type tedlium3 --checkpoint-path ./experiments/checkpoints/epoch=119-step=254999.ckpt
--dataset-path ./datasets/tedlium --sp-model-path ./spm_bpe_500.model
"""
import logging
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import torch
import torchaudio
from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_MUSTC, MODEL_TYPE_TEDLIUM3
from librispeech.lightning import LibriSpeechRNNTModule
from mustc.lightning import MuSTCRNNTModule
from tedlium3.lightning import TEDLIUM3RNNTModule
logger = logging.getLogger(__name__)
def compute_word_level_distance(seq1, seq2):
return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split())
def run_eval_subset(model, dataloader, subset):
total_edit_distance = 0
total_length = 0
with torch.no_grad():
for idx, (batch, transcripts) in enumerate(dataloader):
actual = transcripts[0]
predicted = model(batch)
total_edit_distance += compute_word_level_distance(actual, predicted)
total_length += len(actual.split())
if idx % 100 == 0:
logger.info(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
logger.info(f"Final WER for {subset} set: {total_edit_distance / total_length}")
def run_eval(model, model_type):
if model_type == MODEL_TYPE_LIBRISPEECH:
dataloader = model.test_dataloader()
run_eval_subset(model, dataloader, "test")
elif model_type == MODEL_TYPE_TEDLIUM3:
dev_loader = model.dev_dataloader()
test_loader = model.test_dataloader()
run_eval_subset(model, dev_loader, "dev")
run_eval_subset(model, test_loader, "test")
elif model_type == MODEL_TYPE_MUSTC:
dev_loader = model.dev_dataloader()
test_common_loader = model.test_common_dataloader()
test_he_loader = model.test_he_dataloader()
run_eval_subset(model, dev_loader, "dev")
run_eval_subset(model, test_common_loader, "tst-COMMON")
run_eval_subset(model, test_he_loader, "tst-HE")
else:
raise ValueError(f"Encountered unsupported model type {model_type}.")
def get_lightning_module(args):
if args.model_type == MODEL_TYPE_LIBRISPEECH:
return LibriSpeechRNNTModule.load_from_checkpoint(
args.checkpoint_path,
librispeech_path=str(args.dataset_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
)
elif args.model_type == MODEL_TYPE_TEDLIUM3:
return TEDLIUM3RNNTModule.load_from_checkpoint(
args.checkpoint_path,
tedlium_path=str(args.dataset_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
)
elif args.model_type == MODEL_TYPE_MUSTC:
return MuSTCRNNTModule.load_from_checkpoint(
args.checkpoint_path,
mustc_path=str(args.dataset_path),
sp_model_path=str(args.sp_model_path),
global_stats_path=str(args.global_stats_path),
)
else:
raise ValueError(f"Encountered unsupported model type {args.model_type}.")
def parse_args():
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--model-type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3, MODEL_TYPE_MUSTC], required=True
)
parser.add_argument(
"--checkpoint-path",
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
)
parser.add_argument(
"--global-stats-path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--dataset-path",
type=pathlib.Path,
help="Path to dataset.",
)
parser.add_argument(
"--sp-model-path",
type=pathlib.Path,
help="Path to SentencePiece model.",
)
parser.add_argument(
"--use-cuda",
action="store_true",
default=False,
help="Run using CUDA.",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
def init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = parse_args()
init_logger(args.debug)
model = get_lightning_module(args)
if args.use_cuda:
model = model.to(device="cuda")
run_eval(model, args.model_type)
if __name__ == "__main__":
cli_main()
|