1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
|
import operator_benchmark as op_bench
import torch
embeddingbag_conversion_short_configs = op_bench.cross_product_configs(
num_embeddings=(80,),
embedding_dim=(128, 256, 512),
tags=('short',)
)
embeddingbag_conversion_long_configs = op_bench.cross_product_configs(
num_embeddings=(100, 120, 1000),
embedding_dim=(16, 64, 128, 256, 512, 1024, 2048),
tags=('long',)
)
embeddingbag_conversion_three_dim_configs = op_bench.cross_product_configs(
num_embeddings=(80,),
embedding_dim=(128, 256, 512),
batch_size=(10,),
tags=('short',)
)
conversion_ops = op_bench.op_list(
attrs=(
('qembeddingbag_byte_prepack', torch.ops.quantized.embedding_bag_byte_prepack),
('qembeddingbag_4bit_prepack', torch.ops.quantized.embedding_bag_4bit_prepack),
('qembeddingbag_2bit_prepack', torch.ops.quantized.embedding_bag_2bit_prepack),
),
attr_names=('op_name', 'op_func'),
)
unpack_ops = op_bench.op_list(
attrs=(
('qembeddingbag_byte_unpack', torch.ops.quantized.embedding_bag_byte_unpack),
('qembeddingbag_4bit_unpack', torch.ops.quantized.embedding_bag_4bit_unpack),
('qembeddingbag_2bit_unpack', torch.ops.quantized.embedding_bag_2bit_unpack),
),
attr_names=('op_name', 'op_func'),
)
class EmbeddingBagFloatToFusedBase(op_bench.TorchBenchmarkBase):
def init(self, num_embeddings, embedding_dim, op_func):
self.inputs = {
"weight": torch.rand(num_embeddings, embedding_dim, dtype=torch.float) + 1
}
self.op_func = op_func
def forward(self, weight):
return self.op_func(weight)
class EmbeddingBagThreeDimFloatToFusedBase(op_bench.TorchBenchmarkBase):
def init(self, num_embeddings, embedding_dim, batch_size, op_func):
self.inputs = {
"weight": torch.rand(batch_size, num_embeddings, embedding_dim, dtype=torch.float) + 1
}
self.op_func = op_func
def forward(self, weight):
return self.op_func(weight)
class EmbeddingBagFusedToFloatBase(op_bench.TorchBenchmarkBase):
def init(self, num_embeddings, embedding_dim, op_func):
weight = torch.randn(num_embeddings, embedding_dim + 8, dtype=torch.float)
self.inputs = {
"packed_weight": weight.to(torch.uint8)
}
self.op_func = op_func
def forward(self, packed_weight):
return self.op_func(packed_weight)
class EmbeddingBagThreeDimFusedToFloatBase(op_bench.TorchBenchmarkBase):
def init(self, num_embeddings, embedding_dim, batch_size, op_func):
weight = torch.randn(batch_size, num_embeddings, embedding_dim + 8, dtype=torch.float)
self.inputs = {
"packed_weight": weight.to(torch.uint8)
}
self.op_func = op_func
def forward(self, packed_weight):
return self.op_func(packed_weight)
op_bench.generate_pt_tests_from_op_list(conversion_ops,
embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs,
EmbeddingBagFloatToFusedBase)
op_bench.generate_pt_tests_from_op_list(unpack_ops,
embeddingbag_conversion_short_configs + embeddingbag_conversion_long_configs,
EmbeddingBagFusedToFloatBase)
op_bench.generate_pt_tests_from_op_list(conversion_ops,
embeddingbag_conversion_three_dim_configs,
EmbeddingBagThreeDimFloatToFusedBase)
op_bench.generate_pt_tests_from_op_list(unpack_ops,
embeddingbag_conversion_three_dim_configs,
EmbeddingBagThreeDimFusedToFloatBase)
if __name__ == "__main__":
op_bench.benchmark_runner.main()
|