File: conftest.py

package info (click to toggle)
pytorch-audio 0.13.1-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 8,592 kB
  • sloc: python: 41,137; cpp: 8,016; sh: 3,538; makefile: 24
file content (119 lines) | stat: -rw-r--r-- 3,393 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import shutil

import pytest
import torch
import torchaudio


class GreedyCTCDecoder(torch.nn.Module):
    def __init__(self, labels, blank: int = 0):
        super().__init__()
        self.blank = blank
        self.labels = labels

    def forward(self, logits: torch.Tensor) -> str:
        """Given a sequence logits over labels, get the best path string

        Args:
            logits (Tensor): Logit tensors. Shape `[num_seq, num_label]`.

        Returns:
            str: The resulting transcript
        """
        best_path = torch.argmax(logits, dim=-1)  # [num_seq,]
        best_path = torch.unique_consecutive(best_path, dim=-1)
        hypothesis = []
        for i in best_path:
            if i != self.blank:
                hypothesis.append(self.labels[i])
        return "".join(hypothesis)


@pytest.fixture
def ctc_decoder():
    return GreedyCTCDecoder


_FILES = {
    "en": "Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.flac",
    "de": "20090505-0900-PLENARY-16-de_20090505-21_56_00_8.flac",
    "en2": "20120613-0900-PLENARY-8-en_20120613-13_46_50_3.flac",
    "es": "20130207-0900-PLENARY-7-es_20130207-13_02_05_5.flac",
    "fr": "20121212-0900-PLENARY-5-fr_20121212-11_37_04_10.flac",
    "it": "20170516-0900-PLENARY-16-it_20170516-18_56_31_1.flac",
}
_MIXTURE_FILES = {
    "speech_separation": "mixture_3729-6852-0037_8463-287645-0000.wav",
    "music_separation": "al_james_mixture_shorter.wav",
}
_CLEAN_FILES = {
    "speech_separation": [
        "s1_3729-6852-0037_8463-287645-0000.wav",
        "s2_3729-6852-0037_8463-287645-0000.wav",
    ],
    "music_separation": [
        "al_james_drums_shorter.wav",
        "al_james_bass_shorter.wav",
        "al_james_other_shorter.wav",
        "al_james_vocals_shorter.wav",
    ],
}


@pytest.fixture
def sample_speech(lang):
    if lang not in _FILES:
        raise NotImplementedError(f"Unexpected lang: {lang}")
    filename = _FILES[lang]
    path = torchaudio.utils.download_asset(f"test-assets/{filename}")
    return path


@pytest.fixture
def mixture_source(task):
    if task not in _MIXTURE_FILES:
        raise NotImplementedError(f"Unexpected task: {task}")
    path = torchaudio.utils.download_asset(f"test-assets/{_MIXTURE_FILES[task]}")
    return path


@pytest.fixture
def clean_sources(task):
    if task not in _CLEAN_FILES:
        raise NotImplementedError(f"Unexpected task: {task}")
    paths = []
    for file in _CLEAN_FILES[task]:
        path = torchaudio.utils.download_asset(f"test-assets/{file}")
        paths.append(path)
    return paths


def pytest_addoption(parser):
    parser.addoption(
        "--use-tmp-hub-dir",
        action="store_true",
        help=(
            "When provided, tests will use temporary directory as Torch Hub directory. "
            "Downloaded models will be deleted after each test."
        ),
    )


@pytest.fixture(autouse=True)
def temp_hub_dir(tmp_path, pytestconfig):
    if not pytestconfig.getoption("use_tmp_hub_dir"):
        yield
    else:
        org_dir = torch.hub.get_dir()
        subdir = os.path.join(tmp_path, "hub")
        torch.hub.set_dir(subdir)
        yield
        torch.hub.set_dir(org_dir)
        shutil.rmtree(subdir, ignore_errors=True)


@pytest.fixture()
def emissions():
    path = torchaudio.utils.download_asset("test-assets/emissions-8555-28447-0012.pt")
    return torch.load(path)