File: kaldi_io_test.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (33 lines) | stat: -rw-r--r-- 1,470 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torchaudio.kaldi_io as kio
from torchaudio_unittest import common_utils


class Test_KaldiIO(common_utils.TorchaudioTestCase):
    data1 = [[1, 2, 3], [11, 12, 13], [21, 22, 23]]
    data2 = [[31, 32, 33], [41, 42, 43], [51, 52, 53]]

    def _test_helper(self, file_name, expected_data, fn, expected_dtype):
        """Takes a file_name to the input data and a function fn to extract the
        data. It compares the extracted data to the expected_data. The expected_dtype
        will be used to check that the extracted data is of the right type.
        """
        test_filepath = common_utils.get_asset_path(file_name)
        expected_output = {
            "key" + str(idx + 1): torch.tensor(val, dtype=expected_dtype) for idx, val in enumerate(expected_data)
        }

        for key, vec in fn(test_filepath):
            self.assertTrue(key in expected_output)
            self.assertTrue(isinstance(vec, torch.Tensor))
            self.assertEqual(vec.dtype, expected_dtype)
            self.assertTrue(torch.all(torch.eq(vec, expected_output[key])))

    def test_read_vec_int_ark(self):
        self._test_helper("vec_int.ark", self.data1, kio.read_vec_int_ark, torch.int32)

    def test_read_vec_flt_ark(self):
        self._test_helper("vec_flt.ark", self.data1, kio.read_vec_flt_ark, torch.float32)

    def test_read_mat_ark(self):
        self._test_helper("mat.ark", [self.data1, self.data2], kio.read_mat_ark, torch.float32)