File: subtractive_synthesis_tutorial.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (276 lines) | stat: -rw-r--r-- 8,304 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
# -*- coding: utf-8 -*-
"""
Subtractive synthesis
=====================

**Author**: `Moto Hira <moto@meta.com>`__

This tutorial is the continuation of
`Filter Design Tutorial <./filter_design_tutorial.html>`__.

This tutorial shows how to perform subtractive synthesis with TorchAudio's DSP functions.

Subtractive synthesis creates timbre by applying filters to source waveform.

.. warning::
   This tutorial requires prototype DSP features, which are
   available in nightly builds.

   Please refer to https://pytorch.org/get-started/locally
   for instructions for installing a nightly build.
"""

import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)

######################################################################
# Overview
# --------
#
#

try:
    from torchaudio.prototype.functional import filter_waveform, frequency_impulse_response, sinc_impulse_response
except ModuleNotFoundError:
    print(
        "Failed to import prototype DSP features. "
        "Please install torchaudio nightly builds. "
        "Please refer to https://pytorch.org/get-started/locally "
        "for instructions to install a nightly build."
    )
    raise

import matplotlib.pyplot as plt
from IPython.display import Audio


######################################################################
# Filtered Noise
# --------------
#
# Subtractive synthesis starts with a waveform and applies filters to
# some frequency components.
#
# For the first example of subtractive synthesis, we apply
# time-varying low pass filter to white noise.
#
# First, we create a white noise.
#

SAMPLE_RATE = 16_000
duration = 4
num_frames = int(duration * SAMPLE_RATE)

noise = torch.rand((num_frames,)) - 0.5


######################################################################
#
def plot_input():
    fig, axes = plt.subplots(2, 1, sharex=True)
    t = torch.linspace(0, duration, num_frames)
    axes[0].plot(t, noise)
    axes[0].grid(True)
    axes[1].specgram(noise, Fs=SAMPLE_RATE)
    Audio(noise, rate=SAMPLE_RATE)


plot_input()

######################################################################
# Windowed-sinc filter
# --------------------
#

######################################################################
#
# Sweeping cutoff frequency
# ~~~~~~~~~~~~~~~~~~~~~~~~~
#
# We use :py:func:`~torchaudio.prototype.functional.sinc_impulse_response` to
# create series of low pass filters, while changing the  cut-off
# frequency from zero to Nyquist frequency.
#

num_filters = 64 * duration
window_size = 2049

f_cutoff = torch.linspace(0.0, 0.8, num_filters)
kernel = sinc_impulse_response(f_cutoff, window_size)

######################################################################
#
# To apply time-varying filter, we use
# :py:func:`~torchaudio.prototype.functional.filter_waveform`
#

filtered = filter_waveform(noise, kernel)

######################################################################
#
# Let's look at the spectrogram of the resulting audio and listen to it.
#


def plot_sinc_ir(waveform, cutoff, sample_rate, vol=0.2):
    num_frames = waveform.size(0)
    duration = num_frames / sample_rate
    num_cutoff = cutoff.size(0)
    nyquist = sample_rate / 2

    _, axes = plt.subplots(2, 1, sharex=True)
    t = torch.linspace(0, duration, num_frames)
    axes[0].plot(t, waveform)
    axes[0].grid(True)
    axes[1].specgram(waveform, Fs=sample_rate, scale="dB")
    t = torch.linspace(0, duration, num_cutoff)
    axes[1].plot(t, cutoff * nyquist, color="gray", linewidth=0.8, label="Cutoff Frequency", linestyle="--")
    axes[1].legend(loc="upper center")
    axes[1].set_ylim([0, nyquist])
    waveform /= waveform.abs().max()
    return Audio(vol * waveform, rate=sample_rate, normalize=False)


######################################################################
#

plot_sinc_ir(filtered, f_cutoff, SAMPLE_RATE)

######################################################################
#
# Oscillating cutoff frequency
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# By oscillating the cutoff frequency, we can emulate an effect of
# Low-frequency oscillation (LFO).
#

PI2 = torch.pi * 2
num_filters = 90 * duration

f_lfo = torch.linspace(0.9, 0.1, num_filters)
f_cutoff_osci = torch.linspace(0.01, 0.03, num_filters) * torch.sin(torch.cumsum(f_lfo, dim=0))
f_cutoff_base = torch.linspace(0.8, 0.03, num_filters) ** 1.7
f_cutoff = f_cutoff_base + f_cutoff_osci

######################################################################
#

kernel = sinc_impulse_response(f_cutoff, window_size)
filtered = filter_waveform(noise, kernel)

######################################################################
#

plot_sinc_ir(filtered, f_cutoff, SAMPLE_RATE)

######################################################################
#
# Wah-wah effects
# ~~~~~~~~~~~~~~~
#
# Wah-wah effects are applications of low-pass filter or band-pass filter.
# They change the cut-off freuqnecy or Q-factor quickly.

f_lfo = torch.linspace(0.15, 0.15, num_filters)
f_cutoff = 0.07 + 0.06 * torch.sin(torch.cumsum(f_lfo, dim=0))

######################################################################
#

kernel = sinc_impulse_response(f_cutoff, window_size)
filtered = filter_waveform(noise, kernel)

######################################################################
#

plot_sinc_ir(filtered, f_cutoff, SAMPLE_RATE)

######################################################################
# Arbitrary frequence response
# ----------------------------
#
# By using
# :py:func:`~torchaudio.prototype.functinal.frequency_impulse_response`,
# one can directly control the power distribution over frequency.
#


magnitudes = torch.sin(torch.linspace(0, 10, 64)) ** 4.0
kernel = frequency_impulse_response(magnitudes)
filtered = filter_waveform(noise, kernel.unsqueeze(0))

######################################################################
#


def plot_waveform(magnitudes, filtered, sample_rate):
    nyquist = sample_rate / 2
    num_samples = filtered.size(-1)
    duration = num_samples / sample_rate

    # Re-organize magnitudes for overlay
    N = 10  # number of overlays
    interval = torch.linspace(0.05, 0.95, N)
    offsets = duration * interval
    # Select N magnitudes for overlays
    mags = torch.stack(
        [magnitudes for _ in range(N)]
        if magnitudes.ndim == 1
        else [magnitudes[int(i * magnitudes.size(0))] for i in interval]
    )
    mag_x = offsets.unsqueeze(-1) + 0.1 * mags
    mag_y = torch.linspace(0, nyquist, magnitudes.size(-1)).tile((N, 1))

    _, ax = plt.subplots(1, 1, sharex=True)
    ax.vlines(offsets, 0, nyquist, color="gray", linestyle="--", linewidth=0.8)
    ax.plot(mag_x.T.numpy(), mag_y.T.numpy(), color="gray", linewidth=0.8)
    ax.specgram(filtered, Fs=sample_rate)
    return Audio(filtered, rate=sample_rate)


######################################################################
#
plot_waveform(magnitudes, filtered, SAMPLE_RATE)

######################################################################
#
# It is also possible to make a non-stationary filter.

magnitudes = torch.stack([torch.linspace(0.0, w, 1000) for w in torch.linspace(4.0, 40.0, 250)])
magnitudes = torch.sin(magnitudes) ** 4.0

######################################################################
#
kernel = frequency_impulse_response(magnitudes)
filtered = filter_waveform(noise, kernel)

######################################################################
#
plot_waveform(magnitudes, filtered, SAMPLE_RATE)

######################################################################
#
# Of course it is also possible to emulate simple low pass filter.

magnitudes = torch.concat([torch.ones((32,)), torch.zeros((32,))])

######################################################################
#
kernel = frequency_impulse_response(magnitudes)
filtered = filter_waveform(noise, kernel.unsqueeze(0))

######################################################################
#
plot_waveform(magnitudes, filtered, SAMPLE_RATE)

######################################################################
# References
# ----------
#
# - https://en.wikipedia.org/wiki/Additive_synthesis
# - https://computermusicresource.com/Simple.bell.tutorial.html
# - https://computermusicresource.com/Definitions/additive.synthesis.html