1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
|
import argparse
import copy
import json
import numpy as np
from caffe2.proto import caffe2_pb2
from caffe2.python import core, workspace, utils
import caffe2.python._import_c_extension as C
def pairwise(iterable):
from itertools import tee
a, b = tee(iterable)
next(b, None)
return zip(a, b)
def last_producer(ops, blob):
for (i, op) in reversed(list(enumerate(ops))):
if blob in op.output:
return i
raise ValueError("Failed to find last producer of blob, %s", blob)
def blob_uses(net, blob):
u = []
for i, op in enumerate(net.op):
if blob in op.input or blob in op.control_input:
u.append(i)
return u
def GetArgumentParser():
parser = argparse.ArgumentParser(description="Caffe2 optimization")
parser.add_argument("--init_net",
type=argparse.FileType('rb'),
help="init net")
parser.add_argument("--pred_net",
type=argparse.FileType('rb'),
help="predict net")
parser.add_argument("--verify_input",
type=argparse.FileType('r'),
help="input dims for verification")
parser.add_argument("--fuse_bn", default=False, action='store_true')
parser.add_argument("--fuse_mul_add", default=False, action='store_true')
parser.add_argument("--fuse_conv_relu", default=False, action='store_true')
return parser
def fuse_first_bn(net, params, removed_tensors):
net = copy.deepcopy(net)
params = copy.deepcopy(params)
for ((i, current), (j, next_)) in pairwise(enumerate(net.op)):
if next_.input[0] != current.output[0]:
continue
if current.type not in ("Conv", "ConvTranspose") \
or next_.type != "SpatialBN":
continue
if len(blob_uses(net, current.output[0])) != 1:
# Can't fuse if more than one user
continue
# else, can fuse
conv = current
bn = next_
fused_conv = copy.deepcopy(conv)
fused_conv.output[0] = bn.output[0]
# Fix fused_conv to ensure we have a bias passed.
if len(fused_conv.input) != 3:
bias_name = "{}_bias".format(conv.input[1])
net.external_input.extend([bias_name])
fused_conv.input.extend([bias_name])
for arg in fused_conv.arg:
if arg.name == "no_bias":
arg.i = 0
conv_weight = params[conv.input[1]]
conv_bias = params[conv.input[2]] if len(conv.input) == 3 \
else np.zeros(shape=(conv_weight.shape[0])).astype(np.float32)
bn_scale = params[bn.input[1]]
bn_bias = params[bn.input[2]]
bn_running_mean = params[bn.input[3]]
bn_running_var = params[bn.input[4]]
# First, BN computation can be phrased as follows:
# (X - running_mean) * (1.0 / sqrt(running_var + eps)) *
# bn_scale + bias
# Thus, we can rewrite bn_scale as:
# X * bn_scale * 1.0 / (sqrt(running_var + eps)) + (bias -
# running_mean * (1.0 / sqrt(running_var + eps)) * bn_scale)
# Thus, can just have the affine transform
# X * A + B
# where
# A = bn_scale * 1.0 / (sqrt(running_var + eps))
# B = (bias - running_mean * (1.0 / sqrt(running_var + eps))
# * bn_scale)
eps = 1.0e-5
for arg in bn.arg:
if arg.name == "epsilon":
eps = arg.f
A = bn_scale * 1.0 / (np.sqrt(bn_running_var + eps))
B = bn_bias - bn_running_mean * A
# This identify should hold if we have correctly fused
# np.testing.assert_array_equal(
# params[conv.output[0]] * A + B,
# params[bn.output[0]])
# Now, we have that the computation made is the following:
# ((X `conv` W) + b) * A + B
# Then, we can simply fuse this as follows:
# (X `conv` (W * A)) + b * A + B
# which is simply
# (X `conv` Q) + C
# where
# Q = W * A
# C = b * A + B
# For ConvTranspose, from the view of convolutions as a
# Toepeliz multiplication, we have W_ = W^T, so the weights
# are laid out as (R, S, K, K) (vs (S, R, K, K) for a Conv),
# so the weights broadcast slightly differently. Remember, our
# BN scale 'B' is of size (S,)
A_ = A.reshape(-1, 1, 1, 1) if conv.type == "Conv" else \
A.reshape(1, -1, 1, 1)
C = conv_bias * A + B
Q = conv_weight * A_
params[fused_conv.input[1]] = Q
params[fused_conv.input[2]] = C
new_ops = net.op[:i] + [fused_conv] + net.op[j + 1:]
del net.op[:]
removed_tensors.append(bn.input[1])
removed_tensors.append(bn.input[2])
removed_tensors.append(bn.input[3])
removed_tensors.append(bn.input[4])
del params[bn.input[1]]
del params[bn.input[2]]
del params[bn.input[3]]
del params[bn.input[4]]
net.op.extend(new_ops)
break
return net, params, removed_tensors
def fuse_bn(net, params, ignore_failure):
# Run until we hit a fixed point
removed_tensors = []
while True:
(next_net, next_params, removed_tensors) = \
fuse_first_bn(net, params, removed_tensors)
if len(next_net.op) == len(net.op):
if (
any(op.type == "SpatialBN" for op in next_net.op) and
not ignore_failure
):
raise Exception(
"Model contains SpatialBN op after fusion: %s", next_net)
return (next_net, next_params, removed_tensors)
net, params, removed_tensors = (next_net, next_params, removed_tensors)
def fuse_first_mul_add(net, params, removed_tensors):
net = copy.deepcopy(net)
params = copy.deepcopy(params)
for ((i, current), (j, next_)) in pairwise(enumerate(net.op)):
if current.type != "Mul" or next_.type != "Add":
continue
if next_.input[0] != current.output[0]:
raise Exception("Failure to fuse")
if len(blob_uses(net, current.output[0])) != 1:
raise Exception("Failure to fuse")
log.info("Fusing at index %s", i)
mul_ = current
add_ = next_
batch_norm = copy.deepcopy(mul_)
batch_norm.type = "SpatialBN"
batch_norm.arg.extend([utils.MakeArgument("is_test", 1)])
batch_norm.arg.extend([utils.MakeArgument("epsilon", float(1e-9))])
def s(x):
return "{}{}".format(add_.output[0], x)
fake_mean = s("_mean")
fake_var = s("_var")
del batch_norm.input[:]
batch_norm.input.extend([mul_.input[0],
mul_.input[1],
add_.input[1],
fake_mean,
fake_var])
params[fake_mean] = np.zeros_like(params[mul_.input[1]])
params[fake_var] = np.ones_like(params[mul_.input[1]])
net.external_input.extend([fake_mean, fake_var])
batch_norm.output[0] = add_.output[0]
new_ops = net.op[:i] + [batch_norm] + net.op[j + 1:]
del net.op[:]
net.op.extend(new_ops)
break
return net, params, removed_tensors
def fuse_mul_add(net, params):
# Run until we hit a fixed point
removed_tensors = []
while True:
(next_net, next_params, removed_tensors) = \
fuse_first_mul_add(net, params, removed_tensors)
if len(next_net.op) == len(net.op):
return (next_net, next_params, removed_tensors)
net, params, removed_tensors = (next_net, next_params, removed_tensors)
def add_tensor(net, name, blob):
''' Create an operator to store the tensor 'blob',
run the operator to put the blob to workspace.
uint8 is stored as an array of string with one element.
'''
kTypeNameMapper = {
np.dtype('float32'): "GivenTensorFill",
np.dtype('int32'): "GivenTensorIntFill",
np.dtype('int64'): "GivenTensorInt64Fill",
np.dtype('uint8'): "GivenTensorStringFill",
}
shape = blob.shape
values = blob
# pass array of uint8 as a string to save storage
# storing uint8_t has a large overhead for now
if blob.dtype == np.dtype('uint8'):
shape = [1]
values = [str(blob.data)]
op = core.CreateOperator(
kTypeNameMapper[blob.dtype],
[], [name],
arg=[
utils.MakeArgument("shape", shape),
utils.MakeArgument("values", values),
]
)
net.op.extend([op])
def gen_init_net_from_blobs(blobs):
''' Generate an initialization net based on a blob dict '''
ret = caffe2_pb2.NetDef()
for name, blob in blobs.items():
add_tensor(ret, name, blob)
return ret
def fuse_conv_relu(net):
net = copy.deepcopy(net)
device_option = core.DeviceOption(caffe2_pb2.IDEEP)
for op in net.op:
op.device_option.CopyFrom(device_option)
new_net = caffe2_pb2.NetDef()
new_net.ParseFromString(C.transform_optimizeForMKLDNN(net.SerializeToString()))
return new_net
def Optimize(args):
init_net = caffe2_pb2.NetDef()
predict_net = caffe2_pb2.NetDef()
init_net.ParseFromString(args.init_net.read())
predict_net.ParseFromString(args.pred_net.read())
workspace.ResetWorkspace()
workspace.RunNetOnce(init_net)
param_dict = {p: workspace.FetchBlob(p) for p in workspace.Blobs()}
external_inputs = {}
external_outputs = {}
if args.verify_input:
value_info = json.load(args.verify_input)
input_shapes = {k : v[-1] for (k, v) in value_info.items()}
print("input info: {}".format(input_shapes))
for k, v in input_shapes.items():
external_inputs[k] = np.random.randn(*v).astype(np.float32)
workspace.FeedBlob(k, external_inputs[k])
workspace.RunNetOnce(predict_net)
for o in predict_net.external_output:
external_outputs[o] = workspace.FetchBlob(o)
if args.fuse_mul_add:
predict_net, param_dict, _ = fuse_mul_add(predict_net, param_dict)
if args.fuse_bn:
predict_net, param_dict, _ = fuse_bn(predict_net, param_dict, False)
if args.fuse_conv_relu:
predict_net = fuse_conv_relu(predict_net)
external_outputs_opt = {}
if args.verify_input:
workspace.ResetWorkspace()
device_option = core.DeviceOption(caffe2_pb2.IDEEP) if args.fuse_conv_relu else core.DeviceOption(caffe2_pb2.CPU)
with core.DeviceScope(device_option):
for k, v in param_dict.items():
workspace.FeedBlob(k, v, device_option)
for k, v in external_inputs.items():
workspace.FeedBlob(k, v, device_option)
workspace.RunNetOnce(predict_net)
for o in predict_net.external_output:
external_outputs_opt[o] = workspace.FetchBlob(o)
assert np.allclose(external_outputs[o],
external_outputs_opt[o],
atol=1e-3,
rtol=1e-3)
for i, o in enumerate(predict_net.op):
print("op[{}]: {}".format(i, o.type))
init_net = gen_init_net_from_blobs(param_dict)
with open('init_net.pb', 'wb') as f:
f.write(init_net.SerializeToString())
with open('predict_net.pb', 'wb') as f:
f.write(predict_net.SerializeToString())
if __name__ == '__main__':
args = GetArgumentParser().parse_args()
Optimize(args)
|