File: gtzan_test.py

package info (click to toggle)
pytorch-audio 0.13.1-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 8,592 kB
  • sloc: python: 41,137; cpp: 8,016; sh: 3,538; makefile: 24
file content (120 lines) | stat: -rw-r--r-- 4,347 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
from pathlib import Path

from torchaudio.datasets import gtzan
from torchaudio_unittest.common_utils import get_whitenoise, normalize_wav, save_wav, TempDirMixin, TorchaudioTestCase


def get_mock_dataset(root_dir):
    """
    root_dir: directory to the mocked dataset
    """
    mocked_samples = []
    mocked_training = []
    mocked_validation = []
    mocked_testing = []
    sample_rate = 22050

    seed = 0
    for genre in gtzan.gtzan_genres:
        base_dir = os.path.join(root_dir, "genres", genre)
        os.makedirs(base_dir, exist_ok=True)
        for i in range(100):
            filename = f"{genre}.{i:05d}"
            path = os.path.join(base_dir, f"{filename}.wav")
            data = get_whitenoise(sample_rate=sample_rate, duration=0.01, n_channels=1, dtype="int16", seed=seed)
            save_wav(path, data, sample_rate)
            sample = (normalize_wav(data), sample_rate, genre)
            mocked_samples.append(sample)
            if filename in gtzan.filtered_test:
                mocked_testing.append(sample)
            if filename in gtzan.filtered_train:
                mocked_training.append(sample)
            if filename in gtzan.filtered_valid:
                mocked_validation.append(sample)
            seed += 1
    return (mocked_samples, mocked_training, mocked_validation, mocked_testing)


class TestGTZAN(TempDirMixin, TorchaudioTestCase):
    backend = "default"

    root_dir = None
    samples = []
    training = []
    validation = []
    testing = []

    @classmethod
    def setUpClass(cls):
        cls.root_dir = cls.get_base_temp_dir()
        mocked_data = get_mock_dataset(cls.root_dir)
        cls.samples = mocked_data[0]
        cls.training = mocked_data[1]
        cls.validation = mocked_data[2]
        cls.testing = mocked_data[3]

    def test_no_subset(self):
        dataset = gtzan.GTZAN(self.root_dir)

        n_ite = 0
        for i, (waveform, sample_rate, label) in enumerate(dataset):
            self.assertEqual(waveform, self.samples[i][0], atol=5e-5, rtol=1e-8)
            assert sample_rate == self.samples[i][1]
            assert label == self.samples[i][2]
            n_ite += 1
        assert n_ite == len(self.samples)

    def _test_training(self, dataset):
        n_ite = 0
        for i, (waveform, sample_rate, label) in enumerate(dataset):
            self.assertEqual(waveform, self.training[i][0], atol=5e-5, rtol=1e-8)
            assert sample_rate == self.training[i][1]
            assert label == self.training[i][2]
            n_ite += 1
        assert n_ite == len(self.training)

    def _test_validation(self, dataset):
        n_ite = 0
        for i, (waveform, sample_rate, label) in enumerate(dataset):
            self.assertEqual(waveform, self.validation[i][0], atol=5e-5, rtol=1e-8)
            assert sample_rate == self.validation[i][1]
            assert label == self.validation[i][2]
            n_ite += 1
        assert n_ite == len(self.validation)

    def _test_testing(self, dataset):
        n_ite = 0
        for i, (waveform, sample_rate, label) in enumerate(dataset):
            self.assertEqual(waveform, self.testing[i][0], atol=5e-5, rtol=1e-8)
            assert sample_rate == self.testing[i][1]
            assert label == self.testing[i][2]
            n_ite += 1
        assert n_ite == len(self.testing)

    def test_training_str(self):
        train_dataset = gtzan.GTZAN(self.root_dir, subset="training")
        self._test_training(train_dataset)

    def test_validation_str(self):
        val_dataset = gtzan.GTZAN(self.root_dir, subset="validation")
        self._test_validation(val_dataset)

    def test_testing_str(self):
        test_dataset = gtzan.GTZAN(self.root_dir, subset="testing")
        self._test_testing(test_dataset)

    def test_training_path(self):
        root_dir = Path(self.root_dir)
        train_dataset = gtzan.GTZAN(root_dir, subset="training")
        self._test_training(train_dataset)

    def test_validation_path(self):
        root_dir = Path(self.root_dir)
        val_dataset = gtzan.GTZAN(root_dir, subset="validation")
        self._test_validation(val_dataset)

    def test_testing_path(self):
        root_dir = Path(self.root_dir)
        test_dataset = gtzan.GTZAN(root_dir, subset="testing")
        self._test_testing(test_dataset)