File: wsj0mix_test.py

package info (click to toggle)
pytorch-audio 0.7.2-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 5,512 kB
  • sloc: python: 15,606; cpp: 1,352; sh: 257; makefile: 21
file content (111 lines) | stat: -rw-r--r-- 3,706 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os

from torchaudio_unittest.common_utils import (
    TempDirMixin,
    TorchaudioTestCase,
    get_whitenoise,
    save_wav,
    normalize_wav,
)

from source_separation.utils.dataset import wsj0mix


_FILENAMES = [
    "012c0207_1.9952_01cc0202_-1.9952.wav",
    "01co0302_1.63_014c020q_-1.63.wav",
    "01do0316_0.24011_205a0104_-0.24011.wav",
    "01lc020x_1.1301_027o030r_-1.1301.wav",
    "01mc0202_0.34056_205o0106_-0.34056.wav",
    "01nc020t_0.53821_018o030w_-0.53821.wav",
    "01po030f_2.2136_40ko031a_-2.2136.wav",
    "01ra010o_2.4098_403a010f_-2.4098.wav",
    "01xo030b_0.22377_016o031a_-0.22377.wav",
    "02ac020x_0.68566_01ec020b_-0.68566.wav",
    "20co010m_0.82801_019c0212_-0.82801.wav",
    "20da010u_1.2483_017c0211_-1.2483.wav",
    "20oo010d_1.0631_01ic020s_-1.0631.wav",
    "20sc0107_2.0222_20fo010h_-2.0222.wav",
    "20tc010f_0.051456_404a0110_-0.051456.wav",
    "407c0214_1.1712_02ca0113_-1.1712.wav",
    "40ao030w_2.4697_20vc010a_-2.4697.wav",
    "40pa0101_1.1087_40ea0107_-1.1087.wav",
]


def _mock_dataset(root_dir, num_speaker):
    dirnames = ["mix"] + [f"s{i+1}" for i in range(num_speaker)]
    for dirname in dirnames:
        os.makedirs(os.path.join(root_dir, dirname), exist_ok=True)

    seed = 0
    sample_rate = 8000
    expected = []
    for filename in _FILENAMES:
        mix = None
        src = []
        for dirname in dirnames:
            waveform = get_whitenoise(
                sample_rate=8000, duration=1, n_channels=1, dtype="int16", seed=seed
            )
            seed += 1

            path = os.path.join(root_dir, dirname, filename)
            save_wav(path, waveform, sample_rate)
            waveform = normalize_wav(waveform)

            if dirname == "mix":
                mix = waveform
            else:
                src.append(waveform)
        expected.append((sample_rate, mix, src))
    return expected


class TestWSJ0Mix2(TempDirMixin, TorchaudioTestCase):
    backend = "default"
    root_dir = None
    expected = None

    @classmethod
    def setUpClass(cls):
        cls.root_dir = cls.get_base_temp_dir()
        cls.expected = _mock_dataset(cls.root_dir, 2)

    def test_wsj0mix(self):
        dataset = wsj0mix.WSJ0Mix(self.root_dir, num_speakers=2, sample_rate=8000)

        n_ite = 0
        for i, sample in enumerate(dataset):
            (_, sample_mix, sample_src) = sample
            (_, expected_mix, expected_src) = self.expected[i]
            self.assertEqual(sample_mix, expected_mix, atol=5e-5, rtol=1e-8)
            self.assertEqual(sample_src[0], expected_src[0], atol=5e-5, rtol=1e-8)
            self.assertEqual(sample_src[1], expected_src[1], atol=5e-5, rtol=1e-8)
            n_ite += 1
        assert n_ite == len(self.expected)


class TestWSJ0Mix3(TempDirMixin, TorchaudioTestCase):
    backend = "default"
    root_dir = None
    expected = None

    @classmethod
    def setUpClass(cls):
        cls.root_dir = cls.get_base_temp_dir()
        cls.expected = _mock_dataset(cls.root_dir, 3)

    def test_wsj0mix(self):
        dataset = wsj0mix.WSJ0Mix(self.root_dir, num_speakers=3, sample_rate=8000)

        n_ite = 0
        for i, sample in enumerate(dataset):
            (_, sample_mix, sample_src) = sample
            (_, expected_mix, expected_src) = self.expected[i]
            self.assertEqual(sample_mix, expected_mix, atol=5e-5, rtol=1e-8)
            self.assertEqual(sample_src[0], expected_src[0], atol=5e-5, rtol=1e-8)
            self.assertEqual(sample_src[1], expected_src[1], atol=5e-5, rtol=1e-8)
            self.assertEqual(sample_src[2], expected_src[2], atol=5e-5, rtol=1e-8)
            n_ite += 1
        assert n_ite == len(self.expected)