File: eval.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (90 lines) | stat: -rw-r--r-- 2,735 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import logging
import pathlib
from argparse import ArgumentParser

import ci_sdr

import torch
from datamodule import L3DAS22DataModule
from model import DNNBeamformer
from pesq import pesq
from pystoi import stoi

logger = logging.getLogger()


def run_eval(args):
    model = DNNBeamformer()
    checkpoint = torch.load(args.checkpoint_path)
    new_state_dict = {}
    for k in checkpoint["state_dict"].keys():
        if "loss" not in k:
            new_state_dict[k.replace("model.", "")] = checkpoint["state_dict"][k]
    model.load_state_dict(new_state_dict)
    model.eval()
    data_module = L3DAS22DataModule(dataset_path=args.dataset_path, batch_size=args.batch_size)
    if args.use_cuda:
        model = model.to(device="cuda")
    CI_SDR = 0.0
    STOI = 0.0
    PESQ = 0
    with torch.no_grad():
        for idx, batch in enumerate(data_module.test_dataloader()):
            mixture, clean, _, _ = batch
            if args.use_cuda:
                mixture = mixture.cuda()
            clean = clean[0]
            estimate = model(mixture).cpu()
            ci_sdr_v = (
                -ci_sdr.pt.ci_sdr(estimate, clean, compute_permutation=False, filter_length=512, change_sign=False)
                .mean()
                .item()
            )
            clean, estimate = clean[0].numpy(), estimate[0].numpy()
            stoi_v = stoi(clean, estimate, 16000, extended=False)
            pesq_v = pesq(16000, clean, estimate, "wb")
            CI_SDR += (1.0 / float(idx + 1)) * (ci_sdr_v - CI_SDR)
            STOI += (1.0 / float(idx + 1)) * (stoi_v - STOI)
            PESQ += (1.0 / float(idx + 1)) * (pesq_v - PESQ)
            if idx % 100 == 0:
                logger.warning(f"Processed elem {idx}; Ci-SDR: {CI_SDR}, stoi: {STOI}, pesq: {PESQ}")

        # visualize and save results
        results = {"Ci-SDR": CI_SDR, "stoi": STOI, "pesq": PESQ}
        print("*******************************")
        print("RESULTS")
        for i in results:
            print(i, results[i])


def cli_main():
    parser = ArgumentParser()
    parser.add_argument(
        "--checkpoint-path",
        type=pathlib.Path,
        required=True,
        help="Path to checkpoint to use for evaluation.",
    )
    parser.add_argument(
        "--dataset-path",
        type=pathlib.Path,
        help="Path to L3DAS22 datasets.",
    )
    parser.add_argument(
        "--batch_size",
        default=4,
        type=int,
        help="Batch size for training. (Default: 4)",
    )
    parser.add_argument(
        "--use-cuda",
        action="store_true",
        default=False,
        help="Run using CUDA.",
    )
    args = parser.parse_args()
    run_eval(args)


if __name__ == "__main__":
    cli_main()