File: inference.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (92 lines) | stat: -rw-r--r-- 2,890 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import argparse

import torch
import torchaudio
from processing import NormalizeDB
from torchaudio.datasets import LJSPEECH
from torchaudio.models import wavernn
from torchaudio.models.wavernn import _MODEL_CONFIG_AND_URLS
from torchaudio.transforms import MelSpectrogram
from wavernn_inference_wrapper import WaveRNNInferenceWrapper


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--output-wav-path",
        default="./output.wav",
        type=str,
        metavar="PATH",
        help="The path to output the reconstructed wav file.",
    )
    parser.add_argument(
        "--jit", default=False, action="store_true", help="If used, the model and inference function is jitted."
    )
    parser.add_argument("--no-batch-inference", default=False, action="store_true", help="Don't use batch inference.")
    parser.add_argument(
        "--no-mulaw", default=False, action="store_true", help="Don't use mulaw decoder to decoder the signal."
    )
    parser.add_argument(
        "--checkpoint-name",
        default="wavernn_10k_epochs_8bits_ljspeech",
        choices=list(_MODEL_CONFIG_AND_URLS.keys()),
        help="Select the WaveRNN checkpoint.",
    )
    parser.add_argument(
        "--batch-timesteps",
        default=100,
        type=int,
        help="The time steps for each batch. Only used when batch inference is used",
    )
    parser.add_argument(
        "--batch-overlap",
        default=5,
        type=int,
        help="The overlapping time steps between batches. Only used when batch inference is used",
    )
    args = parser.parse_args()
    return args


def main(args):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    waveform, sample_rate, _, _ = LJSPEECH("./", download=True)[0]

    mel_kwargs = {
        "sample_rate": sample_rate,
        "n_fft": 2048,
        "f_min": 40.0,
        "n_mels": 80,
        "win_length": 1100,
        "hop_length": 275,
        "mel_scale": "slaney",
        "norm": "slaney",
        "power": 1,
    }
    transforms = torch.nn.Sequential(
        MelSpectrogram(**mel_kwargs),
        NormalizeDB(min_level_db=-100, normalization=True),
    )
    mel_specgram = transforms(waveform)

    wavernn_model = wavernn(args.checkpoint_name).eval().to(device)
    wavernn_inference_model = WaveRNNInferenceWrapper(wavernn_model)

    if args.jit:
        wavernn_inference_model = torch.jit.script(wavernn_inference_model)

    with torch.no_grad():
        output = wavernn_inference_model(
            mel_specgram.to(device),
            mulaw=(not args.no_mulaw),
            batched=(not args.no_batch_inference),
            timesteps=args.batch_timesteps,
            overlap=args.batch_overlap,
        )

    torchaudio.save(args.output_wav_path, output, sample_rate=sample_rate)


if __name__ == "__main__":
    args = parse_args()
    main(args)