1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
|
import sys
import time
import torch
import inspect
import itertools
from functorch import pointwise_operator
torch.set_num_threads(1)
torch._C._debug_set_fusion_group_inlining(False)
def rand(*shape):
return torch.rand(*shape).mul(16).add(1)
# ------------------------------------------------------------------------------
# Shape test cases
# ------------------------------------------------------------------------------
def scalar():
return (rand(1), rand(1))
def small():
return (rand(32), rand(32))
def small_2d():
return (rand(1, 32), rand(1, 32))
def small_broadcast():
return (rand(4, 32), rand(32))
def medium():
return (rand(32, 12, 64, 64), rand(32, 12, 64, 64))
def medium_sliced():
return (rand(32, 12, 64, 64)[..., ::2],
rand(32, 12, 64, 64)[..., ::2])
def medium_transpose():
return (rand(32, 12, 64, 64).transpose(-1, -2),
rand(32, 12, 64, 64).transpose(-1, -2))
def medium2():
return (rand(32, 3, 224, 224), rand(32, 3, 224, 224))
def medium3d():
return (rand(16, 32, 64), rand(16, 32, 64))
def medium_channels_last():
return (rand(32, 3, 224, 224).to(memory_format=torch.channels_last),
rand(32, 3, 224, 224).to(memory_format=torch.channels_last))
def medium_broadcast():
return (rand(32, 12, 64, 64), rand(64))
def medium_broadcast_channels_last():
return (rand(32, 3, 223, 223).to(memory_format=torch.channels_last),
rand(3, 1, 1))
def large():
return (rand(8192, 8192), rand(8192, 8192))
def large_transpose():
return (rand(8192, 8192).transpose(0, 1),
rand(8192, 8192).transpose(0, 1))
def large_channels_last():
return (rand(32, 32, 256, 256).to(memory_format=torch.channels_last),
rand(32, 32, 256, 256).to(memory_format=torch.channels_last))
def pathological_broadcast():
return (rand(1, 32, 32, 2), rand(1024, 1, 1, 2))
# ------------------------------------------------------------------------------
# Operator test cases
# ------------------------------------------------------------------------------
def add(a, b):
return a + b
def sub(a, b):
return a - b
def mul(a, b):
return a * b
def div(a, b):
return a / b
def relu(a):
return a.relu()
def sigmoid(a):
return a.sigmoid()
def tanh(a):
return a.tanh()
def log(a):
return a.log()
def exp(a):
return a.exp()
def square(a):
return a ** 2
def fma(a, b):
return a * b + b
def hardswish(a):
return a * (a + 3.0).clamp(0.0, 6.0) / 6.0
def native_hardswish(a):
return torch._C._nn.hardswish(a)
def softplus(a):
return (a * 1.0).exp().log1p() / 1.0
def mish(a):
return a * ((a * 1.0).exp().log1p() / 1.0).tanh()
# ------------------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------------------
def time_cpu(fn, args, iters):
s = time.perf_counter()
for _ in range(iters):
fn(*args)
e = time.perf_counter()
return e - s
def time_cuda(fn, args, iters):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
fn(*args)
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / 1e3
def benchmark_with_timer(fn, args, timer):
timer(fn, args, 3)
calibration = timer(fn, args, 1)
iters = int(1.0 / calibration)
return timer(fn, args, iters) / iters
def benchmark(fn, args):
timer = time_cpu if args[0].device.type == "cpu" else time_cuda
return benchmark_with_timer(fn, args, timer)
def micros(s):
return f"{s * 1e6:.1f}"
shapes = [
scalar,
small,
small_2d,
small_broadcast,
medium,
medium2,
medium3d,
medium_sliced,
medium_transpose,
medium_channels_last,
medium_broadcast,
medium_broadcast_channels_last,
large,
large_transpose,
large_channels_last,
pathological_broadcast,
]
operators = [
add,
sub,
mul,
div,
relu,
sigmoid,
tanh,
log,
exp,
square,
fma,
hardswish,
native_hardswish,
]
nope = set()
for shape, operator in itertools.product(shapes, operators):
nargs = len(inspect.signature(operator).parameters)
args = shape()[:nargs]
try:
if shape == medium_transpose:
raise RuntimeError("pointwise_operator hangs on medium_transpose")
pw_op = pointwise_operator(operator)
torch.testing.assert_allclose(operator(*args), pw_op(*args))
except Exception:
print(f"pointwise_operator failed on {operator.__name__}, {shape.__name__}")
nope.add((operator, shape))
ts_op = torch.jit.script(operator)
torch.testing.assert_allclose(operator(*args), ts_op(*args))
print("fuser,device,operator,shape,time")
results = []
for shape, operator in itertools.product(shapes, operators):
nargs = len(inspect.signature(operator).parameters)
args = shape()[:nargs]
result = benchmark(operator, args)
print(",".join(["eager", args[0].device.type, operator.__name__, shape.__name__, micros(result)]))
try:
if shape == medium_transpose:
raise RuntimeError("pointwise_operator hangs on medium_transpose")
if (operator, shape) in nope:
raise RuntimeError("pointwise_operator fails on medium_transpose")
pw_op = pointwise_operator(operator)
result = benchmark(pw_op, args)
print(",".join(["pointwise", args[0].device.type, operator.__name__, shape.__name__, micros(result)]))
except Exception:
print(",".join(["pointwise", args[0].device.type, operator.__name__, shape.__name__, micros(float("nan"))]))
ts_op = torch.jit.script(operator)
result = benchmark(ts_op, args)
print(",".join(["fuser", args[0].device.type, operator.__name__, shape.__name__, micros(result)]))
sys.stdout.flush()
|