1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
|
import math
import torch
from torch import nn as nn
from torch.nn import functional as F
class LongCrossEntropyLoss(nn.Module):
r"""CrossEntropy loss"""
def __init__(self):
super(LongCrossEntropyLoss, self).__init__()
def forward(self, output, target):
output = output.transpose(1, 2)
target = target.long()
criterion = nn.CrossEntropyLoss()
return criterion(output, target)
class MoLLoss(nn.Module):
r"""Discretized mixture of logistic distributions loss
Adapted from wavenet vocoder
(https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py)
Explanation of loss (https://github.com/Rayhane-mamah/Tacotron-2/issues/155)
Args:
y_hat (Tensor): Predicted output (n_batch x n_time x n_channel)
y (Tensor): Target (n_batch x n_time x 1)
num_classes (int): Number of classes
log_scale_min (float): Log scale minimum value
reduce (bool): If True, the losses are averaged or summed for each minibatch
Returns
Tensor: loss
"""
def __init__(self, num_classes=65536, log_scale_min=None, reduce=True):
super(MoLLoss, self).__init__()
self.num_classes = num_classes
self.log_scale_min = log_scale_min
self.reduce = reduce
def forward(self, y_hat, y):
y = y.unsqueeze(-1)
if self.log_scale_min is None:
self.log_scale_min = math.log(1e-14)
assert y_hat.dim() == 3
assert y_hat.size(-1) % 3 == 0
nr_mix = y_hat.size(-1) // 3
# unpack parameters (n_batch, n_time, num_mixtures) x 3
logit_probs = y_hat[:, :, :nr_mix]
means = y_hat[:, :, nr_mix : 2 * nr_mix]
log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=self.log_scale_min)
# (n_batch x n_time x 1) to (n_batch x n_time x num_mixtures)
y = y.expand_as(means)
centered_y = y - means
inv_stdv = torch.exp(-log_scales)
plus_in = inv_stdv * (centered_y + 1.0 / (self.num_classes - 1))
cdf_plus = torch.sigmoid(plus_in)
min_in = inv_stdv * (centered_y - 1.0 / (self.num_classes - 1))
cdf_min = torch.sigmoid(min_in)
# log probability for edge case of 0 (before scaling)
# equivalent: torch.log(F.sigmoid(plus_in))
log_cdf_plus = plus_in - F.softplus(plus_in)
# log probability for edge case of 255 (before scaling)
# equivalent: (1 - F.sigmoid(min_in)).log()
log_one_minus_cdf_min = -F.softplus(min_in)
# probability for all other cases
cdf_delta = cdf_plus - cdf_min
mid_in = inv_stdv * centered_y
# log probability in the center of the bin, to be used in extreme cases
log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in)
inner_inner_cond = (cdf_delta > 1e-5).float()
inner_inner_out = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1.0 - inner_inner_cond) * (
log_pdf_mid - math.log((self.num_classes - 1) / 2)
)
inner_cond = (y > 0.999).float()
inner_out = inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
cond = (y < -0.999).float()
log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out
log_probs = log_probs + F.log_softmax(logit_probs, -1)
if self.reduce:
return -torch.mean(_log_sum_exp(log_probs))
else:
return -_log_sum_exp(log_probs).unsqueeze(-1)
def _log_sum_exp(x):
r"""Numerically stable log_sum_exp implementation that prevents overflow"""
axis = len(x.size()) - 1
m, _ = torch.max(x, dim=axis)
m2, _ = torch.max(x, dim=axis, keepdim=True)
return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
|