1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
|
import argparse
import time
from collections import Counter, OrderedDict
import torch
from torchtext.data.utils import get_tokenizer
from torchtext.datasets import DATASETS
from torchtext.prototype.transforms import basic_english_normalize
from torchtext.prototype.vocab_factory import build_vocab_from_text_file, load_vocab_from_file
from torchtext.vocab import build_vocab_from_iterator, vocab as VocabNew
def build_vocab(data, transforms):
def apply_transforms(data):
for _, line in data:
yield transforms(line)
vocab = build_vocab_from_iterator(apply_transforms(data), specials=["<unk>", "<pad>"])
vocab.set_default_index(vocab["<unk>"])
return vocab
def benchmark_new_vocab_construction(vocab_file_path, is_raw_text=True, num_iters=1):
f = open(vocab_file_path, "r")
t0 = time.monotonic()
if is_raw_text:
print("Loading from raw text file with basic_english_normalize tokenizer")
for _ in range(num_iters):
tokenizer = basic_english_normalize()
jited_tokenizer = torch.jit.script(tokenizer)
build_vocab_from_text_file(vocab_file_path, jited_tokenizer, num_cpus=1)
print("Construction time:", time.monotonic() - t0)
else:
for _ in range(num_iters):
load_vocab_from_file(f)
print("Construction time:", time.monotonic() - t0)
def benchmark_new_vocab_lookup(vocab_file_path=None, dataset="AG_NEWS"):
def _run_benchmark_lookup(tokens, vocab):
t0 = time.monotonic()
# list lookup
if isinstance(tokens, list) and isinstance(tokens[0], list):
for tokens_list in tokens:
vocab.lookup_indices(tokens_list)
# single token lookup
elif isinstance(tokens, list):
for token in tokens:
vocab[token]
else:
raise RuntimeError("Received tokens of incorrect type {}.".format(type(tokens)))
print("Lookup time:", time.monotonic() - t0)
tokens = []
tokens_lists = []
tokenizer = get_tokenizer("basic_english")
for (_, text) in DATASETS[dataset](split="train"):
cur_tokens = tokenizer(text)
tokens_lists.append(cur_tokens)
tokens += cur_tokens
if vocab_file_path:
print("Loading Vocab from file {}".format(vocab_file_path))
def token_iterator(file_path):
f = open(file_path, "r")
for token in f:
yield token
# new Vocab construction
print("Vocab New")
t0 = time.monotonic()
f = open(vocab_file_path, "r")
v_new = load_vocab_from_file(f)
print("Construction time:", time.monotonic() - t0)
else:
print("Loading Vocab from {}".format(dataset))
counter = Counter(tokens)
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True)
ordered_dict = OrderedDict(sorted_by_freq_tuples)
# new Vocab construction
print("Vocab New")
t0 = time.monotonic()
v_new = VocabNew(ordered_dict)
print("Construction time:", time.monotonic() - t0)
jit_v_new = torch.jit.script(v_new)
# new Vocab eager lookup
print("Vocab New - Eager Mode")
_run_benchmark_lookup(tokens, v_new)
_run_benchmark_lookup([tokens], v_new)
_run_benchmark_lookup(tokens_lists, v_new)
jit_v_new = torch.jit.script(v_new)
# new Vocab jit lookup
print("Vocab New - Jit Mode")
_run_benchmark_lookup(tokens, jit_v_new)
_run_benchmark_lookup([tokens], jit_v_new)
_run_benchmark_lookup(tokens_lists, jit_v_new)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Data procesing pipelines")
parser.add_argument(
"--run-construction-benchmark",
type=bool,
default=False,
help="run benchmark for constructing a vocab (default=False)",
)
parser.add_argument(
"--is-raw-text", type=bool, default=True, help="construct vocab from raw text file (default=True)"
)
parser.add_argument(
"--vocab-filename-construction",
type=str,
default="vocab.txt",
help="The name of vocab file used for construction",
)
parser.add_argument(
"--vocab-filename-lookup", type=str, default=None, help="The name of vocab file used for lookup"
)
parser.add_argument("--dataset", type=str, default="AG_NEWS", help="The name of vocab file used for lookup")
args = parser.parse_args()
if args.run_construction_benchmark:
benchmark_new_vocab_construction(args.vocab_filename_construction, is_raw_text=args.is_raw_text)
else:
benchmark_new_vocab_lookup(args.vocab_filename_lookup, args.dataset)
|