1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
|
import os
from pathlib import Path
from torchaudio.datasets import gtzan
from torchaudio_unittest.common_utils import get_whitenoise, normalize_wav, save_wav, TempDirMixin, TorchaudioTestCase
def get_mock_dataset(root_dir):
"""
root_dir: directory to the mocked dataset
"""
mocked_samples = []
mocked_training = []
mocked_validation = []
mocked_testing = []
sample_rate = 22050
seed = 0
for genre in gtzan.gtzan_genres:
base_dir = os.path.join(root_dir, "genres", genre)
os.makedirs(base_dir, exist_ok=True)
for i in range(100):
filename = f"{genre}.{i:05d}"
path = os.path.join(base_dir, f"{filename}.wav")
data = get_whitenoise(sample_rate=sample_rate, duration=0.01, n_channels=1, dtype="int16", seed=seed)
save_wav(path, data, sample_rate)
sample = (normalize_wav(data), sample_rate, genre)
mocked_samples.append(sample)
if filename in gtzan.filtered_test:
mocked_testing.append(sample)
if filename in gtzan.filtered_train:
mocked_training.append(sample)
if filename in gtzan.filtered_valid:
mocked_validation.append(sample)
seed += 1
return (mocked_samples, mocked_training, mocked_validation, mocked_testing)
class TestGTZAN(TempDirMixin, TorchaudioTestCase):
backend = "default"
root_dir = None
samples = []
training = []
validation = []
testing = []
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
mocked_data = get_mock_dataset(cls.root_dir)
cls.samples = mocked_data[0]
cls.training = mocked_data[1]
cls.validation = mocked_data[2]
cls.testing = mocked_data[3]
def test_no_subset(self):
dataset = gtzan.GTZAN(self.root_dir)
n_ite = 0
for i, (waveform, sample_rate, label) in enumerate(dataset):
self.assertEqual(waveform, self.samples[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.samples[i][1]
assert label == self.samples[i][2]
n_ite += 1
assert n_ite == len(self.samples)
def _test_training(self, dataset):
n_ite = 0
for i, (waveform, sample_rate, label) in enumerate(dataset):
self.assertEqual(waveform, self.training[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.training[i][1]
assert label == self.training[i][2]
n_ite += 1
assert n_ite == len(self.training)
def _test_validation(self, dataset):
n_ite = 0
for i, (waveform, sample_rate, label) in enumerate(dataset):
self.assertEqual(waveform, self.validation[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.validation[i][1]
assert label == self.validation[i][2]
n_ite += 1
assert n_ite == len(self.validation)
def _test_testing(self, dataset):
n_ite = 0
for i, (waveform, sample_rate, label) in enumerate(dataset):
self.assertEqual(waveform, self.testing[i][0], atol=5e-5, rtol=1e-8)
assert sample_rate == self.testing[i][1]
assert label == self.testing[i][2]
n_ite += 1
assert n_ite == len(self.testing)
def test_training_str(self):
train_dataset = gtzan.GTZAN(self.root_dir, subset="training")
self._test_training(train_dataset)
def test_validation_str(self):
val_dataset = gtzan.GTZAN(self.root_dir, subset="validation")
self._test_validation(val_dataset)
def test_testing_str(self):
test_dataset = gtzan.GTZAN(self.root_dir, subset="testing")
self._test_testing(test_dataset)
def test_training_path(self):
root_dir = Path(self.root_dir)
train_dataset = gtzan.GTZAN(root_dir, subset="training")
self._test_training(train_dataset)
def test_validation_path(self):
root_dir = Path(self.root_dir)
val_dataset = gtzan.GTZAN(root_dir, subset="validation")
self._test_validation(val_dataset)
def test_testing_path(self):
root_dir = Path(self.root_dir)
test_dataset = gtzan.GTZAN(root_dir, subset="testing")
self._test_testing(test_dataset)
|