1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
|
import torch
from pytorch_lightning import LightningDataModule
from torchaudio.datasets.librispeech import LIBRISPEECH
from ._utils import BucketizeBatchSampler, CollateFnWav2Vec2, DistributedBatchSampler
class Wav2Vec2DataModule(LightningDataModule):
librispeech_cls = LIBRISPEECH
def __init__(
self,
*,
dataset_path,
seconds_per_batch,
train_shuffle=True,
num_workers=10,
):
super().__init__()
self.dataset_path = dataset_path
self.seconds_per_batch = seconds_per_batch
self.train_shuffle = train_shuffle
self.num_workers = num_workers
def train_dataloader(self):
dataset = torch.utils.data.ConcatDataset(
[
self.librispeech_cls(self.dataset_path, url="train-clean-360"),
self.librispeech_cls(self.dataset_path, url="train-clean-100"),
self.librispeech_cls(self.dataset_path, url="train-other-500"),
]
)
len_list = [d[0].size(1) for d in dataset]
sampler = BucketizeBatchSampler(
len_list,
num_buckets=10000,
max_token_count=self.seconds_per_batch * 16000,
min_len=32000,
max_len=250000,
shuffle=True,
)
sampler = DistributedBatchSampler(sampler, shuffle=self.train_shuffle)
sampler.set_epoch(self.trainer.current_epoch)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnWav2Vec2(pad=False, rand_crop=True),
num_workers=self.num_workers,
)
return dataloader
def val_dataloader(self):
dataset = torch.utils.data.ConcatDataset(
[
self.librispeech_cls(self.librispeech_path, url="dev-clean"),
self.librispeech_cls(self.librispeech_path, url="dev-other"),
]
)
len_list = [d[0].size(1) for d in dataset]
sampler = BucketizeBatchSampler(
len_list,
num_buckets=1000,
max_token_count=self.seconds_per_batch * 16000,
min_len=32000,
max_len=250000,
shuffle=False,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnWav2Vec2(pad=False, rand_crop=True),
num_workers=self.num_workers,
)
return dataloader
|