File: test_tutorial_train_from_iterators.py

package info (click to toggle)
tokenizers 0.20.3%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 5,480 kB
  • sloc: python: 4,499; javascript: 419; makefile: 124
file content (104 lines) | stat: -rw-r--r-- 3,538 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# flake8: noqa
import gzip
import os

import datasets
import pytest

from ..utils import data_dir, train_files


class TestTrainFromIterators:
    @staticmethod
    def get_tokenizer_trainer():
        # START init_tokenizer_trainer
        from tokenizers import Tokenizer, decoders, models, normalizers, pre_tokenizers, trainers

        tokenizer = Tokenizer(models.Unigram())
        tokenizer.normalizer = normalizers.NFKC()
        tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
        tokenizer.decoder = decoders.ByteLevel()

        trainer = trainers.UnigramTrainer(
            vocab_size=20000,
            initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
            special_tokens=["<PAD>", "<BOS>", "<EOS>"],
        )
        # END init_tokenizer_trainer
        trainer.show_progress = False

        return tokenizer, trainer

    @staticmethod
    def load_dummy_dataset():
        # START load_dataset
        import datasets

        dataset = datasets.load_dataset("wikitext", "wikitext-103-raw-v1", split="train+test+validation")
        # END load_dataset

    @pytest.fixture(scope="class")
    def setup_gzip_files(self, train_files):
        with open(train_files["small"], "rt") as small:
            for n in range(3):
                path = f"data/my-file.{n}.gz"
                with gzip.open(path, "wt") as f:
                    f.write(small.read())

    def test_train_basic(self):
        tokenizer, trainer = self.get_tokenizer_trainer()

        # START train_basic
        # First few lines of the "Zen of Python" https://www.python.org/dev/peps/pep-0020/
        data = [
            "Beautiful is better than ugly."
            "Explicit is better than implicit."
            "Simple is better than complex."
            "Complex is better than complicated."
            "Flat is better than nested."
            "Sparse is better than dense."
            "Readability counts."
        ]
        tokenizer.train_from_iterator(data, trainer=trainer)
        # END train_basic

    def test_datasets(self):
        tokenizer, trainer = self.get_tokenizer_trainer()

        # In order to keep tests fast, we only use the first 100 examples
        os.environ["TOKENIZERS_PARALLELISM"] = "true"
        dataset = datasets.load_dataset("wikitext", "wikitext-103-raw-v1", split="train[0:100]")

        # START def_batch_iterator
        def batch_iterator(batch_size=1000):
            # Only keep the text column to avoid decoding the rest of the columns unnecessarily
            tok_dataset = dataset.select_columns("text")
            for batch in tok_dataset.iter(batch_size):
                yield batch["text"]

        # END def_batch_iterator

        # START train_datasets
        tokenizer.train_from_iterator(batch_iterator(), trainer=trainer, length=len(dataset))
        # END train_datasets

    def test_gzip(self, setup_gzip_files):
        tokenizer, trainer = self.get_tokenizer_trainer()

        # START single_gzip
        import gzip

        with gzip.open("data/my-file.0.gz", "rt") as f:
            tokenizer.train_from_iterator(f, trainer=trainer)
        # END single_gzip
        # START multi_gzip
        files = ["data/my-file.0.gz", "data/my-file.1.gz", "data/my-file.2.gz"]

        def gzip_iterator():
            for path in files:
                with gzip.open(path, "rt") as f:
                    for line in f:
                        yield line

        tokenizer.train_from_iterator(gzip_iterator(), trainer=trainer)
        # END multi_gzip