1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
|
import time
from collections import namedtuple
from typing import Tuple
import torch
import torch.distributed as dist
from utils import dist_utils, metrics
_LG = dist_utils.getLogger(__name__)
Metric = namedtuple("SNR", ["si_snri", "sdri"])
Metric.__str__ = lambda self: f"SI-SNRi: {self.si_snri:10.3e}, SDRi: {self.sdri:10.3e}"
def si_sdr_improvement(
estimate: torch.Tensor, reference: torch.Tensor, mix: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute the improvement of scale-invariant SDR. (SI-SNRi) and bare SDR (SDRi).
Args:
estimate (torch.Tensor): Estimated source signals.
Shape: [batch, speakers, time frame]
reference (torch.Tensor): Reference (original) source signals.
Shape: [batch, speakers, time frame]
mix (torch.Tensor): Mixed souce signals, from which the setimated signals were generated.
Shape: [batch, speakers == 1, time frame]
mask (torch.Tensor): Mask to indicate padded value (0) or valid value (1).
Shape: [batch, 1, time frame]
Returns:
torch.Tensor: Improved SI-SDR. Shape: [batch, ]
torch.Tensor: Absolute SI-SDR. Shape: [batch, ]
References:
- Conv-TasNet: Surpassing Ideal Time--Frequency Magnitude Masking for Speech Separation
Luo, Yi and Mesgarani, Nima
https://arxiv.org/abs/1809.07454
"""
with torch.no_grad():
sdri = metrics.sdri(estimate, reference, mix, mask=mask)
estimate = estimate - estimate.mean(axis=2, keepdim=True)
reference = reference - reference.mean(axis=2, keepdim=True)
mix = mix - mix.mean(axis=2, keepdim=True)
si_sdri = metrics.sdri(estimate, reference, mix, mask=mask)
return si_sdri, sdri
class OccasionalLogger:
"""Simple helper class to log once in a while or when progress is quick enough"""
def __init__(self, time_interval=180, progress_interval=0.1):
self.time_interval = time_interval
self.progress_interval = progress_interval
self.last_time = 0.0
self.last_progress = 0.0
def log(self, metric, progress, force=False):
now = time.monotonic()
if force or now > self.last_time + self.time_interval or progress > self.last_progress + self.progress_interval:
self.last_time = now
self.last_progress = progress
_LG.info_on_master("train: %s [%3d%%]", metric, 100 * progress)
class Trainer:
def __init__(
self,
model,
optimizer,
train_loader,
valid_loader,
eval_loader,
grad_clip,
device,
*,
debug,
):
self.model = model
self.optimizer = optimizer
self.train_loader = train_loader
self.valid_loader = valid_loader
self.eval_loader = eval_loader
self.grad_clip = grad_clip
self.device = device
self.debug = debug
def train_one_epoch(self):
self.model.train()
logger = OccasionalLogger()
num_batches = len(self.train_loader)
for i, batch in enumerate(self.train_loader, start=1):
mix = batch.mix.to(self.device)
src = batch.src.to(self.device)
mask = batch.mask.to(self.device)
estimate = self.model(mix)
si_snri, sdri = si_sdr_improvement(estimate, src, mix, mask)
si_snri = si_snri.mean()
sdri = sdri.mean()
loss = -si_snri
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip, norm_type=2.0)
self.optimizer.step()
metric = Metric(si_snri.item(), sdri.item())
logger.log(metric, progress=i / num_batches, force=i == num_batches)
if self.debug:
break
def evaluate(self):
with torch.no_grad():
return self._test(self.eval_loader)
def validate(self):
with torch.no_grad():
return self._test(self.valid_loader)
def _test(self, loader):
self.model.eval()
total_si_snri = torch.zeros(1, dtype=torch.float32, device=self.device)
total_sdri = torch.zeros(1, dtype=torch.float32, device=self.device)
for batch in loader:
mix = batch.mix.to(self.device)
src = batch.src.to(self.device)
mask = batch.mask.to(self.device)
estimate = self.model(mix)
si_snri, sdri = si_sdr_improvement(estimate, src, mix, mask)
total_si_snri += si_snri.sum()
total_sdri += sdri.sum()
if self.debug:
break
dist.all_reduce(total_si_snri, dist.ReduceOp.SUM)
dist.all_reduce(total_sdri, dist.ReduceOp.SUM)
num_samples = len(loader.dataset)
metric = Metric(total_si_snri.item() / num_samples, total_sdri.item() / num_samples)
return metric
|