1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
|
import random
import torch
from processing import bits_to_normalized_waveform, normalized_waveform_to_bits
from torch.utils.data.dataset import random_split
from torchaudio.datasets import LIBRITTS, LJSPEECH
from torchaudio.transforms import MuLawEncoding
class MapMemoryCache(torch.utils.data.Dataset):
r"""Wrap a dataset so that, whenever a new item is returned, it is saved to memory."""
def __init__(self, dataset):
self.dataset = dataset
self._cache = [None] * len(dataset)
def __getitem__(self, n):
if self._cache[n] is not None:
return self._cache[n]
item = self.dataset[n]
self._cache[n] = item
return item
def __len__(self):
return len(self.dataset)
class Processed(torch.utils.data.Dataset):
def __init__(self, dataset, transforms):
self.dataset = dataset
self.transforms = transforms
def __getitem__(self, key):
item = self.dataset[key]
return self.process_datapoint(item)
def __len__(self):
return len(self.dataset)
def process_datapoint(self, item):
specgram = self.transforms(item[0])
return item[0].squeeze(0), specgram
def split_process_dataset(args, transforms):
if args.dataset == "ljspeech":
data = LJSPEECH(root=args.file_path, download=False)
val_length = int(len(data) * args.val_ratio)
lengths = [len(data) - val_length, val_length]
train_dataset, val_dataset = random_split(data, lengths)
elif args.dataset == "libritts":
train_dataset = LIBRITTS(root=args.file_path, url="train-clean-100", download=False)
val_dataset = LIBRITTS(root=args.file_path, url="dev-clean", download=False)
else:
raise ValueError(f"Expected dataset: `ljspeech` or `libritts`, but found {args.dataset}")
train_dataset = Processed(train_dataset, transforms)
val_dataset = Processed(val_dataset, transforms)
train_dataset = MapMemoryCache(train_dataset)
val_dataset = MapMemoryCache(val_dataset)
return train_dataset, val_dataset
def collate_factory(args):
def raw_collate(batch):
pad = (args.kernel_size - 1) // 2
# input waveform length
wave_length = args.hop_length * args.seq_len_factor
# input spectrogram length
spec_length = args.seq_len_factor + pad * 2
# max start postion in spectrogram
max_offsets = [x[1].shape[-1] - (spec_length + pad * 2) for x in batch]
# random start postion in spectrogram
spec_offsets = [random.randint(0, offset) for offset in max_offsets]
# random start postion in waveform
wave_offsets = [(offset + pad) * args.hop_length for offset in spec_offsets]
waveform_combine = [x[0][wave_offsets[i] : wave_offsets[i] + wave_length + 1] for i, x in enumerate(batch)]
specgram = [x[1][:, spec_offsets[i] : spec_offsets[i] + spec_length] for i, x in enumerate(batch)]
specgram = torch.stack(specgram)
waveform_combine = torch.stack(waveform_combine)
waveform = waveform_combine[:, :wave_length]
target = waveform_combine[:, 1:]
# waveform: [-1, 1], target: [0, 2**bits-1] if loss = 'crossentropy'
if args.loss == "crossentropy":
if args.mulaw:
mulaw_encode = MuLawEncoding(2**args.n_bits)
waveform = mulaw_encode(waveform)
target = mulaw_encode(target)
waveform = bits_to_normalized_waveform(waveform, args.n_bits)
else:
target = normalized_waveform_to_bits(target, args.n_bits)
return waveform.unsqueeze(1), specgram.unsqueeze(1), target.unsqueeze(1)
return raw_collate
|