1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
|
import os
import sys
import pytest
import torch
import torchaudio
from torchaudio.pipelines import CONVTASNET_BASE_LIBRI2MIX, HDEMUCS_HIGH_MUSDB, HDEMUCS_HIGH_MUSDB_PLUS
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "examples"))
from source_separation.utils.metrics import sdr
@pytest.mark.parametrize(
"bundle,task,channel,expected_score",
[
[CONVTASNET_BASE_LIBRI2MIX, "speech_separation", 1, 8.1373],
[HDEMUCS_HIGH_MUSDB_PLUS, "music_separation", 2, 8.7480],
[HDEMUCS_HIGH_MUSDB, "music_separation", 2, 8.0697],
],
)
def test_source_separation_models(bundle, task, channel, expected_score, mixture_source, clean_sources):
"""Integration test for the source separation pipeline.
Given the mixture waveform with dimensions `(batch, channel, time)`, the pre-trained pipeline generates
the separated sources Tensor with dimensions `(batch, num_sources, time)`.
The test computes the scale-invariant signal-to-distortion ratio (Si-SDR) score in decibel (dB).
Si-SDR score should be equal to or larger than the expected score.
"""
model = bundle.get_model()
mixture_waveform, sample_rate = torchaudio.load(mixture_source)
assert sample_rate == bundle.sample_rate, "The sample rate of audio must match that in the bundle."
clean_waveforms = []
for source in clean_sources:
clean_waveform, sample_rate = torchaudio.load(source)
assert sample_rate == bundle.sample_rate, "The sample rate of audio must match that in the bundle."
clean_waveforms.append(clean_waveform)
mixture_waveform = mixture_waveform.reshape(1, channel, -1)
estimated_sources = model(mixture_waveform)
clean_waveforms = torch.cat(clean_waveforms).unsqueeze(0)
estimated_sources = estimated_sources.reshape(1, -1, clean_waveforms.shape[-1])
sdr_values = sdr(estimated_sources, clean_waveforms).mean()
assert sdr_values >= expected_score
|