File: smoke_test.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (90 lines) | stat: -rw-r--r-- 3,080 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import itertools

from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio_unittest.common_utils import get_wav_data, skipIfNoSox, TempDirMixin, TorchaudioTestCase

from .common import name_func


@skipIfNoSox
class SmokeTest(TempDirMixin, TorchaudioTestCase):
    """Run smoke test on various audio format

    The purpose of this test suite is to verify that sox_io_backend functionalities do not exhibit
    abnormal behaviors.

    This test suite should be able to run without any additional tools (such as sox command),
    however without such tools, the correctness of each function cannot be verified.
    """

    def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype="float32"):
        duration = 1
        num_frames = sample_rate * duration
        path = self.get_temp_path(f"test.{ext}")
        original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames)

        # 1. run save
        sox_io_backend.save(path, original, sample_rate, compression=compression)
        # 2. run info
        info = sox_io_backend.info(path)
        assert info.sample_rate == sample_rate
        assert info.num_channels == num_channels
        # 3. run load
        loaded, sr = sox_io_backend.load(path, normalize=False)
        assert sr == sample_rate
        assert loaded.shape[0] == num_channels

    @parameterized.expand(
        list(
            itertools.product(
                ["float32", "int32", "int16", "uint8"],
                [8000, 16000],
                [1, 2],
            )
        ),
        name_func=name_func,
    )
    def test_wav(self, dtype, sample_rate, num_channels):
        """Run smoke test on wav format"""
        self.run_smoke_test("wav", sample_rate, num_channels, dtype=dtype)

    @parameterized.expand(
        list(
            itertools.product(
                [8000, 16000],
                [1, 2],
                [-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
            )
        )
    )
    def test_mp3(self, sample_rate, num_channels, bit_rate):
        """Run smoke test on mp3 format"""
        self.run_smoke_test("mp3", sample_rate, num_channels, compression=bit_rate)

    @parameterized.expand(
        list(
            itertools.product(
                [8000, 16000],
                [1, 2],
                [-1, 0, 1, 2, 3, 3.6, 5, 10],
            )
        )
    )
    def test_vorbis(self, sample_rate, num_channels, quality_level):
        """Run smoke test on vorbis format"""
        self.run_smoke_test("vorbis", sample_rate, num_channels, compression=quality_level)

    @parameterized.expand(
        list(
            itertools.product(
                [8000, 16000],
                [1, 2],
                list(range(9)),
            )
        ),
        name_func=name_func,
    )
    def test_flac(self, sample_rate, num_channels, compression_level):
        """Run smoke test on flac format"""
        self.run_smoke_test("flac", sample_rate, num_channels, compression=compression_level)