File: kaldi_utils.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (38 lines) | stat: -rw-r--r-- 1,303 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import subprocess

import torch


def convert_args(**kwargs):
    args = []
    for key, value in kwargs.items():
        if key == "sample_rate":
            key = "sample_frequency"
        key = "--" + key.replace("_", "-")
        value = str(value).lower() if value in [True, False] else str(value)
        args.append("%s=%s" % (key, value))
    return args


def run_kaldi(command, input_type, input_value):
    """Run provided Kaldi command, pass a tensor and get the resulting tensor

    Args:
        command (list of str): The command with arguments
        input_type (str): 'ark' or 'scp'
        input_value (Tensor for 'ark', string for 'scp'): The input to pass.
            Must be a path to an audio file for 'scp'.
    """
    import kaldi_io

    key = "foo"
    process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
    if input_type == "ark":
        kaldi_io.write_mat(process.stdin, input_value.cpu().numpy(), key=key)
    elif input_type == "scp":
        process.stdin.write(f"{key} {input_value}".encode("utf8"))
    else:
        raise NotImplementedError("Unexpected type")
    process.stdin.close()
    result = dict(kaldi_io.read_mat_ark(process.stdout))["foo"]
    return torch.from_numpy(result.copy())  # copy supresses some torch warning