File: yesno_test.py

package info (click to toggle)
pytorch-audio 0.7.2-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 5,512 kB
  • sloc: python: 15,606; cpp: 1,352; sh: 257; makefile: 21
file content (49 lines) | stat: -rw-r--r-- 1,509 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os

from torchaudio.datasets import yesno

from torchaudio_unittest.common_utils import (
    TempDirMixin,
    TorchaudioTestCase,
    get_whitenoise,
    save_wav,
    normalize_wav,
)


class TestYesNo(TempDirMixin, TorchaudioTestCase):
    backend = 'default'

    root_dir = None
    data = []
    labels = [
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 1, 1, 1],
        [0, 1, 0, 1, 0, 1, 1, 0],
        [1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1],
    ]

    @classmethod
    def setUpClass(cls):
        cls.root_dir = cls.get_base_temp_dir()
        base_dir = os.path.join(cls.root_dir, 'waves_yesno')
        os.makedirs(base_dir, exist_ok=True)
        for i, label in enumerate(cls.labels):
            filename = f'{"_".join(str(l) for l in label)}.wav'
            path = os.path.join(base_dir, filename)
            data = get_whitenoise(sample_rate=8000, duration=6, n_channels=1, dtype='int16', seed=i)
            save_wav(path, data, 8000)
            cls.data.append(normalize_wav(data))

    def test_yesno(self):
        dataset = yesno.YESNO(self.root_dir)
        n_ite = 0
        for i, (waveform, sample_rate, label) in enumerate(dataset):
            expected_label = self.labels[i]
            expected_data = self.data[i]
            self.assertEqual(expected_data, waveform, atol=5e-5, rtol=1e-8)
            assert sample_rate == 8000
            assert label == expected_label
            n_ite += 1
        assert n_ite == len(self.data)