File: qembedding_pack_test.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (111 lines) | stat: -rw-r--r-- 3,597 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import operator_benchmark as op_bench

import torch


embeddingbag_conversion_short_configs = op_bench.cross_product_configs(
    num_embeddings=(80,), embedding_dim=(128, 256, 512), tags=("short",)
)

embeddingbag_conversion_long_configs = op_bench.cross_product_configs(
    num_embeddings=(100, 120, 1000),
    embedding_dim=(16, 64, 128, 256, 512, 1024, 2048),
    tags=("long",),
)

embeddingbag_conversion_three_dim_configs = op_bench.cross_product_configs(
    num_embeddings=(80,),
    embedding_dim=(128, 256, 512),
    batch_size=(10,),
    tags=("short",),
)

conversion_ops = op_bench.op_list(
    attrs=(
        ("qembeddingbag_byte_prepack", torch.ops.quantized.embedding_bag_byte_prepack),
        ("qembeddingbag_4bit_prepack", torch.ops.quantized.embedding_bag_4bit_prepack),
        ("qembeddingbag_2bit_prepack", torch.ops.quantized.embedding_bag_2bit_prepack),
    ),
    attr_names=("op_name", "op_func"),
)

unpack_ops = op_bench.op_list(
    attrs=(
        ("qembeddingbag_byte_unpack", torch.ops.quantized.embedding_bag_byte_unpack),
        ("qembeddingbag_4bit_unpack", torch.ops.quantized.embedding_bag_4bit_unpack),
        ("qembeddingbag_2bit_unpack", torch.ops.quantized.embedding_bag_2bit_unpack),
    ),
    attr_names=("op_name", "op_func"),
)


class EmbeddingBagFloatToFusedBase(op_bench.TorchBenchmarkBase):
    def init(self, num_embeddings, embedding_dim, op_func):
        self.inputs = {
            "weight": torch.rand(num_embeddings, embedding_dim, dtype=torch.float) + 1
        }
        self.op_func = op_func

    def forward(self, weight):
        return self.op_func(weight)


class EmbeddingBagThreeDimFloatToFusedBase(op_bench.TorchBenchmarkBase):
    def init(self, num_embeddings, embedding_dim, batch_size, op_func):
        self.inputs = {
            "weight": torch.rand(
                batch_size, num_embeddings, embedding_dim, dtype=torch.float
            )
            + 1
        }
        self.op_func = op_func

    def forward(self, weight):
        return self.op_func(weight)


class EmbeddingBagFusedToFloatBase(op_bench.TorchBenchmarkBase):
    def init(self, num_embeddings, embedding_dim, op_func):
        weight = torch.randn(num_embeddings, embedding_dim + 8, dtype=torch.float)
        self.inputs = {"packed_weight": weight.to(torch.uint8)}
        self.op_func = op_func

    def forward(self, packed_weight):
        return self.op_func(packed_weight)


class EmbeddingBagThreeDimFusedToFloatBase(op_bench.TorchBenchmarkBase):
    def init(self, num_embeddings, embedding_dim, batch_size, op_func):
        weight = torch.randn(
            batch_size, num_embeddings, embedding_dim + 8, dtype=torch.float
        )
        self.inputs = {"packed_weight": weight.to(torch.uint8)}
        self.op_func = op_func

    def forward(self, packed_weight):
        return self.op_func(packed_weight)


op_bench.generate_pt_tests_from_op_list(
    conversion_ops,
    embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs,
    EmbeddingBagFloatToFusedBase,
)
op_bench.generate_pt_tests_from_op_list(
    unpack_ops,
    embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs,
    EmbeddingBagFusedToFloatBase,
)
op_bench.generate_pt_tests_from_op_list(
    conversion_ops,
    embeddingbag_conversion_three_dim_configs,
    EmbeddingBagThreeDimFloatToFusedBase,
)
op_bench.generate_pt_tests_from_op_list(
    unpack_ops,
    embeddingbag_conversion_three_dim_configs,
    EmbeddingBagThreeDimFusedToFloatBase,
)

if __name__ == "__main__":
    op_bench.benchmark_runner.main()