File: source_separation_pipeline_test.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (42 lines) | stat: -rw-r--r-- 1,955 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import os
import sys

import pytest
import torch
import torchaudio
from torchaudio.pipelines import CONVTASNET_BASE_LIBRI2MIX, HDEMUCS_HIGH_MUSDB, HDEMUCS_HIGH_MUSDB_PLUS


sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "examples"))
from source_separation.utils.metrics import sdr


@pytest.mark.parametrize(
    "bundle,task,channel,expected_score",
    [
        [CONVTASNET_BASE_LIBRI2MIX, "speech_separation", 1, 8.1373],
        [HDEMUCS_HIGH_MUSDB_PLUS, "music_separation", 2, 8.7480],
        [HDEMUCS_HIGH_MUSDB, "music_separation", 2, 8.0697],
    ],
)
def test_source_separation_models(bundle, task, channel, expected_score, mixture_source, clean_sources):
    """Integration test for the source separation pipeline.
    Given the mixture waveform with dimensions `(batch, channel, time)`, the pre-trained pipeline generates
    the separated sources Tensor with dimensions `(batch, num_sources, time)`.
    The test computes the scale-invariant signal-to-distortion ratio (Si-SDR) score in decibel (dB).
    Si-SDR score should be equal to or larger than the expected score.
    """
    model = bundle.get_model()
    mixture_waveform, sample_rate = torchaudio.load(mixture_source)
    assert sample_rate == bundle.sample_rate, "The sample rate of audio must match that in the bundle."
    clean_waveforms = []
    for source in clean_sources:
        clean_waveform, sample_rate = torchaudio.load(source)
        assert sample_rate == bundle.sample_rate, "The sample rate of audio must match that in the bundle."
        clean_waveforms.append(clean_waveform)
    mixture_waveform = mixture_waveform.reshape(1, channel, -1)
    estimated_sources = model(mixture_waveform)
    clean_waveforms = torch.cat(clean_waveforms).unsqueeze(0)
    estimated_sources = estimated_sources.reshape(1, -1, clean_waveforms.shape[-1])
    sdr_values = sdr(estimated_sources, clean_waveforms).mean()
    assert sdr_values >= expected_score