1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
|
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor
def hubert_loss(
logit_m: Optional[Tensor],
logit_u: Optional[Tensor],
feature_penalty: Tensor,
label: Optional[Tensor] = None,
masked_weight: float = 1.0,
unmasked_weight: float = 0.0,
feature_weight: float = 10.0,
reduction: str = "sum",
) -> Tuple[Tensor, float]:
"""Compute the cross-entropy loss on HuBERT masked and non-masked logits.
Args:
logit_m (Tensor or None): The masked logit Tensor of dimension `(masked_frames, final_dim)`.
logit_u (Tensor or None): The non-masked logit Tensor of dimension `(unmasked_frames, final_dim)`.
feature_penalty (Tensor): The feature mean value for additional penalty loss.
masked_weight (float, optional): The weight for masked cross-entropy loss (Default: ``1.0``).
unmasked_weight (float, optional): The weight for non-masked cross-entropy loss (Default: ``0.0``).
feature_weight (float, optional): The weight for feature penalty loss (Default: ``10.0``).
reduction (str, optional): The reduction method for cross-entropy loss (Default: ``"sum"``).
Returns:
(Tensor, float)
Tensor: The desired loss Tensor.
float: Number of frames used in loss computation.
"""
num_frame = 0.0
loss = 0.0
if logit_m is not None:
target_m = torch.zeros(logit_m.shape[0], dtype=torch.long, device=logit_m.device)
loss_m = F.cross_entropy(logit_m, target_m, reduction=reduction)
loss += loss_m * masked_weight
num_frame += logit_m.shape[0]
if logit_u is not None:
target_u = torch.zeros(logit_u.shape[0], dtype=torch.long, device=logit_m.device)
loss_u = F.cross_entropy(logit_u, target_u, reduction=reduction)
loss += loss_u * unmasked_weight
num_frame += logit_u.shape[0]
loss += feature_penalty * feature_weight * num_frame
return loss, num_frame
|