File: datasets.py

package info (click to toggle)
pytorch-audio 2.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, trixie
  • size: 10,696 kB
  • sloc: python: 61,274; cpp: 10,031; sh: 128; ansic: 70; makefile: 34
file content (113 lines) | stat: -rw-r--r-- 3,748 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import random

import torch
from processing import bits_to_normalized_waveform, normalized_waveform_to_bits
from torch.utils.data.dataset import random_split
from torchaudio.datasets import LIBRITTS, LJSPEECH
from torchaudio.transforms import MuLawEncoding


class MapMemoryCache(torch.utils.data.Dataset):
    r"""Wrap a dataset so that, whenever a new item is returned, it is saved to memory."""

    def __init__(self, dataset):
        self.dataset = dataset
        self._cache = [None] * len(dataset)

    def __getitem__(self, n):
        if self._cache[n] is not None:
            return self._cache[n]

        item = self.dataset[n]
        self._cache[n] = item

        return item

    def __len__(self):
        return len(self.dataset)


class Processed(torch.utils.data.Dataset):
    def __init__(self, dataset, transforms):
        self.dataset = dataset
        self.transforms = transforms

    def __getitem__(self, key):
        item = self.dataset[key]
        return self.process_datapoint(item)

    def __len__(self):
        return len(self.dataset)

    def process_datapoint(self, item):
        specgram = self.transforms(item[0])
        return item[0].squeeze(0), specgram


def split_process_dataset(args, transforms):
    if args.dataset == "ljspeech":
        data = LJSPEECH(root=args.file_path, download=False)

        val_length = int(len(data) * args.val_ratio)
        lengths = [len(data) - val_length, val_length]
        train_dataset, val_dataset = random_split(data, lengths)

    elif args.dataset == "libritts":
        train_dataset = LIBRITTS(root=args.file_path, url="train-clean-100", download=False)
        val_dataset = LIBRITTS(root=args.file_path, url="dev-clean", download=False)

    else:
        raise ValueError(f"Expected dataset: `ljspeech` or `libritts`, but found {args.dataset}")

    train_dataset = Processed(train_dataset, transforms)
    val_dataset = Processed(val_dataset, transforms)

    train_dataset = MapMemoryCache(train_dataset)
    val_dataset = MapMemoryCache(val_dataset)

    return train_dataset, val_dataset


def collate_factory(args):
    def raw_collate(batch):

        pad = (args.kernel_size - 1) // 2

        # input waveform length
        wave_length = args.hop_length * args.seq_len_factor
        # input spectrogram length
        spec_length = args.seq_len_factor + pad * 2

        # max start postion in spectrogram
        max_offsets = [x[1].shape[-1] - (spec_length + pad * 2) for x in batch]

        # random start postion in spectrogram
        spec_offsets = [random.randint(0, offset) for offset in max_offsets]
        # random start postion in waveform
        wave_offsets = [(offset + pad) * args.hop_length for offset in spec_offsets]

        waveform_combine = [x[0][wave_offsets[i] : wave_offsets[i] + wave_length + 1] for i, x in enumerate(batch)]
        specgram = [x[1][:, spec_offsets[i] : spec_offsets[i] + spec_length] for i, x in enumerate(batch)]

        specgram = torch.stack(specgram)
        waveform_combine = torch.stack(waveform_combine)

        waveform = waveform_combine[:, :wave_length]
        target = waveform_combine[:, 1:]

        # waveform: [-1, 1], target: [0, 2**bits-1] if loss = 'crossentropy'
        if args.loss == "crossentropy":

            if args.mulaw:
                mulaw_encode = MuLawEncoding(2**args.n_bits)
                waveform = mulaw_encode(waveform)
                target = mulaw_encode(target)

                waveform = bits_to_normalized_waveform(waveform, args.n_bits)

            else:
                target = normalized_waveform_to_bits(target, args.n_bits)

        return waveform.unsqueeze(1), specgram.unsqueeze(1), target.unsqueeze(1)

    return raw_collate