File: roundtrip_test.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (56 lines) | stat: -rw-r--r-- 2,009 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import itertools

from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio_unittest.common_utils import get_wav_data, PytorchTestCase, skipIfNoExec, skipIfNoSox, TempDirMixin

from .common import get_enc_params, name_func


@skipIfNoExec("sox")
@skipIfNoSox
class TestRoundTripIO(TempDirMixin, PytorchTestCase):
    """save/load round trip should not degrade data for lossless formats"""

    @parameterized.expand(
        list(
            itertools.product(
                ["float32", "int32", "int16", "uint8"],
                [8000, 16000],
                [1, 2],
            )
        ),
        name_func=name_func,
    )
    def test_wav(self, dtype, sample_rate, num_channels):
        """save/load round trip should not degrade data for wav formats"""
        original = get_wav_data(dtype, num_channels, normalize=False)
        enc, bps = get_enc_params(dtype)
        data = original
        for i in range(10):
            path = self.get_temp_path(f"{i}.wav")
            sox_io_backend.save(path, data, sample_rate, encoding=enc, bits_per_sample=bps)
            data, sr = sox_io_backend.load(path, normalize=False)
            assert sr == sample_rate
            self.assertEqual(original, data)

    @parameterized.expand(
        list(
            itertools.product(
                [8000, 16000],
                [1, 2],
                list(range(9)),
            )
        ),
        name_func=name_func,
    )
    def test_flac(self, sample_rate, num_channels, compression_level):
        """save/load round trip should not degrade data for flac formats"""
        original = get_wav_data("float32", num_channels)
        data = original
        for i in range(10):
            path = self.get_temp_path(f"{i}.flac")
            sox_io_backend.save(path, data, sample_rate, compression=compression_level)
            data, sr = sox_io_backend.load(path)
            assert sr == sample_rate
            self.assertEqual(original, data)