File: datamodule.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (101 lines) | stat: -rw-r--r-- 3,456 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from pathlib import Path
from typing import Tuple, Union

import lightning.pytorch as pl

import torch
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from utils import CollateFnL3DAS22

_PREFIX = "L3DAS22_Task1_"
_SUBSETS = {
    "train360": ["train360_1", "train360_2"],
    "train100": ["train100"],
    "dev": ["dev"],
    "test": ["test"],
}
_SAMPLE_RATE = 16000


class L3DAS22(Dataset):
    def __init__(
        self,
        root: Union[str, Path],
        subset: str = "train360",
        min_len: int = 64000,
    ):
        self._walker = []
        if subset not in _SUBSETS:
            raise ValueError(f"Expect subset to be one of ('train360', 'train100', 'dev', 'test'). Found {subset}.")
        for sub_dir in _SUBSETS[subset]:
            path = Path(root) / f"{_PREFIX}{sub_dir}" / "data"
            files = [str(p) for p in path.glob("*_A.wav") if torchaudio.info(p).num_frames >= min_len]
            if len(files) == 0:
                raise RuntimeError(
                    f"Directory {path} is not found. Please check if the zip file has been downloaded and extracted."
                )
            self._walker += files

    def __len__(self):
        return len(self._walker)

    def __getitem__(self, n: int) -> Tuple[Tensor, Tensor, int, str]:
        noisy_path_A = Path(self._walker[n])
        noisy_path_B = str(noisy_path_A).replace("_A.wav", "_B.wav")
        clean_path = noisy_path_A.parent.parent / "labels" / noisy_path_A.name.replace("_A.wav", ".wav")
        transcript_path = str(clean_path).replace("wav", "txt")
        waveform_noisy_A, sample_rate1 = torchaudio.load(noisy_path_A)
        waveform_noisy_B, sample_rate2 = torchaudio.load(noisy_path_B)
        waveform_noisy = torch.cat((waveform_noisy_A, waveform_noisy_B), dim=0)
        waveform_clean, sample_rate3 = torchaudio.load(clean_path)
        assert sample_rate1 == _SAMPLE_RATE and sample_rate2 == _SAMPLE_RATE and sample_rate3 == _SAMPLE_RATE
        with open(transcript_path, "r") as f:
            transcript = f.readline()
        return waveform_noisy, waveform_clean, _SAMPLE_RATE, transcript


class L3DAS22DataModule(pl.LightningDataModule):
    def __init__(
        self,
        dataset_path: str,
        batch_size: int,
    ):
        super().__init__()
        self.dataset_path = dataset_path
        self.batch_size = batch_size

    def train_dataloader(self):
        dataset = torch.utils.data.ConcatDataset(
            [
                L3DAS22(self.dataset_path, "train360"),
                L3DAS22(self.dataset_path, "train100"),
            ]
        )
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            collate_fn=CollateFnL3DAS22(audio_length=64000, rand_crop=True),
            shuffle=True,
            num_workers=20,
        )

    def val_dataloader(self):
        dataset = L3DAS22(self.dataset_path, "dev")
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            collate_fn=CollateFnL3DAS22(audio_length=64000, rand_crop=True),
            shuffle=False,
            num_workers=1,
        )

    def test_dataloader(self):
        dataset = L3DAS22(self.dataset_path, "test", min_len=0)
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=1,
            shuffle=False,
            num_workers=1,
        )