File: squim_pipeline_test.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (43 lines) | stat: -rw-r--r-- 1,487 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import pytest
import torchaudio
from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE


@pytest.mark.parametrize(
    "lang,expected",
    [
        ("en", [0.9978380799293518, 4.23893404006958, 24.217193603515625]),
    ],
)
def test_squim_objective_pretrained_weights(lang, expected, sample_speech):
    """Test that the metric scores estimated by SquimObjective Bundle is identical to the expected result."""
    bundle = SQUIM_OBJECTIVE

    # Get SquimObjective model
    model = bundle.get_model()
    # Create a synthetic waveform
    waveform, sample_rate = torchaudio.load(sample_speech)
    scores = model(waveform)
    for i in range(3):
        assert abs(scores[i].item() - expected[i]) < 1e-5


@pytest.mark.parametrize(
    "task,expected",
    [
        ("speech_separation", [3.9257140159606934, 3.9391300678253174]),
    ],
)
def test_squim_subjective_pretrained_weights(task, expected, mixture_source, clean_sources):
    """Test that the metric scores estimated by SquimSubjective Bundle is identical to the expected result."""
    bundle = SQUIM_SUBJECTIVE

    # Get SquimObjective model
    model = bundle.get_model()
    # Load input mixture audio
    waveform, sample_rate = torchaudio.load(mixture_source)
    for i, source in enumerate(clean_sources):
        # Load clean reference
        clean_waveform, sample_rate = torchaudio.load(source)
        score = model(waveform, clean_waveform)
        assert abs(score.item() - expected[i]) < 1e-5