1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
|
import argparse
import sys
import torch
from torchtext.data.utils import get_tokenizer, ngrams_iterator
from torchtext.prototype.transforms import load_sp_model, PRETRAINED_SP_MODEL, SentencePieceTokenizer
from torchtext.utils import download_from_url
def predict(text, model, dictionary, tokenizer, ngrams):
r"""
The predict() function here is used to test the model on a sample text.
The input text is numericalized with the vocab and then sent to
the model for inference.
Args:
text: a sample text string
model: the trained model
dictionary: a vocab object for the information of string-to-index
tokenizer: tokenizer object to split text into tokens
ngrams: the number of ngrams.
"""
with torch.no_grad():
text = torch.tensor(dictionary(list(ngrams_iterator(tokenizer(text), ngrams))))
output = model(text, torch.tensor([0]))
return output.argmax(1).item() + 1
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Predict text from stdin given model and dictionary")
parser.add_argument("model", help="the path for model")
parser.add_argument("dictionary", help="the path for dictionary")
parser.add_argument("--ngrams", type=int, default=2, help="ngrams (default=2)")
parser.add_argument(
"--use-sp-tokenizer", type=bool, default=False, help="use sentencepiece tokenizer (default=False)"
)
args = parser.parse_args()
model = torch.load(args.model)
dictionary = torch.load(args.dictionary)
if args.use_sp_tokenizer:
sp_model_path = download_from_url(PRETRAINED_SP_MODEL["text_unigram_15000"])
sp_model = load_sp_model(sp_model_path)
tokenizer = SentencePieceTokenizer(sp_model)
else:
tokenizer = get_tokenizer("basic_english")
for line in sys.stdin:
print(predict(line, model, dictionary, tokenizer, args.ngrams))
|