File: _hubert_datamodule.py

package info (click to toggle)
pytorch-audio 2.9.1-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 108,884 kB
  • sloc: python: 44,403; cpp: 3,384; sh: 126; makefile: 32
file content (65 lines) | stat: -rw-r--r-- 2,171 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from pytorch_lightning import LightningDataModule

from ._utils import BucketizeBatchSampler, CollateFnHubert, DistributedBatchSampler, HuBERTDataSet


class HuBERTDataModule(LightningDataModule):
    hubert_cls = HuBERTDataSet

    def __init__(
        self,
        *,
        dataset_path,
        dataset,
        feature_type,
        seconds_per_batch,
        train_shuffle=True,
        num_workers=10,
    ):
        super().__init__()
        self.dataset_path = dataset_path
        self.dataset = dataset
        self.feature_type = feature_type
        self.seconds_per_batch = seconds_per_batch
        self.train_shuffle = train_shuffle
        self.num_workers = num_workers

    def train_dataloader(self):
        dataset = self.hubert_cls(self.dataset_path, self.dataset, "train")
        sampler = BucketizeBatchSampler(
            dataset.len_list,
            num_buckets=10000,
            max_token_count=self.seconds_per_batch * 16000,
            min_len=32000,
            max_len=250000,
            shuffle=True,
            seed=self.trainer.current_epoch,
        )
        sampler = DistributedBatchSampler(sampler, shuffle=self.train_shuffle)
        sampler.set_epoch(self.trainer.current_epoch)
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=sampler,
            collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True),
            num_workers=self.num_workers,
        )
        return dataloader

    def val_dataloader(self):
        dataset = self.hubert_cls(self.dataset_path, self.dataset, "valid")
        sampler = BucketizeBatchSampler(
            dataset.len_list,
            num_buckets=1000,
            max_token_count=self.seconds_per_batch * 16000,
            min_len=32000,
            max_len=250000,
            shuffle=False,
        )
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=sampler,
            collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True),
            num_workers=self.num_workers,
        )
        return dataloader