File: model.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (72 lines) | stat: -rw-r--r-- 3,023 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import ci_sdr
import lightning.pytorch as pl
import torch
from asteroid.losses.stoi import NegSTOILoss
from asteroid.masknn import TDConvNet
from torchaudio.transforms import InverseSpectrogram, PSD, SoudenMVDR, Spectrogram


class DNNBeamformer(torch.nn.Module):
    def __init__(self, n_fft: int = 1024, hop_length: int = 256, ref_channel: int = 0):
        super().__init__()
        self.stft = Spectrogram(n_fft=n_fft, hop_length=hop_length, power=None)
        self.istft = InverseSpectrogram(n_fft=n_fft, hop_length=hop_length)
        self.mask_net = TDConvNet(
            n_fft // 2 + 1,
            2,
            out_chan=n_fft // 2 + 1,
            causal=False,
            mask_act="linear",
            norm_type="gLN",
        )
        self.beamformer = SoudenMVDR()
        self.psd = PSD()
        self.ref_channel = ref_channel

    def forward(self, mixture) -> torch.Tensor:
        spectrum = self.stft(mixture)  # (batch, channel, time, freq)
        batch, _, freq, time = spectrum.shape
        input_feature = torch.log(spectrum[:, self.ref_channel].abs() + 1e-8)  # (batch, freq, time)
        mask = torch.nn.functional.relu(self.mask_net(input_feature))  # (batch, 2, freq, time)
        mask_speech = mask[:, 0]
        mask_noise = mask[:, 1]
        psd_speech = self.psd(spectrum, mask_speech)
        psd_noise = self.psd(spectrum, mask_noise)
        enhanced_stft = self.beamformer(spectrum, psd_speech, psd_noise, self.ref_channel)
        enhanced_waveform = self.istft(enhanced_stft, length=mixture.shape[-1])
        return enhanced_waveform


class DNNBeamformerLightningModule(pl.LightningModule):
    def __init__(self, model: torch.nn.Module):
        super(DNNBeamformerLightningModule, self).__init__()
        self.model = model
        self.loss_stoi = NegSTOILoss(16000)

    def training_step(self, batch, batch_idx):
        mixture, clean = batch
        estimate = self.model(mixture)
        loss_cisdr = ci_sdr.pt.ci_sdr_loss(estimate, clean, compute_permutation=False, filter_length=512).mean()
        loss_stoi = self.loss_stoi(estimate, clean).mean()
        loss = loss_cisdr + loss_stoi * 10
        self.log("train/loss_cisdr", loss_cisdr.item())
        self.log("train/loss_stoi", loss_stoi.item())
        self.log("train/loss", loss.item())
        return loss

    def validation_step(self, batch, batch_idx):
        mixture, clean = batch
        estimate = self.model(mixture)
        loss_cisdr = ci_sdr.pt.ci_sdr_loss(estimate, clean, compute_permutation=False, filter_length=512).mean()
        loss_stoi = self.loss_stoi(estimate, clean).mean()
        loss = loss_cisdr + loss_stoi * 10
        self.log("val/loss_cisdr", loss_cisdr.item())
        self.log("val/loss_stoi", loss_stoi.item())
        self.log("val/loss", loss.item())
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001, weight_decay=1e-8)
        return {
            "optimizer": optimizer,
        }