File: eval.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (95 lines) | stat: -rw-r--r-- 3,357 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from argparse import ArgumentParser
from pathlib import Path

import mir_eval
import torch
from lightning_train import _get_dataloader, _get_model, sisdri_metric


def _eval(model, data_loader, device):
    results = torch.zeros(4)
    with torch.no_grad():
        for _, batch in enumerate(data_loader):
            mix, src, mask = batch
            mix, src, mask = mix.to(device), src.to(device), mask.to(device)
            est = model(mix)
            sisdri = sisdri_metric(est, src, mix, mask)
            src = src.cpu().detach().numpy()
            est = est.cpu().detach().numpy()
            mix = mix.repeat(1, src.shape[1], 1).cpu().detach().numpy()
            sdr, sir, sar, _ = mir_eval.separation.bss_eval_sources(src[0], est[0])
            sdr_mix, sir_mix, sar_mix, _ = mir_eval.separation.bss_eval_sources(src[0], mix[0])
            results += torch.tensor(
                [sdr.mean() - sdr_mix.mean(), sisdri, sir.mean() - sir_mix.mean(), sar.mean() - sar_mix.mean()]
            )
    results /= len(data_loader)
    print("SDR improvement: ", results[0].item())
    print("Si-SDR improvement: ", results[1].item())
    print("SIR improvement: ", results[2].item())
    print("SAR improvement: ", results[3].item())


def cli_main():
    parser = ArgumentParser()
    parser.add_argument("--dataset", default="librimix", type=str, choices=["wsj0mix", "librimix"])
    parser.add_argument(
        "--root-dir",
        type=Path,
        help="The path to the directory where the directory ``Libri2Mix`` or ``Libri3Mix`` is stored.",
    )
    parser.add_argument(
        "--librimix-tr-split",
        default="train-360",
        choices=["train-360", "train-100"],
        help="The training partition of librimix dataset. (default: ``train-360``)",
    )
    parser.add_argument(
        "--librimix-task",
        default="sep_clean",
        type=str,
        choices=["sep_clean", "sep_noisy", "enh_single", "enh_both"],
        help="The task to perform (separation or enhancement, noisy or clean). (default: ``sep_clean``)",
    )
    parser.add_argument(
        "--num-speakers", default=2, type=int, help="The number of speakers in the mixture. (default: 2)"
    )
    parser.add_argument(
        "--sample-rate",
        default=8000,
        type=int,
        help="Sample rate of audio files in the given dataset. (default: 8000)",
    )
    parser.add_argument(
        "--exp-dir", default=Path("./exp"), type=Path, help="The directory to save checkpoints and logs."
    )
    parser.add_argument("--gpu-device", default=-1, type=int, help="The gpu device for model inference. (default: -1)")

    args = parser.parse_args()

    model = _get_model(num_sources=2)
    state_dict = torch.load(args.exp_dir / "best_model.pth")
    model.load_state_dict(state_dict)

    if args.gpu_device != -1:
        device = torch.device("cuda:" + str(args.gpu_device))
    else:
        device = torch.device("cpu")

    model = model.to(device)

    _, _, eval_loader = _get_dataloader(
        args.dataset,
        args.root_dir,
        args.num_speakers,
        args.sample_rate,
        1,  # batch size is set to 1 to avoid masking
        0,  # set num_workers to 0
        args.librimix_task,
        args.librimix_tr_split,
    )

    _eval(model, eval_loader, device)


if __name__ == "__main__":
    cli_main()