File: losses.py

package info (click to toggle)
pytorch-audio 0.13.1-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 8,592 kB
  • sloc: python: 41,137; cpp: 8,016; sh: 3,538; makefile: 24
file content (111 lines) | stat: -rw-r--r-- 3,818 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import math

import torch
from torch import nn as nn
from torch.nn import functional as F


class LongCrossEntropyLoss(nn.Module):
    r"""CrossEntropy loss"""

    def __init__(self):
        super(LongCrossEntropyLoss, self).__init__()

    def forward(self, output, target):
        output = output.transpose(1, 2)
        target = target.long()

        criterion = nn.CrossEntropyLoss()
        return criterion(output, target)


class MoLLoss(nn.Module):
    r"""Discretized mixture of logistic distributions loss

    Adapted from wavenet vocoder
    (https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py)
    Explanation of loss (https://github.com/Rayhane-mamah/Tacotron-2/issues/155)

    Args:
        y_hat (Tensor): Predicted output (n_batch x n_time x n_channel)
        y (Tensor): Target (n_batch x n_time x 1)
        num_classes (int): Number of classes
        log_scale_min (float): Log scale minimum value
        reduce (bool): If True, the losses are averaged or summed for each minibatch

    Returns
        Tensor: loss
    """

    def __init__(self, num_classes=65536, log_scale_min=None, reduce=True):
        super(MoLLoss, self).__init__()
        self.num_classes = num_classes
        self.log_scale_min = log_scale_min
        self.reduce = reduce

    def forward(self, y_hat, y):
        y = y.unsqueeze(-1)

        if self.log_scale_min is None:
            self.log_scale_min = math.log(1e-14)

        assert y_hat.dim() == 3
        assert y_hat.size(-1) % 3 == 0

        nr_mix = y_hat.size(-1) // 3

        # unpack parameters (n_batch, n_time, num_mixtures) x 3
        logit_probs = y_hat[:, :, :nr_mix]
        means = y_hat[:, :, nr_mix : 2 * nr_mix]
        log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=self.log_scale_min)

        # (n_batch x n_time x 1) to (n_batch x n_time x num_mixtures)
        y = y.expand_as(means)

        centered_y = y - means
        inv_stdv = torch.exp(-log_scales)
        plus_in = inv_stdv * (centered_y + 1.0 / (self.num_classes - 1))
        cdf_plus = torch.sigmoid(plus_in)
        min_in = inv_stdv * (centered_y - 1.0 / (self.num_classes - 1))
        cdf_min = torch.sigmoid(min_in)

        # log probability for edge case of 0 (before scaling)
        # equivalent: torch.log(F.sigmoid(plus_in))
        log_cdf_plus = plus_in - F.softplus(plus_in)

        # log probability for edge case of 255 (before scaling)
        # equivalent: (1 - F.sigmoid(min_in)).log()
        log_one_minus_cdf_min = -F.softplus(min_in)

        # probability for all other cases
        cdf_delta = cdf_plus - cdf_min

        mid_in = inv_stdv * centered_y
        # log probability in the center of the bin, to be used in extreme cases
        log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in)

        inner_inner_cond = (cdf_delta > 1e-5).float()

        inner_inner_out = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1.0 - inner_inner_cond) * (
            log_pdf_mid - math.log((self.num_classes - 1) / 2)
        )
        inner_cond = (y > 0.999).float()
        inner_out = inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
        cond = (y < -0.999).float()
        log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out

        log_probs = log_probs + F.log_softmax(logit_probs, -1)

        if self.reduce:
            return -torch.mean(_log_sum_exp(log_probs))
        else:
            return -_log_sum_exp(log_probs).unsqueeze(-1)


def _log_sum_exp(x):
    r"""Numerically stable log_sum_exp implementation that prevents overflow"""

    axis = len(x.size()) - 1
    m, _ = torch.max(x, dim=axis)
    m2, _ = torch.max(x, dim=axis, keepdim=True)
    return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))