File: rnnt_pipeline_test.py

package info (click to toggle)
pytorch-audio 0.13.1-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 8,592 kB
  • sloc: python: 41,137; cpp: 8,016; sh: 3,538; makefile: 24
file content (20 lines) | stat: -rw-r--r-- 691 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import pytest
import torchaudio
from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH


@pytest.mark.parametrize(
    "bundle,lang,expected",
    [
        (EMFORMER_RNNT_BASE_LIBRISPEECH, "en", "i have that curiosity beside me at this moment"),
    ],
)
def test_rnnt(bundle, sample_speech, expected):
    feature_extractor = bundle.get_feature_extractor()
    decoder = bundle.get_decoder().eval()
    token_processor = bundle.get_token_processor()
    waveform, _ = torchaudio.load(sample_speech)
    features, length = feature_extractor(waveform.squeeze())
    hypotheses = decoder(features, length, 10)
    text = token_processor(hypotheses[0][0])
    assert text == expected