1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
|
from collections import namedtuple
from functools import partial
from typing import List
import torch
from torchaudio.datasets import LibriMix
from . import wsj0mix
Batch = namedtuple("Batch", ["mix", "src", "mask"])
def get_dataset(dataset_type, root_dir, num_speakers, sample_rate, task=None, librimix_tr_split=None):
if dataset_type == "wsj0mix":
train = wsj0mix.WSJ0Mix(root_dir / "tr", num_speakers, sample_rate)
validation = wsj0mix.WSJ0Mix(root_dir / "cv", num_speakers, sample_rate)
evaluation = wsj0mix.WSJ0Mix(root_dir / "tt", num_speakers, sample_rate)
elif dataset_type == "librimix":
train = LibriMix(root_dir, librimix_tr_split, num_speakers, sample_rate, task)
validation = LibriMix(root_dir, "dev", num_speakers, sample_rate, task)
evaluation = LibriMix(root_dir, "test", num_speakers, sample_rate, task)
else:
raise ValueError(f"Unexpected dataset: {dataset_type}")
return train, validation, evaluation
def _fix_num_frames(sample: wsj0mix.SampleType, target_num_frames: int, sample_rate: int, random_start=False):
"""Ensure waveform has exact number of frames by slicing or padding"""
mix = sample[1] # [1, time]
src = torch.cat(sample[2], 0) # [num_sources, time]
num_channels, num_frames = src.shape
num_seconds = torch.div(num_frames, sample_rate, rounding_mode="floor")
target_seconds = torch.div(target_num_frames, sample_rate, rounding_mode="floor")
if num_frames >= target_num_frames:
if random_start and num_frames > target_num_frames:
start_frame = torch.randint(num_seconds - target_seconds + 1, [1]) * sample_rate
mix = mix[:, start_frame:]
src = src[:, start_frame:]
mix = mix[:, :target_num_frames]
src = src[:, :target_num_frames]
mask = torch.ones_like(mix)
else:
num_padding = target_num_frames - num_frames
pad = torch.zeros([1, num_padding], dtype=mix.dtype, device=mix.device)
mix = torch.cat([mix, pad], 1)
src = torch.cat([src, pad.expand(num_channels, -1)], 1)
mask = torch.ones_like(mix)
mask[..., num_frames:] = 0
return mix, src, mask
def collate_fn_wsj0mix_train(samples: List[wsj0mix.SampleType], sample_rate, duration):
target_num_frames = int(duration * sample_rate)
mixes, srcs, masks = [], [], []
for sample in samples:
mix, src, mask = _fix_num_frames(sample, target_num_frames, sample_rate, random_start=True)
mixes.append(mix)
srcs.append(src)
masks.append(mask)
return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0))
def collate_fn_wsj0mix_test(samples: List[wsj0mix.SampleType], sample_rate):
max_num_frames = max(s[1].shape[-1] for s in samples)
mixes, srcs, masks = [], [], []
for sample in samples:
mix, src, mask = _fix_num_frames(sample, max_num_frames, sample_rate, random_start=False)
mixes.append(mix)
srcs.append(src)
masks.append(mask)
return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0), torch.stack(masks, 0))
def get_collate_fn(dataset_type, mode, sample_rate=None, duration=4):
assert mode in ["train", "test"]
if dataset_type in ["wsj0mix", "librimix"]:
if mode == "train":
if sample_rate is None:
raise ValueError("sample_rate is not given.")
return partial(collate_fn_wsj0mix_train, sample_rate=sample_rate, duration=duration)
return partial(collate_fn_wsj0mix_test, sample_rate=sample_rate)
raise ValueError(f"Unexpected dataset: {dataset_type}")
|