File: kaldi_compatibility_test_impl.py

package info (click to toggle)
pytorch-audio 0.13.1-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 8,592 kB
  • sloc: python: 41,137; cpp: 8,016; sh: 3,538; makefile: 24
file content (56 lines) | stat: -rw-r--r-- 2,314 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torchaudio.functional as F
from parameterized import parameterized
from torchaudio_unittest.common_utils import (
    get_sinusoid,
    load_params,
    save_wav,
    skipIfNoExec,
    TempDirMixin,
    TestBaseMixin,
)
from torchaudio_unittest.common_utils.kaldi_utils import convert_args, run_kaldi


class Kaldi(TempDirMixin, TestBaseMixin):
    def assert_equal(self, output, *, expected, rtol=None, atol=None):
        expected = expected.to(dtype=self.dtype, device=self.device)
        self.assertEqual(output, expected, rtol=rtol, atol=atol)

    @skipIfNoExec("apply-cmvn-sliding")
    def test_sliding_window_cmn(self):
        """sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
        kwargs = {
            "cmn_window": 600,
            "min_cmn_window": 100,
            "center": False,
            "norm_vars": False,
        }

        tensor = torch.randn(40, 10, dtype=self.dtype, device=self.device)
        result = F.sliding_window_cmn(tensor, **kwargs)
        command = ["apply-cmvn-sliding"] + convert_args(**kwargs) + ["ark:-", "ark:-"]
        kaldi_result = run_kaldi(command, "ark", tensor)
        self.assert_equal(result, expected=kaldi_result)


class KaldiCPUOnly(TempDirMixin, TestBaseMixin):
    def assert_equal(self, output, *, expected, rtol=None, atol=None):
        expected = expected.to(dtype=self.dtype, device=self.device)
        self.assertEqual(output, expected, rtol=rtol, atol=atol)

    @parameterized.expand(load_params("kaldi_test_pitch_args.jsonl"))
    @skipIfNoExec("compute-kaldi-pitch-feats")
    def test_pitch_feats(self, kwargs):
        """compute_kaldi_pitch produces numerically compatible result with compute-kaldi-pitch-feats"""
        sample_rate = kwargs["sample_rate"]
        waveform = get_sinusoid(dtype="float32", sample_rate=sample_rate)
        result = F.compute_kaldi_pitch(waveform[0], **kwargs)

        waveform = get_sinusoid(dtype="int16", sample_rate=sample_rate)
        wave_file = self.get_temp_path("test.wav")
        save_wav(wave_file, waveform, sample_rate)

        command = ["compute-kaldi-pitch-feats"] + convert_args(**kwargs) + ["scp:-", "ark:-"]
        kaldi_result = run_kaldi(command, "scp", wave_file)
        self.assert_equal(result, expected=kaldi_result)