File: librimix_test.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (99 lines) | stat: -rw-r--r-- 3,642 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os

from parameterized import parameterized

from torchaudio.datasets import LibriMix
from torchaudio_unittest.common_utils import get_whitenoise, save_wav, TempDirMixin, TorchaudioTestCase

_SAMPLE_RATE = 8000
_TASKS_TO_MIXTURE = {
    "sep_clean": "mix_clean",
    "enh_single": "mix_single",
    "enh_both": "mix_both",
    "sep_noisy": "mix_both",
}


def _save_wav(filepath: str, seed: int):
    wav = get_whitenoise(
        sample_rate=_SAMPLE_RATE,
        duration=0.01,
        n_channels=1,
        seed=seed,
    )
    save_wav(filepath, wav, _SAMPLE_RATE)
    return wav


def get_mock_dataset(root_dir: str, num_speaker: int):
    """
    root_dir: directory to the mocked dataset
    """
    mocked_data = []
    seed = 0
    base_dir = os.path.join(root_dir, f"Libri{num_speaker}Mix", "wav8k", "min", "train-360")
    os.makedirs(base_dir, exist_ok=True)
    for utterance_id in range(10):
        filename = f"{utterance_id}.wav"
        task_outputs = {}
        for task in _TASKS_TO_MIXTURE:
            # create mixture folder. The folder names depends on the task.
            mixture_folder = _TASKS_TO_MIXTURE[task]
            mixture_dir = os.path.join(base_dir, mixture_folder)
            os.makedirs(mixture_dir, exist_ok=True)
            mixture_path = os.path.join(mixture_dir, filename)
            mixture = _save_wav(mixture_path, seed)
            sources = []
            if task == "enh_both":
                sources = [task_outputs["sep_clean"][1]]
            else:
                for speaker_id in range(num_speaker):
                    source_dir = os.path.join(base_dir, f"s{speaker_id+1}")
                    os.makedirs(source_dir, exist_ok=True)
                    source_path = os.path.join(source_dir, filename)
                    if os.path.exists(source_path):
                        sources = task_outputs["sep_clean"][2]
                        break
                    else:
                        source = _save_wav(source_path, seed)
                        sources.append(source)
                        seed += 1
            task_outputs[task] = (_SAMPLE_RATE, mixture, sources)
        mocked_data.append(task_outputs)

    return mocked_data


class TestLibriMix(TempDirMixin, TorchaudioTestCase):

    root_dir = None
    samples_2spk = []
    samples_3spk = []

    @classmethod
    def setUpClass(cls):
        cls.root_dir = cls.get_base_temp_dir()
        cls.samples_2spk = get_mock_dataset(cls.root_dir, 2)
        cls.samples_3spk = get_mock_dataset(cls.root_dir, 3)

    def _test_librimix(self, dataset, samples, task):
        num_samples = 0
        for i, (sample_rate, mixture, sources) in enumerate(dataset):
            assert sample_rate == samples[i][task][0]
            self.assertEqual(mixture, samples[i][task][1])
            assert len(sources) == len(samples[i][task][2])
            for j in range(len(sources)):
                self.assertEqual(sources[j], samples[i][task][2][j])
            num_samples += 1

        assert num_samples == len(samples)

    @parameterized.expand([("sep_clean"), ("enh_single",), ("enh_both",), ("sep_noisy",)])
    def test_librimix_2speaker(self, task):
        dataset = LibriMix(self.root_dir, num_speakers=2, sample_rate=_SAMPLE_RATE, task=task)
        self._test_librimix(dataset, self.samples_2spk, task)

    @parameterized.expand([("sep_clean"), ("enh_single",), ("enh_both",), ("sep_noisy",)])
    def test_librimix_3speaker(self, task):
        dataset = LibriMix(self.root_dir, num_speakers=3, sample_rate=_SAMPLE_RATE, task=task)
        self._test_librimix(dataset, self.samples_3spk, task)