File: summarize.py

package info (click to toggle)
caffe 1.0.0%2Bgit20180821.99bd997-8
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 16,288 kB
  • sloc: cpp: 61,586; python: 5,783; makefile: 599; sh: 559
file content (140 lines) | stat: -rwxr-xr-x 4,880 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#!/usr/bin/env python

"""Net summarization tool.

This tool summarizes the structure of a net in a concise but comprehensive
tabular listing, taking a prototxt file as input.

Use this tool to check at a glance that the computation you've specified is the
computation you expect.
"""

from caffe.proto import caffe_pb2
from google import protobuf
import re
import argparse

# ANSI codes for coloring blobs (used cyclically)
COLORS = ['92', '93', '94', '95', '97', '96', '42', '43;30', '100',
          '444', '103;30', '107;30']
DISCONNECTED_COLOR = '41'

def read_net(filename):
    net = caffe_pb2.NetParameter()
    with open(filename) as f:
        protobuf.text_format.Parse(f.read(), net)
    return net

def format_param(param):
    out = []
    if len(param.name) > 0:
        out.append(param.name)
    if param.lr_mult != 1:
        out.append('x{}'.format(param.lr_mult))
    if param.decay_mult != 1:
        out.append('Dx{}'.format(param.decay_mult))
    return ' '.join(out)

def printed_len(s):
    return len(re.sub(r'\033\[[\d;]+m', '', s))

def print_table(table, max_width):
    """Print a simple nicely-aligned table.

    table must be a list of (equal-length) lists. Columns are space-separated,
    and as narrow as possible, but no wider than max_width. Text may overflow
    columns; note that unlike string.format, this will not affect subsequent
    columns, if possible."""

    max_widths = [max_width] * len(table[0])
    column_widths = [max(printed_len(row[j]) + 1 for row in table)
                     for j in range(len(table[0]))]
    column_widths = [min(w, max_w) for w, max_w in zip(column_widths, max_widths)]

    for row in table:
        row_str = ''
        right_col = 0
        for cell, width in zip(row, column_widths):
            right_col += width
            row_str += cell + ' '
            row_str += ' ' * max(right_col - printed_len(row_str), 0)
        print row_str

def summarize_net(net):
    disconnected_tops = set()
    for lr in net.layer:
        disconnected_tops |= set(lr.top)
        disconnected_tops -= set(lr.bottom)

    table = []
    colors = {}
    for lr in net.layer:
        tops = []
        for ind, top in enumerate(lr.top):
            color = colors.setdefault(top, COLORS[len(colors) % len(COLORS)])
            if top in disconnected_tops:
                top = '\033[1;4m' + top
            if len(lr.loss_weight) > 0:
                top = '{} * {}'.format(lr.loss_weight[ind], top)
            tops.append('\033[{}m{}\033[0m'.format(color, top))
        top_str = ', '.join(tops)

        bottoms = []
        for bottom in lr.bottom:
            color = colors.get(bottom, DISCONNECTED_COLOR)
            bottoms.append('\033[{}m{}\033[0m'.format(color, bottom))
        bottom_str = ', '.join(bottoms)

        if lr.type == 'Python':
            type_str = lr.python_param.module + '.' + lr.python_param.layer
        else:
            type_str = lr.type

        # Summarize conv/pool parameters.
        # TODO support rectangular/ND parameters
        conv_param = lr.convolution_param
        if (lr.type in ['Convolution', 'Deconvolution']
                and len(conv_param.kernel_size) == 1):
            arg_str = str(conv_param.kernel_size[0])
            if len(conv_param.stride) > 0 and conv_param.stride[0] != 1:
                arg_str += '/' + str(conv_param.stride[0])
            if len(conv_param.pad) > 0 and conv_param.pad[0] != 0:
                arg_str += '+' + str(conv_param.pad[0])
            arg_str += ' ' + str(conv_param.num_output)
            if conv_param.group != 1:
                arg_str += '/' + str(conv_param.group)
        elif lr.type == 'Pooling':
            arg_str = str(lr.pooling_param.kernel_size)
            if lr.pooling_param.stride != 1:
                arg_str += '/' + str(lr.pooling_param.stride)
            if lr.pooling_param.pad != 0:
                arg_str += '+' + str(lr.pooling_param.pad)
        else:
            arg_str = ''

        if len(lr.param) > 0:
            param_strs = map(format_param, lr.param)
            if max(map(len, param_strs)) > 0:
                param_str = '({})'.format(', '.join(param_strs))
            else:
                param_str = ''
        else:
            param_str = ''

        table.append([lr.name, type_str, param_str, bottom_str, '->', top_str,
                      arg_str])
    return table

def main():
    parser = argparse.ArgumentParser(description="Print a concise summary of net computation.")
    parser.add_argument('filename', help='net prototxt file to summarize')
    parser.add_argument('-w', '--max-width', help='maximum field width',
            type=int, default=30)
    args = parser.parse_args()

    net = read_net(args.filename)
    table = summarize_net(net)
    print_table(table, max_width=args.max_width)

if __name__ == '__main__':
    main()