File: create_jittable_pipeline.py

package info (click to toggle)
pytorch-audio 0.13.1-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 8,592 kB
  • sloc: python: 41,137; cpp: 8,016; sh: 3,538; makefile: 24
file content (79 lines) | stat: -rwxr-xr-x 2,259 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#!/usr/bin/env python3
"""
Create a data preprocess pipeline that can be run with libtorchaudio
"""
import argparse
import os

import torch
import torchaudio


class Pipeline(torch.nn.Module):
    """Example audio process pipeline.

    This example load waveform from a file then apply effects and save it to a file.
    """

    def __init__(self, rir_path: str):
        super().__init__()
        rir, sample_rate = torchaudio.load(rir_path)
        self.register_buffer("rir", rir)
        self.rir_sample_rate: int = sample_rate

    def forward(self, input_path: str, output_path: str):
        torchaudio.sox_effects.init_sox_effects()

        # 1. load audio
        waveform, sample_rate = torchaudio.load(input_path)

        # 2. Add background noise
        alpha = 0.01
        waveform = alpha * torch.randn_like(waveform) + (1 - alpha) * waveform

        # 3. Reample the RIR filter to much the audio sample rate
        rir, _ = torchaudio.sox_effects.apply_effects_tensor(
            self.rir, self.rir_sample_rate, effects=[["rate", str(sample_rate)]]
        )
        rir = rir / torch.norm(rir, p=2)
        rir = torch.flip(rir, [1])

        # 4. Apply RIR filter
        waveform = torch.nn.functional.pad(waveform, (rir.shape[1] - 1, 0))
        waveform = torch.nn.functional.conv1d(waveform[None, ...], rir[None, ...])[0]

        # Save
        torchaudio.save(output_path, waveform, sample_rate)


def _create_jit_pipeline(rir_path, output_path):
    module = torch.jit.script(Pipeline(rir_path))
    print("*" * 40)
    print("* Pipeline code")
    print("*" * 40)
    print()
    print(module.code)
    print("*" * 40)
    module.save(output_path)


def _get_path(*paths):
    return os.path.join(os.path.dirname(__file__), *paths)


def _parse_args():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "--rir-path", default=_get_path("..", "data", "rir.wav"), help="Audio dara for room impulse response."
    )
    parser.add_argument("--output-path", default=_get_path("pipeline.zip"), help="Output JIT file.")
    return parser.parse_args()


def _main():
    args = _parse_args()
    _create_jit_pipeline(args.rir_path, args.output_path)


if __name__ == "__main__":
    _main()