1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
|
import torch
import operator_benchmark as op_bench
# 2D pooling will have input matrix of rank 3 or 4
qpool2d_long_configs = op_bench.config_list(
attrs=(
# C H W k s p
( 1, 3, 3, (3, 3), (1, 1), (0, 0)), # dummy # noqa: E201,E241
( 3, 64, 64, (3, 3), (2, 2), (1, 1)), # dummy # noqa: E201,E241
# VGG16 pools with original input shape: (-1, 3, 224, 224)
( 64, 224, 224, (2, 2), (2, 2), (0, 0)), # MaxPool2d-4 # noqa: E201
(256, 56, 56, (2, 2), (2, 2), (0, 0)), # MaxPool2d-16 # noqa: E241
),
attr_names=('C', 'H', 'W', # Input layout
'k', 's', 'p'), # Pooling parameters
cross_product_configs={
'N': (1, 4),
'contig': (False, True),
'dtype': (torch.quint8,),
},
tags=('long',)
)
qpool2d_short_configs = op_bench.config_list(
attrs=((1, 3, 3, (3, 3), (1, 1), (0, 0)),), # dummy
attr_names=('C', 'H', 'W', # Input layout
'k', 's', 'p'), # Pooling parameters
cross_product_configs={
'N': (2,),
'contig': (True,),
'dtype': (torch.qint32, torch.qint8, torch.quint8),
},
tags=('short',)
)
qadaptive_avgpool2d_long_configs = op_bench.cross_product_configs(
input_size=(
# VGG16 pools with original input shape: (-1, 3, 224, 224)
(112, 112), # MaxPool2d-9
),
output_size=(
(448, 448),
# VGG16 pools with original input shape: (-1, 3, 224, 224)
(224, 224), # MaxPool2d-4
(112, 112), # MaxPool2d-9
( 56, 56), # MaxPool2d-16 # noqa: E201,E241
( 14, 14), # MaxPool2d-30 # noqa: E201,E241
),
N=(1, 4),
C=(1, 3, 64, 128),
contig=(False, True),
dtype=(torch.quint8,),
tags=('long',)
)
qadaptive_avgpool2d_short_configs = op_bench.config_list(
attrs=((4, 3, (224, 224), (112, 112), True),),
attr_names=('N', 'C', 'input_size', 'output_size', 'contig'),
cross_product_configs={
'dtype': (torch.qint32, torch.qint8, torch.quint8),
},
tags=('short',)
)
class _QPool2dBenchmarkBase(op_bench.TorchBenchmarkBase):
def setup(self, N, C, H, W, dtype, contig):
# Input
if N == 0:
f_input = (torch.rand(C, H, W) - 0.5) * 256
else:
f_input = (torch.rand(N, C, H, W) - 0.5) * 256
scale = 1.0
zero_point = 0
# Quantize the tensor
self.q_input = torch.quantize_per_tensor(f_input, scale=scale,
zero_point=zero_point,
dtype=dtype)
if not contig:
# Permute into NHWC and back to make it non-contiguous
if N == 0:
self.q_input = self.q_input.permute(1, 2, 0).contiguous()
self.q_input = self.q_input.permute(2, 0, 1)
else:
self.q_input = self.q_input.permute(0, 2, 3, 1).contiguous()
self.q_input = self.q_input.permute(0, 3, 1, 2)
self.inputs = {
"q_input": self.q_input
}
def forward(self, q_input):
return self.pool_op(q_input)
class QMaxPool2dBenchmark(_QPool2dBenchmarkBase):
def init(self, N, C, H, W, k, s, p, contig, dtype):
self.pool_op = torch.nn.MaxPool2d(kernel_size=k, stride=s, padding=p,
dilation=(1, 1), ceil_mode=False,
return_indices=False)
super(QMaxPool2dBenchmark, self).setup(N, C, H, W, dtype, contig)
class QAvgPool2dBenchmark(_QPool2dBenchmarkBase):
def init(self, N, C, H, W, k, s, p, contig, dtype):
self.pool_op = torch.nn.AvgPool2d(kernel_size=k, stride=s, padding=p,
ceil_mode=False)
super(QAvgPool2dBenchmark, self).setup(N, C, H, W, dtype, contig)
class QAdaptiveAvgPool2dBenchmark(_QPool2dBenchmarkBase):
def init(self, N, C, input_size, output_size, contig, dtype):
self.pool_op = torch.nn.AdaptiveAvgPool2d(output_size=output_size)
super(QAdaptiveAvgPool2dBenchmark, self).setup(N, C, *input_size,
dtype=dtype,
contig=contig)
op_bench.generate_pt_test(qadaptive_avgpool2d_short_configs + qadaptive_avgpool2d_long_configs,
QAdaptiveAvgPool2dBenchmark)
op_bench.generate_pt_test(qpool2d_short_configs + qpool2d_long_configs,
QAvgPool2dBenchmark)
op_bench.generate_pt_test(qpool2d_short_configs + qpool2d_long_configs,
QMaxPool2dBenchmark)
if __name__ == "__main__":
op_bench.benchmark_runner.main()
|