File: sentencepiece_extractor.py

package info (click to toggle)
tokenizers 0.20.3%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 5,480 kB
  • sloc: python: 4,499; javascript: 419; makefile: 124
file content (145 lines) | stat: -rw-r--r-- 4,852 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from argparse import ArgumentParser
from json import dump
from logging import basicConfig, getLogger
from os import linesep, remove
from os.path import exists
from tempfile import NamedTemporaryFile
from typing import Dict, List, Tuple

from requests import get
from sentencepiece import SentencePieceProcessor
from tqdm import trange, tqdm

basicConfig()
logger = getLogger()


class SentencePieceExtractor:
    """
    Extractor implementation for SentencePiece trained models.
    https://github.com/google/sentencepiece
    """

    def __init__(self, model: str):
        # Get SentencePiece
        self.sp = SentencePieceProcessor()
        self.sp.Load(model)

    def extract(self) -> Tuple[Dict[str, int], List[Tuple]]:
        sp = self.sp
        vocab = {sp.id_to_piece(index): index for index in trange(sp.GetPieceSize())}

        # Merges
        merges = []
        for piece_l in tqdm(vocab.keys(), total=sp.GetPieceSize()):
            for piece_r in vocab.keys():
                merge = f"{piece_l}{piece_r}"
                piece_id = vocab.get(merge, None)
                if piece_id:
                    merges += [(piece_l, piece_r, piece_id)]
        merges = sorted(merges, key=lambda val: val[2])
        merges = [(val[0], val[1]) for val in merges]

        return vocab, merges


class YouTokenToMeExtractor:
    """
    Extractor implementation for YouTokenToMe trained models format.
    Model are as follow:
        vocab_size nb_merges
        piece piece_id
        ...(repeated vocab_size)
        piece_id_left piece_id_right piece_id
        ...(repeated nb merges)
    """

    def __init__(self, model: str):
        self._model = model

    def extract(self) -> Tuple[Dict[str, int], List[Tuple]]:
        with open(self._model, "r") as model_f:
            # Retrieve information
            nb_pieces, nb_merges = map(int, model_f.readline().split())
            vocab, merges = {}, []

            # Vocab
            for _ in trange(nb_pieces):
                piece, piece_id = map(int, model_f.readline().split())
                vocab[piece_id] = chr(piece)

            # Merges
            for _ in trange(nb_merges):
                piece_id_l, piece_id_r, piece = map(int, model_f.readline().split())
                piece_l, piece_r = vocab[piece_id_l], vocab[piece_id_r]
                vocab[piece] = f"{piece_l}{piece_r}"
                merges += [(piece_l, piece_r)]

            # Special tokens
            unk, pad, bos, eos = map(int, model_f.readline().split())
            vocab[unk] = "<unk>"
            vocab[pad] = "<pad>"
            vocab[bos] = "<bos>"
            vocab[eos] = "<eos>"

        # Invert key and value for vocab
        vocab = dict(zip(vocab.values(), vocab.keys()))
        return vocab, merges


if __name__ == "__main__":
    parser = ArgumentParser("SentencePiece vocab extractor")
    parser.add_argument(
        "--provider",
        type=str,
        required=True,
        choices=["sentencepiece", "youtokentome"],
        help="Indicate the format of the file.",
    )
    parser.add_argument("--model", type=str, required=True, help="SentencePiece model to extract vocab from.")
    parser.add_argument(
        "--vocab-output-path",
        type=str,
        required=True,
        help="Path where the vocab.json file will be extracted",
    )
    parser.add_argument(
        "--merges-output-path",
        type=str,
        required=True,
        help="Path where the merges file will be extracted",
    )

    # Parse cli arguments
    args = parser.parse_args()

    try:
        if args.model.startswith("http"):
            # Saving model
            with NamedTemporaryFile("wb", delete=False) as f:
                logger.info("Writing content from {} to {}".format(args.model, f.name))
                response = get(args.model, allow_redirects=True)
                f.write(response.content)

                args.remote_model = args.model
                args.model = f.name

        # Allocate extractor
        extractor = SentencePieceExtractor if args.provider == "sentencepiece" else YouTokenToMeExtractor
        extractor = extractor(args.model)

        logger.info(f"Using {type(extractor).__name__}")

        # Open output files and let's extract model information
        with open(args.vocab_output_path, "w") as vocab_f:
            with open(args.merges_output_path, "w") as merges_f:
                # Do the extraction
                vocab, merges = extractor.extract()

                # Save content
                dump(vocab, vocab_f)
                merges_f.writelines(map(lambda x: f"{x[0]} {x[1]}{linesep}", merges))
    finally:
        # If model was downloaded from internet we need to cleanup the tmp folder.
        if hasattr(args, "remote_model") and exists(args.model):
            remove(args.model)