File: datasets.py

package info (click to toggle)
pytorch-audio 0.13.1-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 8,592 kB
  • sloc: python: 41,137; cpp: 8,016; sh: 3,538; makefile: 24
file content (120 lines) | stat: -rw-r--r-- 3,186 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from torchaudio.datasets import LIBRISPEECH


class MapMemoryCache(torch.utils.data.Dataset):
    """
    Wrap a dataset so that, whenever a new item is returned, it is saved to memory.
    """

    def __init__(self, dataset):
        self.dataset = dataset
        self._cache = [None] * len(dataset)

    def __getitem__(self, n):
        if self._cache[n] is not None:
            return self._cache[n]

        item = self.dataset[n]
        self._cache[n] = item

        return item

    def __len__(self):
        return len(self.dataset)


class Processed(torch.utils.data.Dataset):
    def __init__(self, dataset, transforms, encode):
        self.dataset = dataset
        self.transforms = transforms
        self.encode = encode

    def __getitem__(self, key):
        item = self.dataset[key]
        return self.process_datapoint(item)

    def __len__(self):
        return len(self.dataset)

    def process_datapoint(self, item):
        transformed = item[0]
        target = item[2].lower()

        transformed = self.transforms(transformed)
        transformed = transformed[0, ...].transpose(0, -1)

        target = self.encode(target)
        target = torch.tensor(target, dtype=torch.long, device=transformed.device)

        return transformed, target


def split_process_librispeech(
    datasets,
    transforms,
    language_model,
    root,
    folder_in_archive,
):
    def create(tags, cache=True):

        if isinstance(tags, str):
            tags = [tags]
        if isinstance(transforms, list):
            transform_list = transforms
        else:
            transform_list = [transforms]

        data = torch.utils.data.ConcatDataset(
            [
                Processed(
                    LIBRISPEECH(
                        root,
                        tag,
                        folder_in_archive=folder_in_archive,
                        download=False,
                    ),
                    transform,
                    language_model.encode,
                )
                for tag, transform in zip(tags, transform_list)
            ]
        )

        data = MapMemoryCache(data)
        return data

    # For performance, we cache all datasets
    return tuple(create(dataset) for dataset in datasets)


def collate_factory(model_length_function, transforms=None):

    if transforms is None:
        transforms = torch.nn.Sequential()

    def collate_fn(batch):

        tensors = [transforms(b[0]) for b in batch if b]

        tensors_lengths = torch.tensor(
            [model_length_function(t) for t in tensors],
            dtype=torch.long,
            device=tensors[0].device,
        )

        tensors = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True)
        tensors = tensors.transpose(1, -1)

        targets = [b[1] for b in batch if b]
        target_lengths = torch.tensor(
            [target.shape[0] for target in targets],
            dtype=torch.long,
            device=tensors.device,
        )
        targets = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True)

        return tensors, targets, tensors_lengths, target_lengths

    return collate_fn