File: benchmark_vocab.py

package info (click to toggle)
pytorch-text 0.14.1-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 11,560 kB
  • sloc: python: 14,197; cpp: 2,404; sh: 214; makefile: 20
file content (129 lines) | stat: -rw-r--r-- 4,737 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import argparse
import time
from collections import Counter, OrderedDict

import torch
from torchtext.data.utils import get_tokenizer
from torchtext.datasets import DATASETS
from torchtext.prototype.transforms import basic_english_normalize
from torchtext.prototype.vocab_factory import build_vocab_from_text_file, load_vocab_from_file
from torchtext.vocab import build_vocab_from_iterator, vocab as VocabNew


def build_vocab(data, transforms):
    def apply_transforms(data):
        for _, line in data:
            yield transforms(line)

    vocab = build_vocab_from_iterator(apply_transforms(data), specials=["<unk>", "<pad>"])
    vocab.set_default_index(vocab["<unk>"])
    return vocab


def benchmark_new_vocab_construction(vocab_file_path, is_raw_text=True, num_iters=1):
    f = open(vocab_file_path, "r")
    t0 = time.monotonic()
    if is_raw_text:
        print("Loading from raw text file with basic_english_normalize tokenizer")
        for _ in range(num_iters):
            tokenizer = basic_english_normalize()
            jited_tokenizer = torch.jit.script(tokenizer)
            build_vocab_from_text_file(vocab_file_path, jited_tokenizer, num_cpus=1)
        print("Construction time:", time.monotonic() - t0)
    else:
        for _ in range(num_iters):
            load_vocab_from_file(f)
        print("Construction time:", time.monotonic() - t0)


def benchmark_new_vocab_lookup(vocab_file_path=None, dataset="AG_NEWS"):
    def _run_benchmark_lookup(tokens, vocab):
        t0 = time.monotonic()
        # list lookup
        if isinstance(tokens, list) and isinstance(tokens[0], list):
            for tokens_list in tokens:
                vocab.lookup_indices(tokens_list)
        # single token lookup
        elif isinstance(tokens, list):
            for token in tokens:
                vocab[token]
        else:
            raise RuntimeError("Received tokens of incorrect type {}.".format(type(tokens)))
        print("Lookup time:", time.monotonic() - t0)

    tokens = []
    tokens_lists = []
    tokenizer = get_tokenizer("basic_english")
    for (_, text) in DATASETS[dataset](split="train"):
        cur_tokens = tokenizer(text)
        tokens_lists.append(cur_tokens)
        tokens += cur_tokens

    if vocab_file_path:
        print("Loading Vocab from file {}".format(vocab_file_path))

        def token_iterator(file_path):
            f = open(file_path, "r")
            for token in f:
                yield token

        # new Vocab construction
        print("Vocab New")
        t0 = time.monotonic()
        f = open(vocab_file_path, "r")
        v_new = load_vocab_from_file(f)
        print("Construction time:", time.monotonic() - t0)
    else:
        print("Loading Vocab from {}".format(dataset))
        counter = Counter(tokens)
        sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True)
        ordered_dict = OrderedDict(sorted_by_freq_tuples)

        # new Vocab construction
        print("Vocab New")
        t0 = time.monotonic()
        v_new = VocabNew(ordered_dict)
        print("Construction time:", time.monotonic() - t0)
    jit_v_new = torch.jit.script(v_new)

    # new Vocab eager lookup
    print("Vocab New - Eager Mode")
    _run_benchmark_lookup(tokens, v_new)
    _run_benchmark_lookup([tokens], v_new)
    _run_benchmark_lookup(tokens_lists, v_new)

    jit_v_new = torch.jit.script(v_new)
    # new Vocab jit lookup
    print("Vocab New - Jit Mode")
    _run_benchmark_lookup(tokens, jit_v_new)
    _run_benchmark_lookup([tokens], jit_v_new)
    _run_benchmark_lookup(tokens_lists, jit_v_new)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Data procesing pipelines")
    parser.add_argument(
        "--run-construction-benchmark",
        type=bool,
        default=False,
        help="run benchmark for constructing a vocab (default=False)",
    )
    parser.add_argument(
        "--is-raw-text", type=bool, default=True, help="construct vocab from raw text file (default=True)"
    )
    parser.add_argument(
        "--vocab-filename-construction",
        type=str,
        default="vocab.txt",
        help="The name of vocab file used for construction",
    )
    parser.add_argument(
        "--vocab-filename-lookup", type=str, default=None, help="The name of vocab file used for lookup"
    )
    parser.add_argument("--dataset", type=str, default="AG_NEWS", help="The name of vocab file used for lookup")
    args = parser.parse_args()

    if args.run_construction_benchmark:
        benchmark_new_vocab_construction(args.vocab_filename_construction, is_raw_text=args.is_raw_text)
    else:
        benchmark_new_vocab_lookup(args.vocab_filename_lookup, args.dataset)