1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
|
from pathlib import Path
from typing import Tuple, Union
import lightning.pytorch as pl
import torch
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from utils import CollateFnL3DAS22
_PREFIX = "L3DAS22_Task1_"
_SUBSETS = {
"train360": ["train360_1", "train360_2"],
"train100": ["train100"],
"dev": ["dev"],
"test": ["test"],
}
_SAMPLE_RATE = 16000
class L3DAS22(Dataset):
def __init__(
self,
root: Union[str, Path],
subset: str = "train360",
min_len: int = 64000,
):
self._walker = []
if subset not in _SUBSETS:
raise ValueError(f"Expect subset to be one of ('train360', 'train100', 'dev', 'test'). Found {subset}.")
for sub_dir in _SUBSETS[subset]:
path = Path(root) / f"{_PREFIX}{sub_dir}" / "data"
files = [str(p) for p in path.glob("*_A.wav") if torchaudio.info(p).num_frames >= min_len]
if len(files) == 0:
raise RuntimeError(
f"Directory {path} is not found. Please check if the zip file has been downloaded and extracted."
)
self._walker += files
def __len__(self):
return len(self._walker)
def __getitem__(self, n: int) -> Tuple[Tensor, Tensor, int, str]:
noisy_path_A = Path(self._walker[n])
noisy_path_B = str(noisy_path_A).replace("_A.wav", "_B.wav")
clean_path = noisy_path_A.parent.parent / "labels" / noisy_path_A.name.replace("_A.wav", ".wav")
transcript_path = str(clean_path).replace("wav", "txt")
waveform_noisy_A, sample_rate1 = torchaudio.load(noisy_path_A)
waveform_noisy_B, sample_rate2 = torchaudio.load(noisy_path_B)
waveform_noisy = torch.cat((waveform_noisy_A, waveform_noisy_B), dim=0)
waveform_clean, sample_rate3 = torchaudio.load(clean_path)
assert sample_rate1 == _SAMPLE_RATE and sample_rate2 == _SAMPLE_RATE and sample_rate3 == _SAMPLE_RATE
with open(transcript_path, "r") as f:
transcript = f.readline()
return waveform_noisy, waveform_clean, _SAMPLE_RATE, transcript
class L3DAS22DataModule(pl.LightningDataModule):
def __init__(
self,
dataset_path: str,
batch_size: int,
):
super().__init__()
self.dataset_path = dataset_path
self.batch_size = batch_size
def train_dataloader(self):
dataset = torch.utils.data.ConcatDataset(
[
L3DAS22(self.dataset_path, "train360"),
L3DAS22(self.dataset_path, "train100"),
]
)
return torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
collate_fn=CollateFnL3DAS22(audio_length=64000, rand_crop=True),
shuffle=True,
num_workers=20,
)
def val_dataloader(self):
dataset = L3DAS22(self.dataset_path, "dev")
return torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
collate_fn=CollateFnL3DAS22(audio_length=64000, rand_crop=True),
shuffle=False,
num_workers=1,
)
def test_dataloader(self):
dataset = L3DAS22(self.dataset_path, "test", min_len=0)
return torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=1,
)
|