File: playback_test.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (65 lines) | stat: -rw-r--r-- 2,227 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from unittest.mock import patch

import torch
from parameterized import parameterized
from torchaudio.io import play_audio, StreamWriter
from torchaudio_unittest.common_utils import get_sinusoid, skipIfNoAudioDevice, skipIfNoMacOS, TorchaudioTestCase


@skipIfNoAudioDevice
@skipIfNoMacOS
class PlaybackInterfaceTest(TorchaudioTestCase):
    @parameterized.expand([("uint8",), ("int16",), ("int32",), ("int64",), ("float32",), ("float64",)])
    @patch.object(StreamWriter, "write_audio_chunk")
    def test_playaudio(self, dtype, writeaudio_mock):
        """Test playaudio function.
        The patch object is used to check if the data is written
        to the output device stream, without playing the actual audio.
        """
        dtype = getattr(torch, dtype)
        sample_rate = 8000
        waveform = get_sinusoid(
            frequency=440,
            sample_rate=sample_rate,
            duration=1,  # seconds
            n_channels=1,
            dtype=dtype,
            device="cpu",
            channels_first=False,
        )

        play_audio(waveform, sample_rate=sample_rate)

        writeaudio_mock.assert_called()

    @parameterized.expand(
        [
            # Invalid number of dimensions (!= 2)
            ("int16", 1, "audiotoolbox"),
            ("int16", 3, "audiotoolbox"),
            # Invalid tensor type
            ("complex64", 2, "audiotoolbox"),
            # Invalid output device
            ("int16", 2, "audiotool"),
        ]
    )
    @patch.object(StreamWriter, "write_audio_chunk")
    def test_playaudio_invalid_options(self, dtype, ndim, device, writeaudio_mock):
        """Test playaudio function raises error with invalid options."""
        dtype = getattr(torch, dtype)
        sample_rate = 8000
        waveform = get_sinusoid(
            frequency=440,
            sample_rate=sample_rate,
            duration=1,  # seconds
            n_channels=1,
            dtype=dtype,
            device="cpu",
            channels_first=False,
        ).squeeze()

        for _ in range(ndim - 1):
            waveform = waveform.unsqueeze(-1)

        with self.assertRaises(ValueError):
            play_audio(waveform, sample_rate=sample_rate, device=device)