1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
|
import operator_benchmark as op_bench
import torch
embeddingbag_conversion_short_configs = op_bench.cross_product_configs(
num_embeddings=(80,), embedding_dim=(128, 256, 512), tags=("short",)
)
embeddingbag_conversion_long_configs = op_bench.cross_product_configs(
num_embeddings=(100, 120, 1000),
embedding_dim=(16, 64, 128, 256, 512, 1024, 2048),
tags=("long",),
)
embeddingbag_conversion_three_dim_configs = op_bench.cross_product_configs(
num_embeddings=(80,),
embedding_dim=(128, 256, 512),
batch_size=(10,),
tags=("short",),
)
conversion_ops = op_bench.op_list(
attrs=(
("qembeddingbag_byte_prepack", torch.ops.quantized.embedding_bag_byte_prepack),
("qembeddingbag_4bit_prepack", torch.ops.quantized.embedding_bag_4bit_prepack),
("qembeddingbag_2bit_prepack", torch.ops.quantized.embedding_bag_2bit_prepack),
),
attr_names=("op_name", "op_func"),
)
unpack_ops = op_bench.op_list(
attrs=(
("qembeddingbag_byte_unpack", torch.ops.quantized.embedding_bag_byte_unpack),
("qembeddingbag_4bit_unpack", torch.ops.quantized.embedding_bag_4bit_unpack),
("qembeddingbag_2bit_unpack", torch.ops.quantized.embedding_bag_2bit_unpack),
),
attr_names=("op_name", "op_func"),
)
class EmbeddingBagFloatToFusedBase(op_bench.TorchBenchmarkBase):
def init(self, num_embeddings, embedding_dim, op_func):
self.inputs = {
"weight": torch.rand(num_embeddings, embedding_dim, dtype=torch.float) + 1
}
self.op_func = op_func
def forward(self, weight):
return self.op_func(weight)
class EmbeddingBagThreeDimFloatToFusedBase(op_bench.TorchBenchmarkBase):
def init(self, num_embeddings, embedding_dim, batch_size, op_func):
self.inputs = {
"weight": torch.rand(
batch_size, num_embeddings, embedding_dim, dtype=torch.float
)
+ 1
}
self.op_func = op_func
def forward(self, weight):
return self.op_func(weight)
class EmbeddingBagFusedToFloatBase(op_bench.TorchBenchmarkBase):
def init(self, num_embeddings, embedding_dim, op_func):
weight = torch.randn(num_embeddings, embedding_dim + 8, dtype=torch.float)
self.inputs = {"packed_weight": weight.to(torch.uint8)}
self.op_func = op_func
def forward(self, packed_weight):
return self.op_func(packed_weight)
class EmbeddingBagThreeDimFusedToFloatBase(op_bench.TorchBenchmarkBase):
def init(self, num_embeddings, embedding_dim, batch_size, op_func):
weight = torch.randn(
batch_size, num_embeddings, embedding_dim + 8, dtype=torch.float
)
self.inputs = {"packed_weight": weight.to(torch.uint8)}
self.op_func = op_func
def forward(self, packed_weight):
return self.op_func(packed_weight)
op_bench.generate_pt_tests_from_op_list(
conversion_ops,
embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs,
EmbeddingBagFloatToFusedBase,
)
op_bench.generate_pt_tests_from_op_list(
unpack_ops,
embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs,
EmbeddingBagFusedToFloatBase,
)
op_bench.generate_pt_tests_from_op_list(
conversion_ops,
embeddingbag_conversion_three_dim_configs,
EmbeddingBagThreeDimFloatToFusedBase,
)
op_bench.generate_pt_tests_from_op_list(
unpack_ops,
embeddingbag_conversion_three_dim_configs,
EmbeddingBagThreeDimFusedToFloatBase,
)
if __name__ == "__main__":
op_bench.benchmark_runner.main()
|