1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
|
from unittest.mock import patch
import torch
from parameterized import parameterized
from torchaudio.io import play_audio, StreamWriter
from torchaudio_unittest.common_utils import get_sinusoid, skipIfNoAudioDevice, skipIfNoMacOS, TorchaudioTestCase
@skipIfNoAudioDevice
@skipIfNoMacOS
class PlaybackInterfaceTest(TorchaudioTestCase):
@parameterized.expand([("uint8",), ("int16",), ("int32",), ("int64",), ("float32",), ("float64",)])
@patch.object(StreamWriter, "write_audio_chunk")
def test_playaudio(self, dtype, writeaudio_mock):
"""Test playaudio function.
The patch object is used to check if the data is written
to the output device stream, without playing the actual audio.
"""
dtype = getattr(torch, dtype)
sample_rate = 8000
waveform = get_sinusoid(
frequency=440,
sample_rate=sample_rate,
duration=1, # seconds
n_channels=1,
dtype=dtype,
device="cpu",
channels_first=False,
)
play_audio(waveform, sample_rate=sample_rate)
writeaudio_mock.assert_called()
@parameterized.expand(
[
# Invalid number of dimensions (!= 2)
("int16", 1, "audiotoolbox"),
("int16", 3, "audiotoolbox"),
# Invalid tensor type
("complex64", 2, "audiotoolbox"),
# Invalid output device
("int16", 2, "audiotool"),
]
)
@patch.object(StreamWriter, "write_audio_chunk")
def test_playaudio_invalid_options(self, dtype, ndim, device, writeaudio_mock):
"""Test playaudio function raises error with invalid options."""
dtype = getattr(torch, dtype)
sample_rate = 8000
waveform = get_sinusoid(
frequency=440,
sample_rate=sample_rate,
duration=1, # seconds
n_channels=1,
dtype=dtype,
device="cpu",
channels_first=False,
).squeeze()
for _ in range(ndim - 1):
waveform = waveform.unsqueeze(-1)
with self.assertRaises(ValueError):
play_audio(waveform, sample_rate=sample_rate, device=device)
|