File: tacotron2_loss_impl.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (96 lines) | stat: -rw-r--r-- 3,491 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
from pipeline_tacotron2.loss import Tacotron2Loss
from torch.autograd import gradcheck, gradgradcheck
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script


class Tacotron2LossInputMixin(TestBaseMixin):
    def _get_inputs(self, n_mel=80, n_batch=16, max_mel_specgram_length=300):
        mel_specgram = torch.rand(n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device)
        mel_specgram_postnet = torch.rand(n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device)
        gate_out = torch.rand(n_batch, dtype=self.dtype, device=self.device)
        truth_mel_specgram = torch.rand(n_batch, n_mel, max_mel_specgram_length, dtype=self.dtype, device=self.device)
        truth_gate_out = torch.rand(n_batch, dtype=self.dtype, device=self.device)

        truth_mel_specgram.requires_grad = False
        truth_gate_out.requires_grad = False

        return (
            mel_specgram,
            mel_specgram_postnet,
            gate_out,
            truth_mel_specgram,
            truth_gate_out,
        )


class Tacotron2LossShapeTests(Tacotron2LossInputMixin):
    def test_tacotron2_loss_shape(self):
        """Validate the output shape of Tacotron2Loss."""
        n_batch = 16

        (
            mel_specgram,
            mel_specgram_postnet,
            gate_out,
            truth_mel_specgram,
            truth_gate_out,
        ) = self._get_inputs(n_batch=n_batch)

        mel_loss, mel_postnet_loss, gate_loss = Tacotron2Loss()(
            (mel_specgram, mel_specgram_postnet, gate_out), (truth_mel_specgram, truth_gate_out)
        )

        self.assertEqual(mel_loss.size(), torch.Size([]))
        self.assertEqual(mel_postnet_loss.size(), torch.Size([]))
        self.assertEqual(gate_loss.size(), torch.Size([]))


class Tacotron2LossTorchscriptTests(Tacotron2LossInputMixin):
    def _assert_torchscript_consistency(self, fn, tensors):
        ts_func = torch_script(fn)

        output = fn(tensors[:3], tensors[3:])
        ts_output = ts_func(tensors[:3], tensors[3:])

        self.assertEqual(ts_output, output)

    def test_tacotron2_loss_torchscript_consistency(self):
        """Validate the torchscript consistency of Tacotron2Loss."""

        loss_fn = Tacotron2Loss()
        self._assert_torchscript_consistency(loss_fn, self._get_inputs())


class Tacotron2LossGradcheckTests(Tacotron2LossInputMixin):
    def test_tacotron2_loss_gradcheck(self):
        """Performing gradient check on Tacotron2Loss."""
        (
            mel_specgram,
            mel_specgram_postnet,
            gate_out,
            truth_mel_specgram,
            truth_gate_out,
        ) = self._get_inputs()

        mel_specgram.requires_grad_(True)
        mel_specgram_postnet.requires_grad_(True)
        gate_out.requires_grad_(True)

        def _fn(mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out):
            loss_fn = Tacotron2Loss()
            return loss_fn(
                (mel_specgram, mel_specgram_postnet, gate_out),
                (truth_mel_specgram, truth_gate_out),
            )

        gradcheck(
            _fn,
            (mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out),
            fast_mode=True,
        )
        gradgradcheck(
            _fn,
            (mel_specgram, mel_specgram_postnet, gate_out, truth_mel_specgram, truth_gate_out),
            fast_mode=True,
        )