1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
|
# Copyright (c) 2016-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
## @package SparseTransformer
# Module caffe2.experiments.python.SparseTransformer
from caffe2.python import workspace
import scipy.sparse
class NetDefNode():
def __init__(self, name, optype, p=None, op=None):
self.name = name
self.optype = optype
self.ops = {}
self.prev = {}
self.insertInput(p)
self.visited = False
self.op = op
def insertInput(self, p):
"""
Insert input of this op
also maintain the output of previous op
p: a node or a list of node
"""
if isinstance(p, list):
for i in p:
self.prev[i.name] = i
i.ops[self.name] = self
elif isinstance(p, NetDefNode):
self.prev[p.name] = p
p.ops[self.name] = self
def deleteInput(self, p):
if isinstance(p, NetDefNode):
del self.prev[p.name]
del p.ops[self.name]
def maskNallocate(weight_name):
"""
Combine mask and weights
create wcsr, iw, jw, return their names
"""
w = workspace.FetchBlob(weight_name)
w_csr = scipy.sparse.csr_matrix(w)
wcsr = w_csr.data
iw = w_csr.indptr
jw = w_csr.indices
workspace.FeedBlob(weight_name + "wcsr", wcsr)
workspace.FeedBlob(weight_name + "iw", iw)
workspace.FeedBlob(weight_name + "jw", jw)
return weight_name + "wcsr", weight_name + "iw", weight_name + "jw"
def transFCRelu(cur, id2node, name2id, ops, model):
"""
Add trans before and after this FC_Prune->(Relu)->FC_Prune chain.
"""
# 1. add trans before the start of this chain
# assuming that cur is a FC_Prune, and it has only one input
pre = cur.prev.itervalues().next()
# Create a node /op and insert it.
# TODO(wyiming): check whether it is correct here
current_blob = model.Transpose(cur.op.input[0], cur.op.input[0] + "_trans")
# print model.net.Proto()
trans_op = model.net.Proto().op[-1]
trans_node = NetDefNode(trans_op.output[0], "Transpose", pre, trans_op)
trans_node.visited = True
pre_new = trans_node
# 2. use while loop to visit the chain
while True:
# breakup with the parent
cur.deleteInput(pre)
if not (cur.optype == "FC_Prune" or cur.optype == "Relu"):
print("Reaching the end of the chain")
break
if len(cur.ops) > 1:
print("A FC/Relu giving more than 1 useful outputs")
if cur.optype == "FC_Prune":
op = cur.op
wcsr, iw, jw = maskNallocate(op.input[1])
bias_name = op.input[3]
# TODO(wyiming): create a new Op here
current_blob = model.FC_Sparse(current_blob,
cur.op.output[0] + "_Sparse",
wcsr, iw, jw, bias_name)
sps_op = model.net.Proto().op[-1]
sps_node = NetDefNode(cur.op.output[0] + "_Sparse",
"FC_Sparse",
pre_new, sps_op)
sps_node.visited = True
pre_new = sps_node
if cur.optype == "Relu":
op = cur.op
current_blob = model.Relu(current_blob, current_blob)
rel_op = model.net.Proto().op[-1]
rel_node = NetDefNode(str(current_blob), "Relu",
pre_new, rel_op)
rel_node.visited = True
pre_new = rel_node
cur.visited = True
pre = cur
flag = False
for _, temp in cur.ops.iteritems():
if temp.optype == "Relu" or temp.optype == "FC_Prune":
flag = True
cur = temp
if not flag:
# assume that there is only 1 output that is not PrintOP
cur = cur.ops.itervalues().next()
cur.deleteInput(pre)
print("No FC/RElu children")
print(cur.op.type)
break
# 3. add trans after this chain like 1.
current_blob = model.Transpose(current_blob, pre.op.output[0])
trans_op = model.net.Proto().op[-1]
trans_node = NetDefNode(str(current_blob), "Transpose", pre_new, trans_op)
trans_node.visited = True
cur.insertInput(trans_node)
print(cur.prev)
print(trans_node.ops)
def Prune2Sparse(cur, id2node, name2id, ops, model):
# Assume that FC and Relu takes in only 1 input;
# If not raise warning
if not cur.visited and cur.optype == "FC_Prune":
transFCRelu(cur, id2node, name2id, ops, model)
cur.visited = True
for name, n in cur.ops.iteritems():
Prune2Sparse(n, id2node, name2id, ops, model)
def net2list(net_root):
"""
Use topological order(BFS) to print the op of a net in a list
"""
bfs_queue = []
op_list = []
cur = net_root
for _, n in cur.ops.iteritems():
bfs_queue.append(n)
while bfs_queue:
node = bfs_queue[0]
bfs_queue = bfs_queue[1:]
op_list.append(node.op)
for _, n in node.ops.iteritems():
bfs_queue.append(n)
return op_list
def netbuilder(model):
print("Welcome to model checker")
proto = model.net.Proto()
net_name2id = {}
net_id2node = {}
net_root = NetDefNode("net_root", "root", None)
for op_id, op in enumerate(proto.op):
if op.type == "Print":
continue
op_name = '%s/%s (op#%d)' % (op.name, op.type, op_id) \
if op.name else '%s (op#%d)' % (op.type, op_id)
# print(op_name)
op_node = NetDefNode(op_name, op.type, op=op)
net_id2node[op_id] = op_node
if_has_layer_input = False
for input_name in op.input:
if input_name not in net_name2id:
# assume that un_occured name are non_layers
# TODO: write a non-layer checker and log it
continue
op_node.insertInput(net_id2node[net_name2id[input_name]])
if_has_layer_input = True
if not if_has_layer_input:
op_node.insertInput(net_root)
for output_name in op.output:
net_name2id[output_name] = op_id
return net_root, net_name2id, net_id2node
|