File: benchmark_dataset.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (56 lines) | stat: -rw-r--r-- 1,700 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from torch.utils.data import Dataset


def collate_sentences_lm(samples):

    if len(samples) == 0:
        return {}

    id = torch.LongTensor([s["id"] for s in samples])
    src_tokens = torch.stack([s["source"] for s in samples], 0)
    tgt_tokens = torch.stack([s["target"] for s in samples], 0)
    ntokens = len(samples) * len(samples[0]["target"])
    src_lengths = torch.LongTensor([len(samples[0]["source"])] * len(samples))

    batch = {
        "id": id,
        "nsentences": len(samples),
        "ntokens": ntokens,
        "input": src_tokens,
        "target": tgt_tokens,
    }
    return batch


class BenchmarkLMDataset(Dataset):
    """
    Dataset to benchmark a translation like seq2seq task.
    Args:
        vocab_size (int, optional): size of the vocabulary (default 10000).
        max_source_positions (int, optional): max number of tokens in the
            source sentence (default: 1024).
        total_samples (int, optional): the total number of rows in the
            dataset (default: 10000).
    """

    def __init__(
        self, vocab_size=10000, max_source_positions=1024, total_samples=10000,
    ):
        self.vocab_size = vocab_size
        self.max_source_positions = max_source_positions
        self.total_samples = total_samples
        self.sizes = [self.max_source_positions] * self.total_samples

    def __getitem__(self, index):
        length = self.sizes[index]
        source = torch.randint(1, self.vocab_size, (length,))
        target = source.clone()
        return {
            "id": index,
            "source": source,
            "target": target,
        }

    def __len__(self):
        return self.total_samples