File: torchtext_test_case.py

package info (click to toggle)
pytorch-text 0.14.1-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 11,560 kB
  • sloc: python: 14,197; cpp: 2,404; sh: 214; makefile: 20
file content (148 lines) | stat: -rw-r--r-- 6,912 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# -*- coding: utf-8 -*-
import json
import logging
import os
import shutil
import subprocess
import tempfile

import torch  # noqa: F401
from torch.testing._internal.common_utils import TestCase

logger = logging.getLogger(__name__)


class TorchtextTestCase(TestCase):
    def setUp(self) -> None:
        logging.basicConfig(format=("%(asctime)s - %(levelname)s - " "%(name)s - %(message)s"), level=logging.INFO)
        # Directory where everything temporary and test-related is written
        self.project_root = os.path.abspath(
            os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir))
        )
        self.test_dir = tempfile.mkdtemp()
        self.test_ppid_dataset_path = os.path.join(self.test_dir, "test_ppid_dataset")
        self.test_numerical_features_dataset_path = os.path.join(self.test_dir, "test_numerical_features_dataset")
        self.test_newline_dataset_path = os.path.join(self.test_dir, "test_newline_dataset")
        self.test_has_header_dataset_path = os.path.join(self.test_dir, "test_has_header_dataset")
        self.test_missing_field_dataset_path = os.path.join(self.test_dir, "test_msg_field_dst")
        self.test_dataset_splitting_path = os.path.join(self.test_dir, "test_dataset_split")
        self.test_nested_key_json_dataset_path = os.path.join(self.test_dir, "test_nested_key_json")

    def tearDown(self) -> None:
        try:
            shutil.rmtree(self.test_dir)
        except:
            subprocess.call(["rm", "-rf", self.test_dir])

    def write_test_ppid_dataset(self, data_format="csv"):
        data_format = data_format.lower()
        if data_format == "csv":
            delim = ","
        elif data_format == "tsv":
            delim = "\t"
        dict_dataset = [
            {
                "id": "0",
                "question1": "When do you use シ instead of し?",
                "question2": 'When do you use "&" instead of "and"?',
                "label": "0",
            },
            {
                "id": "1",
                "question1": "Where was Lincoln born?",
                "question2": "Which location was Abraham Lincoln born?",
                "label": "1",
            },
            {"id": "2", "question1": "What is 2+2", "question2": "2+2=?", "label": "1"},
        ]
        with open(self.test_ppid_dataset_path, "w", encoding="utf-8") as test_ppid_dataset_file:
            for example in dict_dataset:
                if data_format == "json":
                    test_ppid_dataset_file.write(json.dumps(example) + "\n")
                elif data_format == "csv" or data_format == "tsv":
                    test_ppid_dataset_file.write(
                        "{}\n".format(
                            delim.join([example["id"], example["question1"], example["question2"], example["label"]])
                        )
                    )
                else:
                    raise ValueError("Invalid format {}".format(data_format))

    def write_test_nested_key_json_dataset(self) -> None:
        """
        Used only to test nested key parsing of Example.fromJSON()
        """
        dict_dataset = [
            {"foods": {"fruits": ["Apple", "Banana"], "vegetables": [{"name": "Broccoli"}, {"name": "Cabbage"}]}},
            {
                "foods": {
                    "fruits": ["Cherry", "Grape", "Lemon"],
                    "vegetables": [{"name": "Cucumber"}, {"name": "Lettuce"}],
                }
            },
            {
                "foods": {
                    "fruits": ["Orange", "Pear", "Strawberry"],
                    "vegetables": [{"name": "Marrow"}, {"name": "Spinach"}],
                }
            },
        ]
        with open(self.test_nested_key_json_dataset_path, "w") as test_nested_key_json_dataset_file:
            for example in dict_dataset:
                test_nested_key_json_dataset_file.write(json.dumps(example) + "\n")

    def write_test_numerical_features_dataset(self) -> None:
        with open(self.test_numerical_features_dataset_path, "w") as test_numerical_features_dataset_file:
            test_numerical_features_dataset_file.write("0.1\t1\tteststring1\n")
            test_numerical_features_dataset_file.write("0.5\t12\tteststring2\n")
            test_numerical_features_dataset_file.write("0.2\t0\tteststring3\n")
            test_numerical_features_dataset_file.write("0.4\t12\tteststring4\n")
            test_numerical_features_dataset_file.write("0.9\t9\tteststring5\n")

    def make_mock_dataset(self, num_examples=30, num_labels=3):
        num_repetitions = int(round(num_examples / num_labels)) + 1

        texts = [str(i) for i in range(num_examples)]
        labels = list(range(num_labels)) * num_repetitions
        labels = [str(line) for line in labels[:num_examples]]

        dict_dataset = [{"text": t, "label": l} for t, l in zip(texts, labels)]
        return dict_dataset

    def write_test_splitting_dataset(self, num_examples=30, num_labels=3):
        dict_dataset = self.make_mock_dataset(num_examples, num_labels)
        delim = ","

        with open(self.test_dataset_splitting_path, "w") as test_splitting_dataset_file:
            for example in dict_dataset:
                test_splitting_dataset_file.write("{}\n".format(delim.join([example["text"], example["label"]])))


def verify_numericalized_example(
    field, test_example_data, test_example_numericalized, test_example_lengths=None, batch_first=False, train=True
):
    """
    Function to verify that numericalized example is correct
    with respect to the Field's Vocab.
    """
    if isinstance(test_example_numericalized, tuple):
        test_example_numericalized, lengths = test_example_numericalized
        assert test_example_lengths == lengths.tolist()
    if batch_first:
        test_example_numericalized.t_()
    # Transpose numericalized example so we can compare over batches
    for example_idx, numericalized_single_example in enumerate(test_example_numericalized.t()):
        assert len(test_example_data[example_idx]) == len(numericalized_single_example)
        assert numericalized_single_example.volatile is not train
        for token_idx, numericalized_token in enumerate(numericalized_single_example):
            # Convert from Variable to int
            numericalized_token = numericalized_token.item()  # Pytorch v4 compatibility
            test_example_token = test_example_data[example_idx][token_idx]
            # Check if the numericalized example is correct, taking into
            # account unknown tokens.
            if field.vocab.stoi[test_example_token] != 0:
                # token is in-vocabulary
                assert field.vocab.itos[numericalized_token] == test_example_token
            else:
                # token is OOV and <unk> always has an index of 0
                assert numericalized_token == 0