File: build_pipeline_from_fairseq.py

package info (click to toggle)
pytorch-audio 0.13.1-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 8,592 kB
  • sloc: python: 41,137; cpp: 8,016; sh: 3,538; makefile: 24
file content (169 lines) | stat: -rw-r--r-- 4,705 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#!/usr/bin/evn python3
"""Build Speech Recognition pipeline based on fairseq's wav2vec2.0 and dump it to TorchScript file.

To use this script, you need `fairseq`.
"""
import argparse
import logging
import os
from typing import Tuple

import fairseq
import torch
import torchaudio
from greedy_decoder import Decoder
from torch.utils.mobile_optimizer import optimize_for_mobile
from torchaudio.models.wav2vec2.utils.import_fairseq import import_fairseq_model

TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 10):
    import torch.ao.quantization as tq
else:
    import torch.quantization as tq

_LG = logging.getLogger(__name__)


def _parse_args():
    parser = argparse.ArgumentParser(
        description=__doc__,
    )
    parser.add_argument("--model-file", required=True, help="Path to the input pretrained weight file.")
    parser.add_argument(
        "--dict-dir",
        help=(
            "Path to the directory in which `dict.ltr.txt` file is found. " "Required only when the model is finetuned."
        ),
    )
    parser.add_argument(
        "--output-path",
        help="Path to the directory, where the TorchScript-ed pipelines are saved.",
    )
    parser.add_argument(
        "--test-file",
        help="Path to a test audio file.",
    )
    parser.add_argument(
        "--debug",
        action="store_true",
        help=(
            "When enabled, individual components are separately tested "
            "for the numerical compatibility and TorchScript compatibility."
        ),
    )
    parser.add_argument("--quantize", action="store_true", help="Apply quantization to model.")
    parser.add_argument("--optimize-for-mobile", action="store_true", help="Apply optmization for mobile.")
    return parser.parse_args()


class Loader(torch.nn.Module):
    def forward(self, audio_path: str) -> torch.Tensor:
        waveform, sample_rate = torchaudio.load(audio_path)
        if sample_rate != 16000:
            waveform = torchaudio.functional.resample(waveform, float(sample_rate), 16000.0)
        return waveform


class Encoder(torch.nn.Module):
    def __init__(self, encoder: torch.nn.Module):
        super().__init__()
        self.encoder = encoder

    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        result, _ = self.encoder(waveform)
        return result[0]


def _get_decoder():
    labels = [
        "<s>",
        "<pad>",
        "</s>",
        "<unk>",
        "|",
        "E",
        "T",
        "A",
        "O",
        "N",
        "I",
        "H",
        "S",
        "R",
        "D",
        "L",
        "U",
        "M",
        "W",
        "C",
        "F",
        "G",
        "Y",
        "P",
        "B",
        "V",
        "K",
        "'",
        "X",
        "J",
        "Q",
        "Z",
    ]
    return Decoder(labels)


def _load_fairseq_model(input_file, data_dir=None):
    overrides = {}
    if data_dir:
        overrides["data"] = data_dir

    model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([input_file], arg_overrides=overrides)
    model = model[0]
    return model


def _get_model(model_file, dict_dir):
    original = _load_fairseq_model(model_file, dict_dir)
    model = import_fairseq_model(original.w2v_encoder)
    return model


def _main():
    args = _parse_args()
    _init_logging(args.debug)
    loader = Loader()
    model = _get_model(args.model_file, args.dict_dir).eval()
    encoder = Encoder(model)
    decoder = _get_decoder()
    _LG.info(encoder)

    if args.quantize:
        _LG.info("Quantizing the model")
        model.encoder.transformer.pos_conv_embed.__prepare_scriptable__()
        encoder = tq.quantize_dynamic(encoder, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
        _LG.info(encoder)

    # test
    if args.test_file:
        _LG.info("Testing with %s", args.test_file)
        waveform = loader(args.test_file)
        emission = encoder(waveform)
        transcript = decoder(emission)
        _LG.info(transcript)

    torch.jit.script(loader).save(os.path.join(args.output_path, "loader.zip"))
    torch.jit.script(decoder).save(os.path.join(args.output_path, "decoder.zip"))
    scripted = torch.jit.script(encoder)
    if args.optimize_for_mobile:
        scripted = optimize_for_mobile(scripted)
    scripted.save(os.path.join(args.output_path, "encoder.zip"))


def _init_logging(debug=False):
    level = logging.DEBUG if debug else logging.INFO
    format_ = "%(message)s" if not debug else "%(asctime)s: %(levelname)7s: %(funcName)10s: %(message)s"
    logging.basicConfig(level=level, format=format_)


if __name__ == "__main__":
    _main()