File: librosa_compatibility_test.py

package info (click to toggle)
pytorch-audio 0.7.2-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 5,512 kB
  • sloc: python: 15,606; cpp: 1,352; sh: 257; makefile: 21
file content (346 lines) | stat: -rw-r--r-- 15,083 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
"""Test suites for numerical compatibility with librosa"""
import os
import unittest
from distutils.version import StrictVersion

import torch
import torchaudio
import torchaudio.functional as F
from torchaudio._internal.module_utils import is_module_available

LIBROSA_AVAILABLE = is_module_available('librosa')

if LIBROSA_AVAILABLE:
    import numpy as np
    import librosa
    import scipy

import pytest

from torchaudio_unittest import common_utils


@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TestFunctional(common_utils.TorchaudioTestCase):
    """Test suite for functions in `functional` module."""
    def test_griffinlim(self):
        # NOTE: This test is flaky without a fixed random seed
        # See https://github.com/pytorch/audio/issues/382
        torch.random.manual_seed(42)
        tensor = torch.rand((1, 1000))

        n_fft = 400
        ws = 400
        hop = 100
        window = torch.hann_window(ws)
        normalize = False
        momentum = 0.99
        n_iter = 8
        length = 1000
        rand_init = False
        init = 'random' if rand_init else None

        specgram = F.spectrogram(tensor, 0, window, n_fft, hop, ws, 2, normalize).sqrt()
        ta_out = F.griffinlim(specgram, window, n_fft, hop, ws, 1, normalize,
                              n_iter, momentum, length, rand_init)
        lr_out = librosa.griffinlim(specgram.squeeze(0).numpy(), n_iter=n_iter, hop_length=hop,
                                    momentum=momentum, init=init, length=length)
        lr_out = torch.from_numpy(lr_out).unsqueeze(0)

        self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)

    def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0, norm=None):
        librosa_fb = librosa.filters.mel(sr=sample_rate,
                                         n_fft=n_fft,
                                         n_mels=n_mels,
                                         fmax=fmax,
                                         fmin=fmin,
                                         htk=True,
                                         norm=norm)
        fb = F.create_fb_matrix(sample_rate=sample_rate,
                                n_mels=n_mels,
                                f_max=fmax,
                                f_min=fmin,
                                n_freqs=(n_fft // 2 + 1),
                                norm=norm)

        for i_mel_bank in range(n_mels):
            self.assertEqual(
                fb[:, i_mel_bank], torch.tensor(librosa_fb[i_mel_bank]), atol=1e-4, rtol=1e-5)

    def test_create_fb(self):
        self._test_create_fb()
        self._test_create_fb(n_mels=128, sample_rate=44100)
        self._test_create_fb(n_mels=128, fmin=2000.0, fmax=5000.0)
        self._test_create_fb(n_mels=56, fmin=100.0, fmax=9000.0)
        self._test_create_fb(n_mels=56, fmin=800.0, fmax=900.0)
        self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0)
        self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0)
        if StrictVersion(librosa.__version__) < StrictVersion("0.7.2"):
            return
        self._test_create_fb(n_mels=128, sample_rate=44100, norm="slaney")
        self._test_create_fb(n_mels=128, fmin=2000.0, fmax=5000.0, norm="slaney")
        self._test_create_fb(n_mels=56, fmin=100.0, fmax=9000.0, norm="slaney")
        self._test_create_fb(n_mels=56, fmin=800.0, fmax=900.0, norm="slaney")
        self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0, norm="slaney")
        self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0, norm="slaney")

    def test_amplitude_to_DB(self):
        spec = torch.rand((6, 201))

        amin = 1e-10
        db_multiplier = 0.0
        top_db = 80.0

        # Power to DB
        multiplier = 10.0

        ta_out = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db)
        lr_out = librosa.core.power_to_db(spec.numpy())
        lr_out = torch.from_numpy(lr_out)

        self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)

        # Amplitude to DB
        multiplier = 20.0

        ta_out = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db)
        lr_out = librosa.core.amplitude_to_db(spec.numpy())
        lr_out = torch.from_numpy(lr_out)

        self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)


@pytest.mark.parametrize('complex_specgrams', [
    torch.randn(2, 1025, 400, 2)
])
@pytest.mark.parametrize('rate', [0.5, 1.01, 1.3])
@pytest.mark.parametrize('hop_length', [256])
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
def test_phase_vocoder(complex_specgrams, rate, hop_length):
    # Due to cummulative sum, numerical error in using torch.float32 will
    # result in bottom right values of the stretched sectrogram to not
    # match with librosa.

    complex_specgrams = complex_specgrams.type(torch.float64)
    phase_advance = torch.linspace(0, np.pi * hop_length, complex_specgrams.shape[-3], dtype=torch.float64)[..., None]

    complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance)

    # == Test shape
    expected_size = list(complex_specgrams.size())
    expected_size[-2] = int(np.ceil(expected_size[-2] / rate))

    assert complex_specgrams.dim() == complex_specgrams_stretch.dim()
    assert complex_specgrams_stretch.size() == torch.Size(expected_size)

    # == Test values
    index = [0] * (complex_specgrams.dim() - 3) + [slice(None)] * 3
    mono_complex_specgram = complex_specgrams[index].numpy()
    mono_complex_specgram = mono_complex_specgram[..., 0] + \
        mono_complex_specgram[..., 1] * 1j
    expected_complex_stretch = librosa.phase_vocoder(mono_complex_specgram,
                                                     rate=rate,
                                                     hop_length=hop_length)

    complex_stretch = complex_specgrams_stretch[index].numpy()
    complex_stretch = complex_stretch[..., 0] + 1j * complex_stretch[..., 1]

    assert np.allclose(complex_stretch, expected_complex_stretch, atol=1e-5)


def _load_audio_asset(*asset_paths, **kwargs):
    file_path = common_utils.get_asset_path(*asset_paths)
    sound, sample_rate = torchaudio.load(file_path, **kwargs)
    return sound, sample_rate


@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TestTransforms(common_utils.TorchaudioTestCase):
    """Test suite for functions in `transforms` module."""
    def assert_compatibilities(self, n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
        common_utils.set_audio_backend('default')
        path = common_utils.get_asset_path('sinewave.wav')
        sound, sample_rate = common_utils.load_wav(path)
        sound_librosa = sound.cpu().numpy().squeeze()  # (64000)

        # test core spectrogram
        spect_transform = torchaudio.transforms.Spectrogram(
            n_fft=n_fft, hop_length=hop_length, power=power)
        out_librosa, _ = librosa.core.spectrum._spectrogram(
            y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=power)

        out_torch = spect_transform(sound).squeeze().cpu()
        self.assertEqual(out_torch, torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5)

        # test mel spectrogram
        melspect_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate, window_fn=torch.hann_window,
            hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
        librosa_mel = librosa.feature.melspectrogram(
            y=sound_librosa, sr=sample_rate, n_fft=n_fft,
            hop_length=hop_length, n_mels=n_mels, htk=True, norm=None)
        librosa_mel_tensor = torch.from_numpy(librosa_mel)
        torch_mel = melspect_transform(sound).squeeze().cpu()
        self.assertEqual(
            torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3, rtol=1e-5)

        # test s2db
        power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
        power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu()
        power_to_db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
        self.assertEqual(power_to_db_torch, torch.from_numpy(power_to_db_librosa), atol=5e-3, rtol=1e-5)

        mag_to_db_transform = torchaudio.transforms.AmplitudeToDB('magnitude', 80.)
        mag_to_db_torch = mag_to_db_transform(torch.abs(sound)).squeeze().cpu()
        mag_to_db_librosa = librosa.core.spectrum.amplitude_to_db(sound_librosa)
        self.assertEqual(mag_to_db_torch, torch.from_numpy(mag_to_db_librosa), atol=5e-3, rtol=1e-5)

        power_to_db_torch = power_to_db_transform(melspect_transform(sound)).squeeze().cpu()
        db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
        db_librosa_tensor = torch.from_numpy(db_librosa)
        self.assertEqual(
            power_to_db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3, rtol=1e-5)

        # test MFCC
        melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
        mfcc_transform = torchaudio.transforms.MFCC(
            sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho', melkwargs=melkwargs)

        # librosa.feature.mfcc doesn't pass kwargs properly since some of the
        # kwargs for melspectrogram and mfcc are the same. We just follow the
        # function body in
        # https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
        # to mirror this function call with correct args:
        #
        # librosa_mfcc = librosa.feature.mfcc(
        #     y=sound_librosa, sr=sample_rate, n_mfcc = n_mfcc,
        #     hop_length=hop_length, n_fft=n_fft, htk=True, norm=None, n_mels=n_mels)

        librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
        librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
        torch_mfcc = mfcc_transform(sound).squeeze().cpu()

        self.assertEqual(
            torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3, rtol=1e-5)

    def test_basics1(self):
        kwargs = {
            'n_fft': 400,
            'hop_length': 200,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 40,
            'sample_rate': 16000
        }
        self.assert_compatibilities(**kwargs)

    def test_basics2(self):
        kwargs = {
            'n_fft': 600,
            'hop_length': 100,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 20,
            'sample_rate': 16000
        }
        self.assert_compatibilities(**kwargs)

    # NOTE: Test passes offline, but fails on TravisCI (and CircleCI), see #372.
    @unittest.skipIf('CI' in os.environ, 'Test is known to fail on CI')
    def test_basics3(self):
        kwargs = {
            'n_fft': 200,
            'hop_length': 50,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 50,
            'sample_rate': 24000
        }
        self.assert_compatibilities(**kwargs)

    def test_basics4(self):
        kwargs = {
            'n_fft': 400,
            'hop_length': 200,
            'power': 3.0,
            'n_mels': 128,
            'n_mfcc': 40,
            'sample_rate': 16000
        }
        self.assert_compatibilities(**kwargs)

    def test_MelScale(self):
        """MelScale transform is comparable to that of librosa"""
        n_fft = 2048
        n_mels = 256
        hop_length = n_fft // 4
        sample_rate = 44100
        sound = common_utils.get_whitenoise(sample_rate=sample_rate, duration=60)
        sound = sound.mean(dim=0, keepdim=True)
        spec_ta = F.spectrogram(
            sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft,
            hop_length=hop_length, win_length=n_fft, power=2, normalized=False)
        spec_lr = spec_ta.cpu().numpy().squeeze()
        # Perform MelScale with torchaudio and librosa
        melspec_ta = torchaudio.transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_ta)
        melspec_lr = librosa.feature.melspectrogram(
            S=spec_lr, sr=sample_rate, n_fft=n_fft, hop_length=hop_length,
            win_length=n_fft, center=True, window='hann', n_mels=n_mels, htk=True, norm=None)
        # Note: Using relaxed rtol instead of atol
        self.assertEqual(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), atol=1e-8, rtol=1e-3)

    def test_InverseMelScale(self):
        """InverseMelScale transform is comparable to that of librosa"""
        n_fft = 2048
        n_mels = 256
        n_stft = n_fft // 2 + 1
        hop_length = n_fft // 4

        # Prepare mel spectrogram input. We use torchaudio to compute one.
        path = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
        sound, sample_rate = common_utils.load_wav(path)
        sound = sound[:, 2**10:2**10 + 2**14]
        sound = sound.mean(dim=0, keepdim=True)
        spec_orig = F.spectrogram(
            sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft,
            hop_length=hop_length, win_length=n_fft, power=2, normalized=False)
        melspec_ta = torchaudio.transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_orig)
        melspec_lr = melspec_ta.cpu().numpy().squeeze()
        # Perform InverseMelScale with torch audio and librosa
        spec_ta = torchaudio.transforms.InverseMelScale(
            n_stft, n_mels=n_mels, sample_rate=sample_rate)(melspec_ta)
        spec_lr = librosa.feature.inverse.mel_to_stft(
            melspec_lr, sr=sample_rate, n_fft=n_fft, power=2.0, htk=True, norm=None)
        spec_lr = torch.from_numpy(spec_lr[None, ...])

        # Align dimensions
        # librosa does not return power spectrogram while torchaudio returns power spectrogram
        spec_orig = spec_orig.sqrt()
        spec_ta = spec_ta.sqrt()

        threshold = 2.0
        # This threshold was choosen empirically, based on the following observation
        #
        # torch.dist(spec_lr, spec_ta, p=float('inf'))
        # >>> tensor(1.9666)
        #
        # The spectrograms reconstructed by librosa and torchaudio are not comparable elementwise.
        # This is because they use different approximation algorithms and resulting values can live
        # in different magnitude. (although most of them are very close)
        # See
        # https://github.com/pytorch/audio/pull/366 for the discussion of the choice of algorithm
        # https://github.com/pytorch/audio/pull/448/files#r385747021 for the distribution of P-inf
        # distance over frequencies.
        self.assertEqual(spec_ta, spec_lr, atol=threshold, rtol=1e-5)

        threshold = 1700.0
        # This threshold was choosen empirically, based on the following observations
        #
        # torch.dist(spec_orig, spec_ta, p=1)
        # >>> tensor(1644.3516)
        # torch.dist(spec_orig, spec_lr, p=1)
        # >>> tensor(1420.7103)
        # torch.dist(spec_lr, spec_ta, p=1)
        # >>> tensor(943.2759)
        assert torch.dist(spec_orig, spec_ta, p=1) < threshold