File: qembedding_pack_test.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (98 lines) | stat: -rw-r--r-- 3,837 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

import operator_benchmark as op_bench
import torch

embeddingbag_conversion_short_configs = op_bench.cross_product_configs(
    num_embeddings=(80,),
    embedding_dim=(128, 256, 512),
    tags=('short',)
)

embeddingbag_conversion_long_configs = op_bench.cross_product_configs(
    num_embeddings=(100, 120, 1000),
    embedding_dim=(16, 64, 128, 256, 512, 1024, 2048),
    tags=('long',)
)

embeddingbag_conversion_three_dim_configs = op_bench.cross_product_configs(
    num_embeddings=(80,),
    embedding_dim=(128, 256, 512),
    batch_size=(10,),
    tags=('short',)
)

conversion_ops = op_bench.op_list(
    attrs=(
        ('qembeddingbag_byte_prepack', torch.ops.quantized.embedding_bag_byte_prepack),
        ('qembeddingbag_4bit_prepack', torch.ops.quantized.embedding_bag_4bit_prepack),
        ('qembeddingbag_2bit_prepack', torch.ops.quantized.embedding_bag_2bit_prepack),
    ),
    attr_names=('op_name', 'op_func'),
)

unpack_ops = op_bench.op_list(
    attrs=(
        ('qembeddingbag_byte_unpack', torch.ops.quantized.embedding_bag_byte_unpack),
        ('qembeddingbag_4bit_unpack', torch.ops.quantized.embedding_bag_4bit_unpack),
        ('qembeddingbag_2bit_unpack', torch.ops.quantized.embedding_bag_2bit_unpack),
    ),
    attr_names=('op_name', 'op_func'),
)

class EmbeddingBagFloatToFusedBase(op_bench.TorchBenchmarkBase):
    def init(self, num_embeddings, embedding_dim, op_func):
        self.inputs = {
            "weight": torch.rand(num_embeddings, embedding_dim, dtype=torch.float) + 1
        }
        self.op_func = op_func

    def forward(self, weight):
        return self.op_func(weight)

class EmbeddingBagThreeDimFloatToFusedBase(op_bench.TorchBenchmarkBase):
    def init(self, num_embeddings, embedding_dim, batch_size, op_func):
        self.inputs = {
            "weight": torch.rand(batch_size, num_embeddings, embedding_dim, dtype=torch.float) + 1
        }
        self.op_func = op_func

    def forward(self, weight):
        return self.op_func(weight)

class EmbeddingBagFusedToFloatBase(op_bench.TorchBenchmarkBase):
    def init(self, num_embeddings, embedding_dim, op_func):
        weight = torch.randn(num_embeddings, embedding_dim + 8, dtype=torch.float)
        self.inputs = {
            "packed_weight": weight.to(torch.uint8)
        }
        self.op_func = op_func

    def forward(self, packed_weight):
        return self.op_func(packed_weight)

class EmbeddingBagThreeDimFusedToFloatBase(op_bench.TorchBenchmarkBase):
    def init(self, num_embeddings, embedding_dim, batch_size, op_func):
        weight = torch.randn(batch_size, num_embeddings, embedding_dim + 8, dtype=torch.float)
        self.inputs = {
            "packed_weight": weight.to(torch.uint8)
        }
        self.op_func = op_func

    def forward(self, packed_weight):
        return self.op_func(packed_weight)

op_bench.generate_pt_tests_from_op_list(conversion_ops,
                                        embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs,
                                        EmbeddingBagFloatToFusedBase)
op_bench.generate_pt_tests_from_op_list(unpack_ops,
                                        embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs,
                                        EmbeddingBagFusedToFloatBase)
op_bench.generate_pt_tests_from_op_list(conversion_ops,
                                        embeddingbag_conversion_three_dim_configs,
                                        EmbeddingBagThreeDimFloatToFusedBase)
op_bench.generate_pt_tests_from_op_list(unpack_ops,
                                        embeddingbag_conversion_three_dim_configs,
                                        EmbeddingBagThreeDimFusedToFloatBase)

if __name__ == "__main__":
    op_bench.benchmark_runner.main()