File: convert_voxpopuli_models.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (107 lines) | stat: -rwxr-xr-x 3,643 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python3
"""Convert the fairseq models available in voxpopuli repo https://github.com/facebookresearch/voxpopuli

The available checkpoints should open with fairseq.
But the following error cannot be resolved with almost any version of fairseq.
https://github.com/facebookresearch/voxpopuli/issues/29

So this script manually parse the checkpoint file and reconstruct the model.

Examples

```
python convert_voxpopuli_models.py \
  --input-file wav2vec2_base_10k_ft_fr.pt \
  --output-file wav2vec2_voxpopuli_base_10k_asr_fr.pt
```
"""


def _parse_args():
    import argparse

    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawTextHelpFormatter,
    )
    parser.add_argument("--input-file", required=True, help="Input checkpoint file.")
    parser.add_argument("--output-file", required=False, help="Output model file.")
    return parser.parse_args()


def _removeprefix(s, prefix):
    if s.startswith(prefix):
        return s[len(prefix) :]
    return s


def _load(input_file):
    import torch
    from omegaconf import OmegaConf

    data = torch.load(input_file)
    cfg = OmegaConf.to_container(data["cfg"])
    for key in list(cfg.keys()):
        if key != "model":
            del cfg[key]
            if "w2v_args" in cfg["model"]:
                del cfg["model"]["w2v_args"][key]
    state_dict = {_removeprefix(k, "w2v_encoder."): v for k, v in data["model"].items()}
    return cfg, state_dict


def _parse_model_param(cfg, state_dict):
    key_mapping = {
        "extractor_mode": "extractor_mode",
        "conv_feature_layers": "extractor_conv_layer_config",
        "conv_bias": "extractor_conv_bias",
        "encoder_embed_dim": "encoder_embed_dim",
        "dropout_input": "encoder_projection_dropout",
        "conv_pos": "encoder_pos_conv_kernel",
        "conv_pos_groups": "encoder_pos_conv_groups",
        "encoder_layers": "encoder_num_layers",
        "encoder_attention_heads": "encoder_num_heads",
        "attention_dropout": "encoder_attention_dropout",
        "encoder_ffn_embed_dim": "encoder_ff_interm_features",
        "activation_dropout": "encoder_ff_interm_dropout",
        "dropout": "encoder_dropout",
        "layer_norm_first": "encoder_layer_norm_first",
        "layerdrop": "encoder_layer_drop",
        "encoder_layerdrop": "encoder_layer_drop",
    }
    params = {}
    src_dicts = [cfg["model"]]
    if "w2v_args" in cfg["model"]:
        src_dicts.append(cfg["model"]["w2v_args"]["model"])

    for src, tgt in key_mapping.items():
        for model_cfg in src_dicts:
            if src in model_cfg:
                params[tgt] = model_cfg[src]
                break
    if params["extractor_mode"] == "default":
        params["extractor_mode"] = "group_norm"
    # the following line is commented out to resolve lint warning; uncomment before running script
    # params["extractor_conv_layer_config"] = eval(params["extractor_conv_layer_config"])
    assert len(params) == 15
    params["aux_num_out"] = state_dict["proj.bias"].numel() if "proj.bias" in state_dict else None
    return params


def _main(args):
    import json

    import torch
    import torchaudio
    from torchaudio.models.wav2vec2.utils.import_fairseq import _convert_state_dict as _convert

    cfg, state_dict = _load(args.input_file)
    params = _parse_model_param(cfg, state_dict)
    print(json.dumps(params, indent=4))
    model = torchaudio.models.wav2vec2_model(**params)
    model.load_state_dict(_convert(state_dict))
    torch.save(model.state_dict(), args.output_file)


if __name__ == "__main__":
    _main(_parse_args())