File: _hubert_loss.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (47 lines) | stat: -rw-r--r-- 2,002 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor


def hubert_loss(
    logit_m: Optional[Tensor],
    logit_u: Optional[Tensor],
    feature_penalty: Tensor,
    label: Optional[Tensor] = None,
    masked_weight: float = 1.0,
    unmasked_weight: float = 0.0,
    feature_weight: float = 10.0,
    reduction: str = "sum",
) -> Tuple[Tensor, float]:
    """Compute the cross-entropy loss on HuBERT masked and non-masked logits.

    Args:
        logit_m (Tensor or None): The masked logit Tensor of dimension `(masked_frames, final_dim)`.
        logit_u (Tensor or None): The non-masked logit Tensor of dimension `(unmasked_frames, final_dim)`.
        feature_penalty (Tensor): The feature mean value for additional penalty loss.
        masked_weight (float, optional): The weight for masked cross-entropy loss (Default: ``1.0``).
        unmasked_weight (float, optional): The weight for non-masked cross-entropy loss (Default: ``0.0``).
        feature_weight (float, optional): The weight for feature penalty loss (Default: ``10.0``).
        reduction (str, optional): The reduction method for cross-entropy loss (Default: ``"sum"``).

    Returns:
        (Tensor, float)
        Tensor: The desired loss Tensor.
        float: Number of frames used in loss computation.
    """
    num_frame = 0.0
    loss = 0.0
    if logit_m is not None:
        target_m = torch.zeros(logit_m.shape[0], dtype=torch.long, device=logit_m.device)
        loss_m = F.cross_entropy(logit_m, target_m, reduction=reduction)
        loss += loss_m * masked_weight
        num_frame += logit_m.shape[0]
    if logit_u is not None:
        target_u = torch.zeros(logit_u.shape[0], dtype=torch.long, device=logit_m.device)
        loss_u = F.cross_entropy(logit_u, target_u, reduction=reduction)
        loss += loss_u * unmasked_weight
        num_frame += logit_u.shape[0]
    loss += feature_penalty * feature_weight * num_frame
    return loss, num_frame