File: qembedding_bag_lookups_test.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (248 lines) | stat: -rw-r--r-- 9,652 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248

import operator_benchmark as op_bench
import torch
import numpy as np
from typing import Optional

from torch.testing._internal.common_quantization import (
    lengths_to_offsets
)

torch.ops.load_library("//caffe2/torch/fb/sparsenn:sparsenn_operators")


embedding_bag_rowwise_offsets_short_configs = op_bench.cross_product_configs(
    num_embeddings=(80,),
    embedding_dim=(128, 256),
    num_offsets=range(2, 10),
    enable_per_sample_weights=(True, False),
    include_last_offset=(True, False),
    is_pruned_weights=(True, False,),
    use_32bit_indices=(True, False),
    use_32bit_offsets=(True, False),
    tags=['short'],
)


embedding_bag_rowwise_offsets_long_configs = op_bench.cross_product_configs(
    num_embeddings=(100, 120, 1000, 10_000, 20_000),
    embedding_dim=(16, 64, 128, 256),
    num_offsets=range(10, 20),
    enable_per_sample_weights=(True, False),
    include_last_offset=(True, False),
    is_pruned_weights=(True, False,),
    use_32bit_indices=(True, False),
    use_32bit_offsets=(True, False),
    tags=['long']
)


full_configs = embedding_bag_rowwise_offsets_short_configs + embedding_bag_rowwise_offsets_long_configs

four_bit_rowwise_ops = op_bench.op_list(
    attrs=(
        ('qembeddingbag_4bit_rowwise_offsets', torch.ops.quantized.embedding_bag_4bit_rowwise_offsets),
    ),
    attr_names=('op_name', 'op_func'),
)

byte_rowwise_ops = op_bench.op_list(
    attrs=(
        ('qembeddingbag_byte_rowwise_offsets', torch.ops.quantized.embedding_bag_byte_rowwise_offsets),
    ),
    attr_names=('op_name', 'op_func'),
)


def get_pruned_weights_and_mapping(q_weights):
    indicator = torch.from_numpy(np.random.uniform(
        low=-1.0, high=1.0, size=[q_weights.shape[0]]).astype(np.float32))

    q_pruned_weights, compressed_indices_mapping = torch.ops.fb.embedding_bag_rowwise_prune(
        q_weights, indicator, 0.01, torch.int32)

    return q_pruned_weights, compressed_indices_mapping


class EmbedddingBag4BitRowwiseOffsetsTest(op_bench.TorchBenchmarkBase):
    def init(self,
             num_embeddings: int,
             embedding_dim: int,
             num_offsets: int,
             enable_per_sample_weights: bool,
             include_last_offset: bool,
             is_pruned_weights: bool,
             use_32bit_indices: bool,
             use_32bit_offsets: bool,
             op_func):
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.num_offsets = num_offsets
        self.enable_per_sample_weights = enable_per_sample_weights
        self.include_last_offset = include_last_offset
        self.max_segment_length = 20
        self.num_lengths = np.random.randint(1, num_offsets + 1)
        self.lengths = np.random.randint(0, self.max_segment_length + 1,
                                         size=self.num_lengths).astype(np.int32)
        self.num_indices = np.sum(self.lengths)
        self.is_pruned_weights = is_pruned_weights
        self.use_32bit_indices = use_32bit_indices
        self.use_32bit_offsets = use_32bit_offsets

        self.offsets = lengths_to_offsets(self.lengths)
        self.indices = torch.from_numpy(np.random.randint(
            low=0, high=num_embeddings, size=self.num_indices, dtype=np.int64))

        self.indices = self.indices.int() if self.use_32bit_indices else self.indices
        self.offsets = self.offsets.int() if self.use_32bit_offsets else self.offsets

        if self.include_last_offset:
            self.offsets = torch.cat(
                (self.offsets, torch.tensor([self.indices.size(0)], dtype=torch.long)), 0
            )

        self.weights = torch.from_numpy((np.random.random_sample((
            self.num_embeddings, self.embedding_dim)) + 1).astype(np.float32))
        self.indices = torch.from_numpy(np.random.randint(
            low=0, high=self.num_embeddings, size=self.num_indices, dtype=np.int64))
        self.prepack_func = torch.ops.quantized.embedding_bag_4bit_prepack

        self.prepacked_weights = self.prepack_func(self.weights)
        self.per_sample_weights = torch.from_numpy(np.random.uniform(
            low=0.01, high=0.5, size=[len(self.indices)]).astype(np.float32)) if \
            self.enable_per_sample_weights else None

        self.compressed_indices = None

        if self.is_pruned_weights:
            self.prepacked_weights, self.compressed_indices = get_pruned_weights_and_mapping(self.prepacked_weights)

        self.inputs = {
            "prepacked_weights": self.prepacked_weights,
            "indices": self.indices,
            "offsets": self.offsets,
            "mode": 0,
            "per_sample_weights": self.per_sample_weights,
            "include_last_offset": self.include_last_offset,
            "is_pruned_weights": self.is_pruned_weights,
            "compressed_indices": self.compressed_indices
        }

        self.op_func = op_func

    def forward(
        self,
        prepacked_weights,
        indices,
        offsets,
        mode: int,
        per_sample_weights: Optional[torch.Tensor],
        include_last_offset: bool,
        is_pruned_weights: bool,
        compressed_indices: Optional[torch.Tensor]
    ):

        return self.op_func(prepacked_weights, indices, offsets,
                            mode=mode,
                            per_sample_weights=per_sample_weights,
                            include_last_offset=include_last_offset,
                            pruned_weights=is_pruned_weights,
                            compressed_indices_mapping=compressed_indices)


class EmbedddingBagByteRowwiseOffsetsTest(op_bench.TorchBenchmarkBase):
    def init(self,
             num_embeddings: int,
             embedding_dim: int,
             num_offsets: int,
             enable_per_sample_weights: bool,
             include_last_offset: bool,
             is_pruned_weights: bool,
             use_32bit_indices: bool,
             use_32bit_offsets: bool,
             op_func):
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.num_offsets = num_offsets
        self.enable_per_sample_weights = enable_per_sample_weights
        self.include_last_offset = include_last_offset
        self.max_segment_length = 20
        self.num_lengths = np.random.randint(1, num_offsets + 1)
        self.lengths = np.random.randint(0, self.max_segment_length + 1,
                                         size=self.num_lengths).astype(np.int32)
        self.is_pruned_weights = is_pruned_weights
        self.use_32bit_indices = use_32bit_indices
        self.use_32bit_offsets = use_32bit_offsets

        self.num_indices = np.sum(self.lengths)
        self.offsets = lengths_to_offsets(self.lengths)
        self.indices = torch.from_numpy(np.random.randint(
            low=0, high=num_embeddings, size=self.num_indices, dtype=np.int64))

        self.indices = self.indices.int() if self.use_32bit_indices else self.indices
        self.offsets = self.offsets.int() if self.use_32bit_offsets else self.offsets

        if include_last_offset:
            self.offsets = torch.cat(
                (self.offsets, torch.tensor([self.indices.size(0)], dtype=torch.long)), 0
            )

        self.weights = torch.from_numpy((np.random.random_sample((
            self.num_embeddings, self.embedding_dim)) + 1).astype(np.float32))
        self.indices = torch.from_numpy(np.random.randint(
            low=0, high=self.num_embeddings, size=self.num_indices, dtype=np.int64))

        self.prepack_func = torch.ops.quantized.embedding_bag_byte_prepack

        self.prepacked_weights = self.prepack_func(self.weights)
        self.per_sample_weights = torch.from_numpy(np.random.uniform(
            low=0.01, high=0.5, size=[len(self.indices)]).astype(np.float32)) if \
            self.enable_per_sample_weights else None

        self.compressed_indices = None

        if self.is_pruned_weights:
            self.prepacked_weights, self.compressed_indices = get_pruned_weights_and_mapping(self.prepacked_weights)

        self.inputs = {
            "prepacked_weights": self.prepacked_weights,
            "indices": self.indices,
            "offsets": self.offsets,
            "mode": 0,
            "per_sample_weights": self.per_sample_weights,
            "include_last_offset": self.include_last_offset,
            "is_pruned_weights": self.is_pruned_weights,
            "compressed_indices": self.compressed_indices
        }

        self.op_func = op_func

    def forward(
        self,
        prepacked_weights,
        indices,
        offsets,
        mode: int,
        per_sample_weights: Optional[torch.Tensor],
        include_last_offset: bool,
        is_pruned_weights: bool,
        compressed_indices: Optional[torch.Tensor]
    ):
        return self.op_func(prepacked_weights, indices, offsets,
                            mode=0,
                            per_sample_weights=per_sample_weights,
                            include_last_offset=self.include_last_offset,
                            pruned_weights=self.is_pruned_weights,
                            compressed_indices_mapping=self.compressed_indices)


op_bench.generate_pt_tests_from_op_list(four_bit_rowwise_ops,
                                        full_configs,
                                        EmbedddingBag4BitRowwiseOffsetsTest)
op_bench.generate_pt_tests_from_op_list(byte_rowwise_ops,
                                        full_configs,
                                        EmbedddingBagByteRowwiseOffsetsTest)


if __name__ == "__main__":
    op_bench.benchmark_runner.main()