File: SparseTransformer.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (206 lines) | stat: -rw-r--r-- 6,888 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# Copyright (c) 2016-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################

## @package SparseTransformer
# Module caffe2.experiments.python.SparseTransformer




from caffe2.python import workspace
import scipy.sparse


class NetDefNode():

    def __init__(self, name, optype, p=None, op=None):
        self.name = name
        self.optype = optype
        self.ops = {}
        self.prev = {}
        self.insertInput(p)
        self.visited = False
        self.op = op

    def insertInput(self, p):
        """
        Insert input of this op
        also maintain the output of previous op
        p: a node or a list of node
        """
        if isinstance(p, list):
            for i in p:
                self.prev[i.name] = i
                i.ops[self.name] = self
        elif isinstance(p, NetDefNode):
            self.prev[p.name] = p
            p.ops[self.name] = self

    def deleteInput(self, p):
        if isinstance(p, NetDefNode):
            del self.prev[p.name]
            del p.ops[self.name]


def maskNallocate(weight_name):
    """
    Combine mask and weights
    create wcsr, iw, jw, return their names
    """
    w = workspace.FetchBlob(weight_name)
    w_csr = scipy.sparse.csr_matrix(w)
    wcsr = w_csr.data
    iw = w_csr.indptr
    jw = w_csr.indices
    workspace.FeedBlob(weight_name + "wcsr", wcsr)
    workspace.FeedBlob(weight_name + "iw", iw)
    workspace.FeedBlob(weight_name + "jw", jw)
    return weight_name + "wcsr", weight_name + "iw", weight_name + "jw"


def transFCRelu(cur, id2node, name2id, ops, model):
    """
    Add trans before and after this FC_Prune->(Relu)->FC_Prune chain.
    """
    # 1. add trans before the start of this chain
    # assuming that cur is a FC_Prune, and it has only one input
    pre = cur.prev.itervalues().next()
    # Create a node /op and insert it.
    # TODO(wyiming): check whether it is correct here
    current_blob = model.Transpose(cur.op.input[0], cur.op.input[0] + "_trans")
#     print model.net.Proto()
    trans_op = model.net.Proto().op[-1]
    trans_node = NetDefNode(trans_op.output[0], "Transpose", pre, trans_op)
    trans_node.visited = True
    pre_new = trans_node

    # 2. use while loop to visit the chain
    while True:
        # breakup with the parent
        cur.deleteInput(pre)
        if not (cur.optype == "FC_Prune" or cur.optype == "Relu"):
            print("Reaching the end of the chain")
            break
        if len(cur.ops) > 1:
            print("A FC/Relu giving more than 1 useful outputs")
        if cur.optype == "FC_Prune":
            op = cur.op
            wcsr, iw, jw = maskNallocate(op.input[1])
            bias_name = op.input[3]
            # TODO(wyiming): create a new Op here
            current_blob = model.FC_Sparse(current_blob,
                                           cur.op.output[0] + "_Sparse",
                                           wcsr, iw, jw, bias_name)
            sps_op = model.net.Proto().op[-1]
            sps_node = NetDefNode(cur.op.output[0] + "_Sparse",
                                  "FC_Sparse",
                                  pre_new, sps_op)
            sps_node.visited = True
            pre_new = sps_node
        if cur.optype == "Relu":
            op = cur.op
            current_blob = model.Relu(current_blob, current_blob)
            rel_op = model.net.Proto().op[-1]
            rel_node = NetDefNode(str(current_blob), "Relu",
                                  pre_new, rel_op)
            rel_node.visited = True
            pre_new = rel_node

        cur.visited = True
        pre = cur
        flag = False
        for _, temp in cur.ops.iteritems():
            if temp.optype == "Relu" or temp.optype == "FC_Prune":
                flag = True
                cur = temp
        if not flag:
            # assume that there is only 1 output that is not PrintOP
            cur = cur.ops.itervalues().next()
            cur.deleteInput(pre)
            print("No FC/RElu children")
            print(cur.op.type)
            break
    # 3. add trans after this chain like 1.
    current_blob = model.Transpose(current_blob, pre.op.output[0])
    trans_op = model.net.Proto().op[-1]
    trans_node = NetDefNode(str(current_blob), "Transpose", pre_new, trans_op)
    trans_node.visited = True
    cur.insertInput(trans_node)
    print(cur.prev)
    print(trans_node.ops)


def Prune2Sparse(cur, id2node, name2id, ops, model):
    # Assume that FC and Relu takes in only 1 input;
    # If not raise warning
    if not cur.visited and cur.optype == "FC_Prune":
        transFCRelu(cur, id2node, name2id, ops, model)

    cur.visited = True
    for name, n in cur.ops.iteritems():
        Prune2Sparse(n, id2node, name2id, ops, model)


def net2list(net_root):
    """
    Use topological order(BFS) to print the op of a net in a list
    """
    bfs_queue = []
    op_list = []
    cur = net_root
    for _, n in cur.ops.iteritems():
        bfs_queue.append(n)
    while bfs_queue:
        node = bfs_queue[0]
        bfs_queue = bfs_queue[1:]
        op_list.append(node.op)
        for _, n in node.ops.iteritems():
            bfs_queue.append(n)

    return op_list


def netbuilder(model):
    print("Welcome to model checker")
    proto = model.net.Proto()
    net_name2id = {}
    net_id2node = {}
    net_root = NetDefNode("net_root", "root", None)

    for op_id, op in enumerate(proto.op):
        if op.type == "Print":
            continue
        op_name = '%s/%s (op#%d)' % (op.name, op.type, op_id) \
                  if op.name else '%s (op#%d)' % (op.type, op_id)
        # print(op_name)
        op_node = NetDefNode(op_name, op.type, op=op)
        net_id2node[op_id] = op_node

        if_has_layer_input = False
        for input_name in op.input:
            if input_name not in net_name2id:
                # assume that un_occured name are non_layers
                # TODO: write a non-layer checker and log it
                continue
            op_node.insertInput(net_id2node[net_name2id[input_name]])
            if_has_layer_input = True

        if not if_has_layer_input:
            op_node.insertInput(net_root)

        for output_name in op.output:
            net_name2id[output_name] = op_id

    return net_root, net_name2id, net_id2node