1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
#!/usr/bin/env python
"""Net summarization tool.
This tool summarizes the structure of a net in a concise but comprehensive
tabular listing, taking a prototxt file as input.
Use this tool to check at a glance that the computation you've specified is the
computation you expect.
"""
from caffe.proto import caffe_pb2
from google import protobuf
import re
import argparse
# ANSI codes for coloring blobs (used cyclically)
COLORS = ['92', '93', '94', '95', '97', '96', '42', '43;30', '100',
'444', '103;30', '107;30']
DISCONNECTED_COLOR = '41'
def read_net(filename):
net = caffe_pb2.NetParameter()
with open(filename) as f:
protobuf.text_format.Parse(f.read(), net)
return net
def format_param(param):
out = []
if len(param.name) > 0:
out.append(param.name)
if param.lr_mult != 1:
out.append('x{}'.format(param.lr_mult))
if param.decay_mult != 1:
out.append('Dx{}'.format(param.decay_mult))
return ' '.join(out)
def printed_len(s):
return len(re.sub(r'\033\[[\d;]+m', '', s))
def print_table(table, max_width):
"""Print a simple nicely-aligned table.
table must be a list of (equal-length) lists. Columns are space-separated,
and as narrow as possible, but no wider than max_width. Text may overflow
columns; note that unlike string.format, this will not affect subsequent
columns, if possible."""
max_widths = [max_width] * len(table[0])
column_widths = [max(printed_len(row[j]) + 1 for row in table)
for j in range(len(table[0]))]
column_widths = [min(w, max_w) for w, max_w in zip(column_widths, max_widths)]
for row in table:
row_str = ''
right_col = 0
for cell, width in zip(row, column_widths):
right_col += width
row_str += cell + ' '
row_str += ' ' * max(right_col - printed_len(row_str), 0)
print row_str
def summarize_net(net):
disconnected_tops = set()
for lr in net.layer:
disconnected_tops |= set(lr.top)
disconnected_tops -= set(lr.bottom)
table = []
colors = {}
for lr in net.layer:
tops = []
for ind, top in enumerate(lr.top):
color = colors.setdefault(top, COLORS[len(colors) % len(COLORS)])
if top in disconnected_tops:
top = '\033[1;4m' + top
if len(lr.loss_weight) > 0:
top = '{} * {}'.format(lr.loss_weight[ind], top)
tops.append('\033[{}m{}\033[0m'.format(color, top))
top_str = ', '.join(tops)
bottoms = []
for bottom in lr.bottom:
color = colors.get(bottom, DISCONNECTED_COLOR)
bottoms.append('\033[{}m{}\033[0m'.format(color, bottom))
bottom_str = ', '.join(bottoms)
if lr.type == 'Python':
type_str = lr.python_param.module + '.' + lr.python_param.layer
else:
type_str = lr.type
# Summarize conv/pool parameters.
# TODO support rectangular/ND parameters
conv_param = lr.convolution_param
if (lr.type in ['Convolution', 'Deconvolution']
and len(conv_param.kernel_size) == 1):
arg_str = str(conv_param.kernel_size[0])
if len(conv_param.stride) > 0 and conv_param.stride[0] != 1:
arg_str += '/' + str(conv_param.stride[0])
if len(conv_param.pad) > 0 and conv_param.pad[0] != 0:
arg_str += '+' + str(conv_param.pad[0])
arg_str += ' ' + str(conv_param.num_output)
if conv_param.group != 1:
arg_str += '/' + str(conv_param.group)
elif lr.type == 'Pooling':
arg_str = str(lr.pooling_param.kernel_size)
if lr.pooling_param.stride != 1:
arg_str += '/' + str(lr.pooling_param.stride)
if lr.pooling_param.pad != 0:
arg_str += '+' + str(lr.pooling_param.pad)
else:
arg_str = ''
if len(lr.param) > 0:
param_strs = map(format_param, lr.param)
if max(map(len, param_strs)) > 0:
param_str = '({})'.format(', '.join(param_strs))
else:
param_str = ''
else:
param_str = ''
table.append([lr.name, type_str, param_str, bottom_str, '->', top_str,
arg_str])
return table
def main():
parser = argparse.ArgumentParser(description="Print a concise summary of net computation.")
parser.add_argument('filename', help='net prototxt file to summarize')
parser.add_argument('-w', '--max-width', help='maximum field width',
type=int, default=30)
args = parser.parse_args()
net = read_net(args.filename)
table = summarize_net(net)
print_table(table, max_width=args.max_width)
if __name__ == '__main__':
main()
|