File: smoke_test.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (56 lines) | stat: -rw-r--r-- 2,107 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import io

from torchaudio._backend.utils import get_info_func, get_load_func, get_save_func
from torchaudio_unittest.common_utils import get_wav_data, PytorchTestCase, skipIfNoFFmpeg, TempDirMixin


@skipIfNoFFmpeg
class SmokeTest(TempDirMixin, PytorchTestCase):
    def run_smoke_test(self, ext, sample_rate, num_channels, *, dtype="float32"):
        duration = 1
        num_frames = sample_rate * duration
        path = self.get_temp_path(f"test.{ext}")
        original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames)

        get_save_func()(path, original, sample_rate)
        info = get_info_func()(path)
        assert info.sample_rate == sample_rate
        assert info.num_channels == num_channels

        loaded, sr = get_load_func()(path, normalize=False)
        assert sr == sample_rate
        assert loaded.shape[0] == num_channels

    def test_wav(self):
        dtype = "float32"
        sample_rate = 16000
        num_channels = 2
        self.run_smoke_test("wav", sample_rate, num_channels, dtype=dtype)


@skipIfNoFFmpeg
class SmokeTestFileObj(TempDirMixin, PytorchTestCase):
    def run_smoke_test(self, ext, sample_rate, num_channels, *, dtype="float32"):
        buffer_size = 8192
        duration = 1
        num_frames = sample_rate * duration
        fileobj = io.BytesIO()
        original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames)

        get_save_func()(fileobj, original, sample_rate, format=ext, buffer_size=buffer_size)

        fileobj.seek(0)
        info = get_info_func()(fileobj, format=ext, buffer_size=buffer_size)
        assert info.sample_rate == sample_rate
        assert info.num_channels == num_channels

        fileobj.seek(0)
        loaded, sr = get_load_func()(fileobj, normalize=False, format=ext, buffer_size=buffer_size)
        assert sr == sample_rate
        assert loaded.shape[0] == num_channels

    def test_wav(self):
        dtype = "float32"
        sample_rate = 16000
        num_channels = 2
        self.run_smoke_test("wav", sample_rate, num_channels, dtype=dtype)