1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
|
import random
import numpy as np
import torch
from torch.utils.data import Dataset
class DummyData(Dataset):
def __init__(
self,
max_val: int,
sample_count: int,
sample_length: int,
sparsity_percentage: int,
):
r"""
A data class that generates random data.
Args:
max_val (int): the maximum value for an element
sample_count (int): count of training samples
sample_length (int): number of elements in a sample
sparsity_percentage (int): the percentage of
embeddings used by the input data in each iteration
"""
self.max_val = max_val
self.input_samples = sample_count
self.input_dim = sample_length
self.sparsity_percentage = sparsity_percentage
def generate_input():
precentage_of_elements = (100 - self.sparsity_percentage) / float(100)
index_count = int(self.max_val * precentage_of_elements)
elements = list(range(self.max_val))
random.shuffle(elements)
elements = elements[:index_count]
data = [
[
elements[random.randint(0, index_count - 1)]
for _ in range(self.input_dim)
]
for _ in range(self.input_samples)
]
return torch.from_numpy(np.array(data))
self.input = generate_input()
self.target = torch.randint(0, max_val, [sample_count])
def __len__(self):
return len(self.input)
def __getitem__(self, index):
return self.input[index], self.target[index]
|