File: cmuarctic_test.py

package info (click to toggle)
pytorch-audio 0.7.2-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 5,512 kB
  • sloc: python: 15,606; cpp: 1,352; sh: 257; makefile: 21
file content (67 lines) | stat: -rw-r--r-- 2,288 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os

from torchaudio.datasets import cmuarctic

from torchaudio_unittest.common_utils import (
    TempDirMixin,
    TorchaudioTestCase,
    get_whitenoise,
    save_wav,
    normalize_wav,
)


class TestCMUARCTIC(TempDirMixin, TorchaudioTestCase):
    backend = "default"

    root_dir = None
    samples = []

    @classmethod
    def setUpClass(cls):
        cls.root_dir = cls.get_base_temp_dir()
        sample_rate = 16000
        utterance = "This is a test utterance."

        base_dir = os.path.join(cls.root_dir, "ARCTIC", "cmu_us_aew_arctic")
        txt_dir = os.path.join(base_dir, "etc")
        os.makedirs(txt_dir, exist_ok=True)
        txt_file = os.path.join(txt_dir, "txt.done.data")
        audio_dir = os.path.join(base_dir, "wav")
        os.makedirs(audio_dir, exist_ok=True)

        seed = 42
        with open(txt_file, "w") as txt:
            for c in ["a", "b"]:
                for i in range(5):
                    utterance_id = f"arctic_{c}{i:04d}"
                    path = os.path.join(audio_dir, f"{utterance_id}.wav")
                    data = get_whitenoise(
                        sample_rate=sample_rate,
                        duration=3,
                        n_channels=1,
                        dtype="int16",
                        seed=seed,
                    )
                    save_wav(path, data, sample_rate)
                    sample = (
                        normalize_wav(data),
                        sample_rate,
                        utterance,
                        utterance_id.split("_")[1],
                    )
                    cls.samples.append(sample)
                    txt.write(f'( {utterance_id} "{utterance}" )\n')
                    seed += 1

    def test_cmuarctic(self):
        dataset = cmuarctic.CMUARCTIC(self.root_dir)
        n_ite = 0
        for i, (waveform, sample_rate, utterance, utterance_id) in enumerate(dataset):
            expected_sample = self.samples[i]
            assert sample_rate == expected_sample[1]
            assert utterance == expected_sample[2]
            assert utterance_id == expected_sample[3]
            self.assertEqual(expected_sample[0], waveform, atol=5e-5, rtol=1e-8)
            n_ite += 1
        assert n_ite == len(self.samples)