File: hdemucs_test_impl.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (160 lines) | stat: -rw-r--r-- 5,864 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import itertools
from typing import List

import torch
from parameterized import parameterized
from torchaudio.models._hdemucs import _HDecLayer, _HEncLayer, HDemucs, hdemucs_high, hdemucs_low
from torchaudio_unittest.common_utils import skipIfNoModule, TestBaseMixin, TorchaudioTestCase


def _get_hdemucs_model(sources: List[str], n_fft: int = 4096, depth: int = 6):
    return HDemucs(sources, nfft=n_fft, depth=depth)


def _get_inputs(sample_rate: int, device: torch.device, batch_size: int = 1, duration: int = 10, channels: int = 2):
    sample = torch.rand(batch_size, channels, duration * sample_rate, dtype=torch.float32, device=device)
    return sample


SOURCE_OPTIONS = [
    (["bass", "drums", "other", "vocals"],),
    (["bass", "drums", "other"],),
    (["bass", "vocals"],),
    (["vocals"],),
]

SOURCES_OUTPUT_CONFIG = parameterized.expand(SOURCE_OPTIONS)


class HDemucsTests(TestBaseMixin):
    @parameterized.expand(list(itertools.product(SOURCE_OPTIONS, [(1024, 5), (2048, 6), (4096, 6)])))
    def test_hdemucs_output_shape(self, sources, nfft_bundle):
        r"""Feed tensors with specific shape to HDemucs and validate
        that it outputs with a tensor with expected shape.
        """
        duration = 10
        channels = 2
        batch_size = 1
        sample_rate = 44100
        nfft = nfft_bundle[0]
        depth = nfft_bundle[1]

        model = _get_hdemucs_model(sources, nfft, depth).to(self.device).eval()
        inputs = _get_inputs(sample_rate, self.device, batch_size, duration, channels)

        split_sample = model(inputs)

        assert split_sample.shape == (batch_size, len(sources), channels, duration * sample_rate)

    def test_encoder_output_shape_frequency(self):
        r"""Feed tensors with specific shape to HDemucs Decoder and validate
        that it outputs with a tensor with expected shape for frequency domain.
        """
        batch_size = 1
        chin, chout = 4, 48
        f_bins = 2048
        t = 800
        stride = 4

        model = _HEncLayer(chin, chout).to(self.device).eval()

        x = torch.rand(batch_size, chin, f_bins, t, device=self.device, dtype=self.dtype)
        out = model(x)

        assert out.size() == (batch_size, chout, f_bins / stride, t)

    def test_decoder_output_shape_frequency(self):
        r"""Feed tensors with specific shape to HDemucs Decoder and validate
        that it outputs with a tensor with expected shape for frequency domain.
        """
        batch_size = 1
        chin, chout = 96, 48
        f_bins = 128
        t = 800
        stride = 4

        model = _HDecLayer(chin, chout).to(self.device).eval()

        x = torch.rand(batch_size, chin, f_bins, t, device=self.device, dtype=self.dtype)
        skip = torch.rand(batch_size, chin, f_bins, t, device=self.device, dtype=self.dtype)
        z, y = model(x, skip, t)

        assert z.size() == (batch_size, chout, f_bins * stride, t)
        assert y.size() == (batch_size, chin, f_bins, t)

    def test_encoder_output_shape_time(self):
        r"""Feed tensors with specific shape to HDemucs Decoder and validate
        that it outputs with a tensor with expected shape for time domain.
        """
        batch_size = 1
        chin, chout = 4, 48
        t = 800
        stride = 4

        model = _HEncLayer(chin, chout, freq=False).to(self.device).eval()

        x = torch.rand(batch_size, chin, t, device=self.device, dtype=self.dtype)
        out = model(x)

        assert out.size() == (batch_size, chout, t / stride)

    def test_decoder_output_shape_time(self):
        r"""Feed tensors with specific shape to HDemucs Decoder and validate
        that it outputs with a tensor with expected shape for time domain.
        """
        batch_size = 1
        chin, chout = 96, 48
        t = 800
        stride = 4

        model = _HDecLayer(chin, chout, freq=False).to(self.device).eval()

        x = torch.rand(batch_size, chin, t, device=self.device, dtype=self.dtype)
        skip = torch.rand(batch_size, chin, t, device=self.device, dtype=self.dtype)
        z, y = model(x, skip, t * stride)

        assert z.size() == (batch_size, chout, t * stride)
        assert y.size() == (batch_size, chin, t)


@skipIfNoModule("demucs")
class CompareHDemucsOriginal(TorchaudioTestCase):
    """Test the process of importing the models from demucs.

    Test methods in this test suite will check to assure correctness in factory functions,
    comparing with original hybrid demucs
    """

    def _get_original_model(self, sources: List[str], nfft: int, depth: int):
        from demucs import hdemucs as original

        original = original.HDemucs(sources, nfft=nfft, depth=depth)
        return original

    def _assert_equal_models(self, factory_hdemucs, depth, nfft, sample_rate, sources):
        torch.random.manual_seed(0)
        original_hdemucs = self._get_original_model(sources, nfft, depth).to(self.device).eval()
        inputs = _get_inputs(sample_rate=sample_rate, device=self.device)
        factory_output = factory_hdemucs(inputs)
        original_output = original_hdemucs(inputs)
        self.assertEqual(original_output, factory_output)

    @SOURCES_OUTPUT_CONFIG
    def test_import_recreate_low_model(self, sources):
        sample_rate = 8000
        nfft = 1024
        depth = 5

        torch.random.manual_seed(0)
        factory_hdemucs = hdemucs_low(sources).to(self.device).eval()
        self._assert_equal_models(factory_hdemucs, depth, nfft, sample_rate, sources)

    @SOURCES_OUTPUT_CONFIG
    def test_import_recreate_high_model(self, sources):
        sample_rate = 44100
        nfft = 4096
        depth = 6

        torch.random.manual_seed(0)
        factory_hdemucs = hdemucs_high(sources).to(self.device).eval()
        self._assert_equal_models(factory_hdemucs, depth, nfft, sample_rate, sources)