1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
|
import torch
from pytorch_lightning import LightningDataModule
from ._utils import BucketizeBatchSampler, CollateFnHubert, DistributedBatchSampler, HuBERTDataSet
class HuBERTDataModule(LightningDataModule):
hubert_cls = HuBERTDataSet
def __init__(
self,
*,
dataset_path,
dataset,
feature_type,
seconds_per_batch,
train_shuffle=True,
num_workers=10,
):
super().__init__()
self.dataset_path = dataset_path
self.dataset = dataset
self.feature_type = feature_type
self.seconds_per_batch = seconds_per_batch
self.train_shuffle = train_shuffle
self.num_workers = num_workers
def train_dataloader(self):
dataset = self.hubert_cls(self.dataset_path, self.dataset, "train")
sampler = BucketizeBatchSampler(
dataset.len_list,
num_buckets=10000,
max_token_count=self.seconds_per_batch * 16000,
min_len=32000,
max_len=250000,
shuffle=True,
seed=self.trainer.current_epoch,
)
sampler = DistributedBatchSampler(sampler, shuffle=self.train_shuffle)
sampler.set_epoch(self.trainer.current_epoch)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True),
num_workers=self.num_workers,
)
return dataloader
def val_dataloader(self):
dataset = self.hubert_cls(self.dataset_path, self.dataset, "valid")
sampler = BucketizeBatchSampler(
dataset.len_list,
num_buckets=1000,
max_token_count=self.seconds_per_batch * 16000,
min_len=32000,
max_len=250000,
shuffle=False,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True),
num_workers=self.num_workers,
)
return dataloader
|