1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
|
import pytest
import torchaudio
from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
@pytest.mark.parametrize(
"bundle,lang,expected",
[
(EMFORMER_RNNT_BASE_LIBRISPEECH, "en", "i have that curiosity beside me at this moment"),
(EMFORMER_RNNT_BASE_MUSTC, "en", "I had that curiosity beside me at this moment."),
(EMFORMER_RNNT_BASE_TEDLIUM3, "en", "i had that curiosity beside me at this moment"),
],
)
def test_rnnt(bundle, sample_speech, expected):
feature_extractor = bundle.get_feature_extractor()
decoder = bundle.get_decoder().eval()
token_processor = bundle.get_token_processor()
waveform, _ = torchaudio.load(sample_speech)
features, length = feature_extractor(waveform.squeeze())
hypotheses = decoder(features, length, 10)
text = token_processor(hypotheses[0][0])
assert text == expected
|