1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
|
import math
from typing import Optional
from itertools import permutations
import torch
def sdr(
estimate: torch.Tensor,
reference: torch.Tensor,
mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8
) -> torch.Tensor:
"""Computes source-to-distortion ratio.
1. scale the reference signal with power(s_est * s_ref) / powr(s_ref * s_ref)
2. compute SNR between adjusted estimate and reference.
Args:
estimate (torch.Tensor): Estimtaed signal.
Shape: [batch, speakers (can be 1), time frame]
reference (torch.Tensor): Reference signal.
Shape: [batch, speakers, time frame]
mask (Optional[torch.Tensor]): Binary mask to indicate padded value (0) or valid value (1).
Shape: [batch, 1, time frame]
epsilon (float): constant value used to stabilize division.
Returns:
torch.Tensor: scale-invariant source-to-distortion ratio.
Shape: [batch, speaker]
References:
- Single-channel multi-speaker separation using deep clustering
Y. Isik, J. Le Roux, Z. Chen, S. Watanabe, and J. R. Hershey,
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
Notes:
This function is tested to produce the exact same result as
https://github.com/naplab/Conv-TasNet/blob/e66d82a8f956a69749ec8a4ae382217faa097c5c/utility/sdr.py#L34-L56
"""
reference_pow = reference.pow(2).mean(axis=2, keepdim=True)
mix_pow = (estimate * reference).mean(axis=2, keepdim=True)
scale = mix_pow / (reference_pow + epsilon)
reference = scale * reference
error = estimate - reference
reference_pow = reference.pow(2)
error_pow = error.pow(2)
if mask is None:
reference_pow = reference_pow.mean(axis=2)
error_pow = error_pow.mean(axis=2)
else:
denom = mask.sum(axis=2)
reference_pow = (mask * reference_pow).sum(axis=2) / denom
error_pow = (mask * error_pow).sum(axis=2) / denom
return 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)
class PIT(torch.nn.Module):
"""Applies utterance-level speaker permutation
Computes the maxium possible value of the given utility function
over the permutations of the speakers.
Args:
utility_func (function):
Function that computes the utility (opposite of loss) with signature of
(extimate: torch.Tensor, reference: torch.Tensor) -> torch.Tensor
where input Tensors are shape of [batch, speakers, frame] and
the output Tensor is shape of [batch, speakers].
References:
- Multi-talker Speech Separation with Utterance-level Permutation Invariant Training of
Deep Recurrent Neural Networks
Morten Kolbæk, Dong Yu, Zheng-Hua Tan and Jesper Jensen
https://arxiv.org/abs/1703.06284
"""
def __init__(self, utility_func):
super().__init__()
self.utility_func = utility_func
def forward(
self,
estimate: torch.Tensor,
reference: torch.Tensor,
mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8
) -> torch.Tensor:
"""Compute utterance-level PIT Loss
Args:
estimate (torch.Tensor): Estimated source signals.
Shape: [bacth, speakers, time frame]
reference (torch.Tensor): Reference (original) source signals.
Shape: [batch, speakers, time frame]
mask (Optional[torch.Tensor]): Binary mask to indicate padded value (0) or valid value (1).
Shape: [batch, 1, time frame]
epsilon (float): constant value used to stabilize division.
Returns:
torch.Tensor: Maximum criterion over the speaker permutation.
Shape: [batch, ]
"""
assert estimate.shape == reference.shape
batch_size, num_speakers = reference.shape[:2]
num_permute = math.factorial(num_speakers)
util_mat = torch.zeros(
batch_size, num_permute, dtype=estimate.dtype, device=estimate.device
)
for i, idx in enumerate(permutations(range(num_speakers))):
util = self.utility_func(estimate, reference[:, idx, :], mask=mask, epsilon=epsilon)
util_mat[:, i] = util.mean(dim=1) # take the average over speaker dimension
return util_mat.max(dim=1).values
_sdr_pit = PIT(utility_func=sdr)
def sdr_pit(
estimate: torch.Tensor,
reference: torch.Tensor,
mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8):
"""Computes scale-invariant source-to-distortion ratio.
1. adjust both estimate and reference to have 0-mean
2. scale the reference signal with power(s_est * s_ref) / powr(s_ref * s_ref)
3. compute SNR between adjusted estimate and reference.
Args:
estimate (torch.Tensor): Estimtaed signal.
Shape: [batch, speakers (can be 1), time frame]
reference (torch.Tensor): Reference signal.
Shape: [batch, speakers, time frame]
mask (Optional[torch.Tensor]): Binary mask to indicate padded value (0) or valid value (1).
Shape: [batch, 1, time frame]
epsilon (float): constant value used to stabilize division.
Returns:
torch.Tensor: scale-invariant source-to-distortion ratio.
Shape: [batch, speaker]
References:
- Single-channel multi-speaker separation using deep clustering
Y. Isik, J. Le Roux, Z. Chen, S. Watanabe, and J. R. Hershey,
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
Notes:
This function is tested to produce the exact same result as the reference implementation,
*when the inputs have 0-mean*
https://github.com/naplab/Conv-TasNet/blob/e66d82a8f956a69749ec8a4ae382217faa097c5c/utility/sdr.py#L107-L153
"""
return _sdr_pit(estimate, reference, mask, epsilon)
def sdri(
estimate: torch.Tensor,
reference: torch.Tensor,
mix: torch.Tensor,
mask: Optional[torch.Tensor] = None,
epsilon: float = 1e-8,
) -> torch.Tensor:
"""Compute the improvement of SDR (SDRi).
This function compute how much SDR is improved if the estimation is changed from
the original mixture signal to the actual estimated source signals. That is,
``SDR(estimate, reference) - SDR(mix, reference)``.
For computing ``SDR(estimate, reference)``, PIT (permutation invariant training) is applied,
so that best combination of sources between the reference signals and the esimate signals
are picked.
Args:
estimate (torch.Tensor): Estimated source signals.
Shape: [batch, speakers, time frame]
reference (torch.Tensor): Reference (original) source signals.
Shape: [batch, speakers, time frame]
mix (torch.Tensor): Mixed souce signals, from which the setimated signals were generated.
Shape: [batch, speakers == 1, time frame]
mask (Optional[torch.Tensor]): Binary mask to indicate padded value (0) or valid value (1).
Shape: [batch, 1, time frame]
epsilon (float): constant value used to stabilize division.
Returns:
torch.Tensor: Improved SDR. Shape: [batch, ]
References:
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
"""
sdr_ = sdr_pit(estimate, reference, mask=mask, epsilon=epsilon) # [batch, ]
base_sdr = sdr(mix, reference, mask=mask, epsilon=epsilon) # [batch, speaker]
return (sdr_.unsqueeze(1) - base_sdr).mean(dim=1)
|