File: DummyData.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (53 lines) | stat: -rw-r--r-- 1,676 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import random

import numpy as np

import torch
from torch.utils.data import Dataset


class DummyData(Dataset):
    def __init__(
        self,
        max_val: int,
        sample_count: int,
        sample_length: int,
        sparsity_percentage: int,
    ):
        r"""
        A data class that generates random data.
        Args:
            max_val (int): the maximum value for an element
            sample_count (int): count of training samples
            sample_length (int): number of elements in a sample
            sparsity_percentage (int): the percentage of
                embeddings used by the input data in each iteration
        """
        self.max_val = max_val
        self.input_samples = sample_count
        self.input_dim = sample_length
        self.sparsity_percentage = sparsity_percentage

        def generate_input():
            precentage_of_elements = (100 - self.sparsity_percentage) / float(100)
            index_count = int(self.max_val * precentage_of_elements)
            elements = list(range(self.max_val))
            random.shuffle(elements)
            elements = elements[:index_count]
            data = [
                [
                    elements[random.randint(0, index_count - 1)]
                    for _ in range(self.input_dim)
                ]
                for _ in range(self.input_samples)
            ]
            return torch.from_numpy(np.array(data))

        self.input = generate_input()
        self.target = torch.randint(0, max_val, [sample_count])

    def __len__(self):
        return len(self.input)

    def __getitem__(self, index):
        return self.input[index], self.target[index]