1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
|
import os
import sys
import time
from collections import Counter, OrderedDict
from typing import List, Union
# this is needed because we want to add 'torchtext/examples/data_pipeline' directory to the
# `sys.path` variable in order to import the pytext_vocab (since its not a module)
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "examples", "vocab"))
import torch
from pytext.data.utils import Vocabulary as PytextVocabulary
from pytext.torchscript.vocab import ScriptVocabulary as PytextScriptVocabulary
from pytext_vocab import ScriptVocab as ExperimentalScriptVocabulary
from torchtext.prototype.datasets import AG_NEWS
def _run_benchmark_lookup(tokens, vocab, num_iters=1):
def _run_benchmark_pytext_vocab(toks, v: PytextVocabulary):
for token_or_tokens_list in toks:
v.lookup_all(token_or_tokens_list)
def _run_benchmark_pytext_script_vocab(toks, v: PytextScriptVocabulary):
# list lookup
if isinstance(toks, list) and isinstance(toks[0], list):
for tokens_list in toks:
v.lookup_indices_1d(tokens_list)
# single token lookup
elif isinstance(toks, list):
for token in toks:
v.lookup_indices_1d([token])
else:
raise RuntimeError("Received tokens of incorrect type {}.".format(type(toks)))
def _run_benchmark_experimental_script_vocab(toks, v: ExperimentalScriptVocabulary):
# list lookup
if isinstance(toks, list) and isinstance(toks[0], list):
for tokens_list in toks:
v.lookup_indices_1d(tokens_list)
# single token lookup
elif isinstance(toks, list):
for token in toks:
v[token]
else:
raise RuntimeError("Received tokens of incorrect type {}.".format(type(toks)))
t0 = time.monotonic()
if isinstance(vocab, PytextVocabulary):
for _ in range(num_iters):
_run_benchmark_pytext_vocab(tokens, vocab)
elif isinstance(vocab, PytextScriptVocabulary):
for _ in range(num_iters):
_run_benchmark_pytext_script_vocab(tokens, vocab)
elif isinstance(vocab, (ExperimentalScriptVocabulary, torch.jit._script.RecursiveScriptModule)):
for _ in range(num_iters):
_run_benchmark_experimental_script_vocab(tokens, vocab)
else:
raise RuntimeError("Received vocab of incorrect type {}.".format(type(vocab)))
print("Lookup time:", time.monotonic() - t0)
def _run_benchmark_lookup_jit_for_loop(tokens: Union[List[str], List[List[str]]], vocab, num_iters=1):
@torch.jit.script
def _run_benchmark_pytext_script_vocab(toks: List[str], v: PytextScriptVocabulary):
for token in toks:
v.lookup_indices_1d([token])
@torch.jit.script
def _run_benchmark_experimental_script_vocab(toks: List[str], v: ExperimentalScriptVocabulary):
for token in toks:
v[token]
@torch.jit.script
def _run_benchmark_lists_pytext_script_vocab(tok_lists: List[List[str]], v: PytextScriptVocabulary):
for tokens_list in tok_lists:
v.lookup_indices_1d(tokens_list)
@torch.jit.script
def _run_benchmark_lists_experimental_script_vocab(tok_lists: List[List[str]], v: ExperimentalScriptVocabulary):
for tokens_list in tok_lists:
v.lookup_indices_1d(tokens_list)
t0 = time.monotonic()
# list lookup
if isinstance(tokens, list) and isinstance(tokens[0], list):
if isinstance(vocab, PytextScriptVocabulary):
for _ in range(num_iters):
_run_benchmark_lists_pytext_script_vocab(tokens, vocab)
elif isinstance(vocab, (ExperimentalScriptVocabulary, torch.jit._script.RecursiveScriptModule)):
for _ in range(num_iters):
_run_benchmark_lists_experimental_script_vocab(tokens, vocab)
else:
raise RuntimeError("Received vocab of incorrect type {}.".format(type(vocab)))
# single token lookup
elif isinstance(tokens, list):
if isinstance(vocab, PytextScriptVocabulary):
for _ in range(num_iters):
_run_benchmark_pytext_script_vocab(tokens, vocab)
elif isinstance(vocab, (ExperimentalScriptVocabulary, torch.jit._script.RecursiveScriptModule)):
for _ in range(num_iters):
_run_benchmark_experimental_script_vocab(tokens, vocab)
else:
raise RuntimeError("Received vocab of incorrect type {}.".format(type(vocab)))
else:
raise RuntimeError("Received tokens of incorrect type {}.".format(type(tokens)))
print("Lookup time:", time.monotonic() - t0)
def benchmark_experimental_vocab():
(train,) = AG_NEWS(data_select="train")
vocab = train.get_vocab()
tokens: List[str] = []
tokens_lists: List[List[str]] = []
for (_, text) in train:
cur_tokens = []
for id in text.tolist():
cur_tokens.append(vocab.itos[id])
tokens_lists.append(cur_tokens)
tokens += cur_tokens
print("Tokens size:", len(tokens))
print("Tokens list size:", len(tokens_lists))
counter = Counter(tokens)
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True)
vocab_list = [pair[0] for pair in sorted_by_freq_tuples]
vocab_list.insert(0, "<unk>")
ordered_dict = OrderedDict(sorted_by_freq_tuples)
# pytext vocab construction
print("Pytext Vocabulary")
t0 = time.monotonic()
pytext_vocab = PytextVocabulary(vocab_list)
print("Construction time:", time.monotonic() - t0)
# pytext ScriptVocab construction
print("Pytext Script Vocabulary")
t0 = time.monotonic()
pytext_script_vocab = PytextScriptVocabulary(vocab_list)
print("Construction time:", time.monotonic() - t0)
jit_pytext_script_vocab = torch.jit.script(pytext_script_vocab)
# experimental ScriptVocab construction
print("Experimental Script Vocabulary")
t0 = time.monotonic()
experimental_script_vocab = ExperimentalScriptVocabulary(ordered_dict, unk_token="<unk>")
print("Construction time:", time.monotonic() - t0)
jit_experimental_script_vocab = torch.jit.script(experimental_script_vocab)
# pytext Vocab eager lookup
print("Pytext Vocabulary - Eager Mode")
_run_benchmark_lookup(tokens, pytext_vocab)
_run_benchmark_lookup([tokens], pytext_vocab)
_run_benchmark_lookup(tokens_lists, pytext_vocab)
# pytext ScriptVocab eager lookup
print("Pytext ScriptVocab - Eager Mode")
_run_benchmark_lookup(tokens, pytext_script_vocab)
_run_benchmark_lookup([tokens], pytext_script_vocab)
_run_benchmark_lookup(tokens_lists, pytext_script_vocab)
# experimental ScriptVocab eager lookup
print("Experimental ScriptVocab - Eager Mode")
_run_benchmark_lookup(tokens, experimental_script_vocab)
_run_benchmark_lookup([tokens], experimental_script_vocab)
_run_benchmark_lookup(tokens_lists, experimental_script_vocab)
# pytext ScriptVocab jit lookup
print("Pytext ScriptVocab - Jit Mode")
_run_benchmark_lookup(tokens, jit_pytext_script_vocab)
_run_benchmark_lookup([tokens], jit_pytext_script_vocab)
_run_benchmark_lookup(tokens_lists, jit_pytext_script_vocab)
# experimental ScriptVocab jit lookup
print("Experimental ScriptVocab - Jit Mode")
_run_benchmark_lookup(tokens, jit_experimental_script_vocab)
_run_benchmark_lookup([tokens], jit_experimental_script_vocab)
_run_benchmark_lookup(tokens_lists, jit_experimental_script_vocab)
# pytext ScriptVocab JITed for loop
print("Pytext ScriptVocab - Jit For Loop")
_run_benchmark_lookup_jit_for_loop(tokens, jit_pytext_script_vocab)
_run_benchmark_lookup_jit_for_loop([tokens], jit_pytext_script_vocab)
_run_benchmark_lookup_jit_for_loop(tokens_lists, jit_pytext_script_vocab)
# experimental ScriptVocab JITed for loop
print("Experimental ScriptVocab - Jit For Loop")
_run_benchmark_lookup_jit_for_loop(tokens, jit_experimental_script_vocab)
_run_benchmark_lookup_jit_for_loop([tokens], jit_experimental_script_vocab)
_run_benchmark_lookup_jit_for_loop(tokens_lists, jit_experimental_script_vocab)
if __name__ == "__main__":
benchmark_experimental_vocab()
|