File: cuda_ctc_decoder_test.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (49 lines) | stat: -rw-r--r-- 1,412 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
from torchaudio_unittest.common_utils import (
    get_asset_path,
    skipIfNoCuCtcDecoder,
    skipIfNoCuda,
    TempDirMixin,
    TorchaudioTestCase,
)

NUM_TOKENS = 7


@skipIfNoCuda
@skipIfNoCuCtcDecoder
class CUCTCDecoderTest(TempDirMixin, TorchaudioTestCase):
    def _get_decoder(self, tokens=None, **kwargs):
        from torchaudio.models.decoder import cuda_ctc_decoder

        if tokens is None:
            tokens = get_asset_path("decoder/tokens.txt")

        return cuda_ctc_decoder(
            tokens=tokens,
            beam_size=5,
            **kwargs,
        )

    def _get_emissions(self):
        B, T, N = 4, 15, NUM_TOKENS

        emissions = torch.rand(B, T, N).cuda()
        emissions = torch.nn.functional.log_softmax(emissions, -1)

        return emissions

    def test_construct_basic_decoder_path(self):
        tokens_path = get_asset_path("decoder/tokens.txt")
        self._get_decoder(tokens=tokens_path)

    def test_construct_basic_decoder_tokens(self):
        tokens = ["-", "|", "f", "o", "b", "a", "r"]
        self._get_decoder(tokens=tokens)

    def test_shape(self):
        log_probs = self._get_emissions()
        encoder_out_lens = torch.tensor([15, 14, 13, 12], dtype=torch.int32).cuda()
        decoder = self._get_decoder()
        results = decoder(log_probs, encoder_out_lens)
        self.assertEqual(len(results), log_probs.shape[0])