File: common_utils.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (110 lines) | stat: -rw-r--r-- 3,948 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# https://github.com/pytorch/fairseq/blob/265df7144c79446f5ea8d835bda6e727f54dad9d/LICENSE
"""
Data pre-processing: create tsv files for training (and valiation).
"""
import logging
import re
from pathlib import Path
from typing import Dict, Tuple, Union

import torch
import torchaudio


_LG = logging.getLogger(__name__)


def create_tsv(
    root_dir: Union[str, Path],
    out_dir: Union[str, Path],
    dataset: str = "librispeech",
    valid_percent: float = 0.01,
    seed: int = 0,
    extension: str = "flac",
) -> None:
    """Create file lists for training and validation.
    Args:
        root_dir (str or Path): The directory of the dataset.
        out_dir (str or Path): The directory to store the file lists.
        dataset (str, optional): The dataset to use. Options:
            [``librispeech``, ``libri-light``]. (Default: ``librispeech``)
        valid_percent (float, optional): The percentage of data for validation. (Default: 0.01)
        seed (int): The seed for randomly selecting the validation files.
        extension (str, optional): The extension of audio files. (Default: ``flac``)

    Returns:
        None
    """
    assert valid_percent >= 0 and valid_percent <= 1.0

    torch.manual_seed(seed)
    root_dir = Path(root_dir)
    out_dir = Path(out_dir)

    if not out_dir.exists():
        out_dir.mkdir()

    valid_f = open(out_dir / f"{dataset}_valid.tsv", "w") if valid_percent > 0 else None
    search_pattern = ".*train.*"
    with open(out_dir / f"{dataset}_train.tsv", "w") as train_f:
        print(root_dir, file=train_f)

        if valid_f is not None:
            print(root_dir, file=valid_f)

        for fname in root_dir.glob(f"**/*.{extension}"):
            if re.match(search_pattern, str(fname)):
                frames = torchaudio.info(fname).num_frames
                dest = train_f if torch.rand(1) > valid_percent else valid_f
                print(f"{fname.relative_to(root_dir)}\t{frames}", file=dest)
    if valid_f is not None:
        valid_f.close()
    _LG.info("Finished creating the file lists successfully")


def _get_feat_lens_paths(feat_dir: Path, split: str, rank: int, num_rank: int) -> Tuple[Path, Path]:
    r"""Get the feature and lengths paths based on feature directory,
        data split, rank, and number of ranks.
    Args:
        feat_dir (Path): The directory that stores the feature and lengths tensors.
        split (str): The split of data. Options: [``train``, ``valid``].
        rank (int): The rank in the multi-processing.
        num_rank (int): The number of ranks for multi-processing in feature extraction.

    Returns:
        (Path, Path)
        Path: The file path of the feature tensor for the current rank.
        Path: The file path of the lengths tensor for the current rank.
    """
    feat_path = feat_dir / f"{split}_{rank}_{num_rank}.pt"
    len_path = feat_dir / f"len_{split}_{rank}_{num_rank}.pt"
    return feat_path, len_path


def _get_model_path(km_dir: Path) -> Path:
    r"""Get the file path of the KMeans clustering model
    Args:
        km_dir (Path): The directory to store the KMeans clustering model.

    Returns:
        Path: The file path of the model.
    """
    return km_dir / "model.pt"


def _get_id2label() -> Dict:
    """Get the dictionary that maps indices of ASR model's last layer dimension to the corresponding labels."""
    bundle = torchaudio.pipelines.HUBERT_ASR_LARGE
    labels = bundle.get_labels()
    return {i: char.lower() for i, char in enumerate(labels)}


def _get_label2id() -> Dict:
    """Get the dictionary that maps the labels to the corresponding indices in ASR model's last dimension."""
    bundle = torchaudio.pipelines.HUBERT_ASR_LARGE
    labels = bundle.get_labels()
    return {char: i for i, char in enumerate(labels)}