File: _wav2vec2_datamodule.py

package info (click to toggle)
pytorch-audio 2.9.1-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 108,884 kB
  • sloc: python: 44,403; cpp: 3,384; sh: 126; makefile: 32
file content (75 lines) | stat: -rw-r--r-- 2,510 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
from pytorch_lightning import LightningDataModule
from torchaudio.datasets.librispeech import LIBRISPEECH

from ._utils import BucketizeBatchSampler, CollateFnWav2Vec2, DistributedBatchSampler


class Wav2Vec2DataModule(LightningDataModule):
    librispeech_cls = LIBRISPEECH

    def __init__(
        self,
        *,
        dataset_path,
        seconds_per_batch,
        train_shuffle=True,
        num_workers=10,
    ):
        super().__init__()
        self.dataset_path = dataset_path
        self.seconds_per_batch = seconds_per_batch
        self.train_shuffle = train_shuffle
        self.num_workers = num_workers

    def train_dataloader(self):
        dataset = torch.utils.data.ConcatDataset(
            [
                self.librispeech_cls(self.dataset_path, url="train-clean-360"),
                self.librispeech_cls(self.dataset_path, url="train-clean-100"),
                self.librispeech_cls(self.dataset_path, url="train-other-500"),
            ]
        )
        len_list = [d[0].size(1) for d in dataset]

        sampler = BucketizeBatchSampler(
            len_list,
            num_buckets=10000,
            max_token_count=self.seconds_per_batch * 16000,
            min_len=32000,
            max_len=250000,
            shuffle=True,
        )
        sampler = DistributedBatchSampler(sampler, shuffle=self.train_shuffle)
        sampler.set_epoch(self.trainer.current_epoch)
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=sampler,
            collate_fn=CollateFnWav2Vec2(pad=False, rand_crop=True),
            num_workers=self.num_workers,
        )
        return dataloader

    def val_dataloader(self):
        dataset = torch.utils.data.ConcatDataset(
            [
                self.librispeech_cls(self.librispeech_path, url="dev-clean"),
                self.librispeech_cls(self.librispeech_path, url="dev-other"),
            ]
        )
        len_list = [d[0].size(1) for d in dataset]
        sampler = BucketizeBatchSampler(
            len_list,
            num_buckets=1000,
            max_token_count=self.seconds_per_batch * 16000,
            min_len=32000,
            max_len=250000,
            shuffle=False,
        )
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=sampler,
            collate_fn=CollateFnWav2Vec2(pad=False, rand_crop=True),
            num_workers=self.num_workers,
        )
        return dataloader