File: hybrid_demucs_tutorial.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (348 lines) | stat: -rw-r--r-- 11,267 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
"""
Music Source Separation with Hybrid Demucs
==========================================

**Author**: `Sean Kim <https://github.com/skim0514>`__

This tutorial shows how to use the Hybrid Demucs model in order to
perform music separation

"""

######################################################################
# 1. Overview
# -----------
#
# Performing music separation is composed of the following steps
#
# 1. Build the Hybrid Demucs pipeline.
# 2. Format the waveform into chunks of expected sizes and loop through
#    chunks (with overlap) and feed into pipeline.
# 3. Collect output chunks and combine according to the way they have been
#    overlapped.
#
# The Hybrid Demucs [`Défossez, 2021 <https://arxiv.org/abs/2111.03600>`__]
# model is a developed version of the
# `Demucs <https://github.com/facebookresearch/demucs>`__ model, a
# waveform based model which separates music into its
# respective sources, such as vocals, bass, and drums.
# Hybrid Demucs effectively uses spectrogram to learn
# through the frequency domain and also moves to time convolutions.
#


######################################################################
# 2. Preparation
# --------------
#
# First, we install the necessary dependencies. The first requirement is
# ``torchaudio`` and ``torch``
#

import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)

import matplotlib.pyplot as plt

######################################################################
# In addition to ``torchaudio``, ``mir_eval`` is required to perform
# signal-to-distortion ratio (SDR) calculations. To install ``mir_eval``
# please use ``pip3 install mir_eval``.
#

from IPython.display import Audio
from mir_eval import separation
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS
from torchaudio.utils import download_asset

######################################################################
# 3. Construct the pipeline
# -------------------------
#
# Pre-trained model weights and related pipeline components are bundled as
# :py:func:`torchaudio.pipelines.HDEMUCS_HIGH_MUSDB_PLUS`. This is a
# :py:class:`torchaudio.models.HDemucs` model trained on
# `MUSDB18-HQ <https://zenodo.org/record/3338373>`__ and additional
# internal extra training data.
# This specific model is suited for higher sample rates, around 44.1 kHZ
# and has a nfft value of 4096 with a depth of 6 in the model implementation.

bundle = HDEMUCS_HIGH_MUSDB_PLUS

model = bundle.get_model()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model.to(device)

sample_rate = bundle.sample_rate

print(f"Sample rate: {sample_rate}")

######################################################################
# 4. Configure the application function
# -------------------------------------
#
# Because ``HDemucs`` is a large and memory-consuming model it is
# very difficult to have sufficient memory to apply the model to
# an entire song at once. To work around this limitation,
# obtain the separated sources of a full song by
# chunking the song into smaller segments and run through the
# model piece by piece, and then rearrange back together.
#
# When doing this, it is important to ensure some
# overlap between each of the chunks, to accommodate for artifacts at the
# edges. Due to the nature of the model, sometimes the edges have
# inaccurate or undesired sounds included.
#
# We provide a sample implementation of chunking and arrangement below. This
# implementation takes an overlap of 1 second on each side, and then does
# a linear fade in and fade out on each side. Using the faded overlaps, I
# add these segments together, to ensure a constant volume throughout.
# This accommodates for the artifacts by using less of the edges of the
# model outputs.
#
# .. image:: https://download.pytorch.org/torchaudio/tutorial-assets/HDemucs_Drawing.jpg

from torchaudio.transforms import Fade


def separate_sources(
    model,
    mix,
    segment=10.0,
    overlap=0.1,
    device=None,
):
    """
    Apply model to a given mixture. Use fade, and add segments together in order to add model segment by segment.

    Args:
        segment (int): segment length in seconds
        device (torch.device, str, or None): if provided, device on which to
            execute the computation, otherwise `mix.device` is assumed.
            When `device` is different from `mix.device`, only local computations will
            be on `device`, while the entire tracks will be stored on `mix.device`.
    """
    if device is None:
        device = mix.device
    else:
        device = torch.device(device)

    batch, channels, length = mix.shape

    chunk_len = int(sample_rate * segment * (1 + overlap))
    start = 0
    end = chunk_len
    overlap_frames = overlap * sample_rate
    fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape="linear")

    final = torch.zeros(batch, len(model.sources), channels, length, device=device)

    while start < length - overlap_frames:
        chunk = mix[:, :, start:end]
        with torch.no_grad():
            out = model.forward(chunk)
        out = fade(out)
        final[:, :, :, start:end] += out
        if start == 0:
            fade.fade_in_len = int(overlap_frames)
            start += int(chunk_len - overlap_frames)
        else:
            start += chunk_len
        end += chunk_len
        if end >= length:
            fade.fade_out_len = 0
    return final


def plot_spectrogram(stft, title="Spectrogram"):
    magnitude = stft.abs()
    spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
    _, axis = plt.subplots(1, 1)
    axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto")
    axis.set_title(title)
    plt.tight_layout()


######################################################################
# 5. Run Model
# ------------
#
# Finally, we run the model and store the separate source files in a
# directory
#
# As a test song, we will be using A Classic Education by NightOwl from
# MedleyDB (Creative Commons BY-NC-SA 4.0). This is also located in
# `MUSDB18-HQ <https://zenodo.org/record/3338373>`__ dataset within
# the ``train`` sources.
#
# In order to test with a different song, the variable names and urls
# below can be changed alongside with the parameters to test the song
# separator in different ways.
#

# We download the audio file from our storage. Feel free to download another file and use audio from a specific path
SAMPLE_SONG = download_asset("tutorial-assets/hdemucs_mix.wav")
waveform, sample_rate = torchaudio.load(SAMPLE_SONG)  # replace SAMPLE_SONG with desired path for different song
waveform = waveform.to(device)
mixture = waveform

# parameters
segment: int = 10
overlap = 0.1

print("Separating track")

ref = waveform.mean(0)
waveform = (waveform - ref.mean()) / ref.std()  # normalization

sources = separate_sources(
    model,
    waveform[None],
    device=device,
    segment=segment,
    overlap=overlap,
)[0]
sources = sources * ref.std() + ref.mean()

sources_list = model.sources
sources = list(sources)

audios = dict(zip(sources_list, sources))

######################################################################
# 5.1 Separate Track
# ^^^^^^^^^^^^^^^^^^
#
# The default set of pretrained weights that has been loaded has 4 sources
# that it is separated into: drums, bass, other, and vocals in that order.
# They have been stored into the dict “audios” and therefore can be
# accessed there. For the four sources, there is a separate cell for each,
# that will create the audio, the spectrogram graph, and also calculate
# the SDR score. SDR is the signal-to-distortion
# ratio, essentially a representation to the “quality” of an audio track.
#

N_FFT = 4096
N_HOP = 4
stft = torchaudio.transforms.Spectrogram(
    n_fft=N_FFT,
    hop_length=N_HOP,
    power=None,
)


######################################################################
# 5.2 Audio Segmenting and Processing
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Below is the processing steps and segmenting 5 seconds of the tracks in
# order to feed into the spectrogram and to caclulate the respective SDR
# scores.
#


def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor, source: str):
    print(
        "SDR score is:",
        separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(),
    )
    plot_spectrogram(stft(predicted_source)[0], f"Spectrogram - {source}")
    return Audio(predicted_source, rate=sample_rate)


segment_start = 150
segment_end = 155

frame_start = segment_start * sample_rate
frame_end = segment_end * sample_rate

drums_original = download_asset("tutorial-assets/hdemucs_drums_segment.wav")
bass_original = download_asset("tutorial-assets/hdemucs_bass_segment.wav")
vocals_original = download_asset("tutorial-assets/hdemucs_vocals_segment.wav")
other_original = download_asset("tutorial-assets/hdemucs_other_segment.wav")

drums_spec = audios["drums"][:, frame_start:frame_end].cpu()
drums, sample_rate = torchaudio.load(drums_original)

bass_spec = audios["bass"][:, frame_start:frame_end].cpu()
bass, sample_rate = torchaudio.load(bass_original)

vocals_spec = audios["vocals"][:, frame_start:frame_end].cpu()
vocals, sample_rate = torchaudio.load(vocals_original)

other_spec = audios["other"][:, frame_start:frame_end].cpu()
other, sample_rate = torchaudio.load(other_original)

mix_spec = mixture[:, frame_start:frame_end].cpu()


######################################################################
# 5.3 Spectrograms and Audio
# ^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# In the next 5 cells, you can see the spectrograms with the respective
# audios. The audios can be clearly visualized using the spectrogram.
#
# The mixture clip comes from the original track, and the remaining
# tracks are the model output
#

# Mixture Clip
plot_spectrogram(stft(mix_spec)[0], "Spectrogram - Mixture")
Audio(mix_spec, rate=sample_rate)

######################################################################
# Drums SDR, Spectrogram, and Audio
#

# Drums Clip
output_results(drums, drums_spec, "drums")

######################################################################
# Bass SDR, Spectrogram, and Audio
#

# Bass Clip
output_results(bass, bass_spec, "bass")

######################################################################
# Vocals SDR, Spectrogram, and Audio
#

# Vocals Audio
output_results(vocals, vocals_spec, "vocals")

######################################################################
# Other SDR, Spectrogram, and Audio
#

# Other Clip
output_results(other, other_spec, "other")

######################################################################

# Optionally, the full audios can be heard in from running the next 5
# cells. They will take a bit longer to load, so to run simply uncomment
# out the ``Audio`` cells for the respective track to produce the audio
# for the full song.
#

# Full Audio
# Audio(mixture, rate=sample_rate)

# Drums Audio
# Audio(audios["drums"], rate=sample_rate)

# Bass Audio
# Audio(audios["bass"], rate=sample_rate)

# Vocals Audio
# Audio(audios["vocals"], rate=sample_rate)

# Other Audio
# Audio(audios["other"], rate=sample_rate)