1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
|
## @package app
# Module caffe2.python.mint.app
import argparse
import flask
import glob
import numpy as np
import nvd3
import os
import sys
# pyre-fixme[21]: Could not find module `tornado.httpserver`.
import tornado.httpserver
# pyre-fixme[21]: Could not find a module corresponding to import `tornado.wsgi`
import tornado.wsgi
__folder__ = os.path.abspath(os.path.dirname(__file__))
app = flask.Flask(
__name__,
template_folder=os.path.join(__folder__, "templates"),
static_folder=os.path.join(__folder__, "static")
)
args = None
def jsonify_nvd3(chart):
chart.buildcontent()
# Note(Yangqing): python-nvd3 does not seem to separate the built HTML part
# and the script part. Luckily, it seems to be the case that the HTML part is
# only a <div>, which can be accessed by chart.container; the script part,
# while the script part occupies the rest of the html content, which we can
# then find by chart.htmlcontent.find['<script>'].
script_start = chart.htmlcontent.find('<script>') + 8
script_end = chart.htmlcontent.find('</script>')
return flask.jsonify(
result=chart.container,
script=chart.htmlcontent[script_start:script_end].strip()
)
def visualize_summary(filename):
try:
data = np.loadtxt(filename)
except Exception as e:
return 'Cannot load file {}: {}'.format(filename, str(e))
chart_name = os.path.splitext(os.path.basename(filename))[0]
chart = nvd3.lineChart(
name=chart_name + '_summary_chart',
height=args.chart_height,
y_axis_format='.03g'
)
if args.sample < 0:
step = max(data.shape[0] / -args.sample, 1)
else:
step = args.sample
xdata = np.arange(0, data.shape[0], step)
# data should have 4 dimensions.
chart.add_serie(x=xdata, y=data[xdata, 0], name='min')
chart.add_serie(x=xdata, y=data[xdata, 1], name='max')
chart.add_serie(x=xdata, y=data[xdata, 2], name='mean')
chart.add_serie(x=xdata, y=data[xdata, 2] + data[xdata, 3], name='m+std')
chart.add_serie(x=xdata, y=data[xdata, 2] - data[xdata, 3], name='m-std')
return jsonify_nvd3(chart)
def visualize_print_log(filename):
try:
data = np.loadtxt(filename)
if data.ndim == 1:
data = data[:, np.newaxis]
except Exception as e:
return 'Cannot load file {}: {}'.format(filename, str(e))
chart_name = os.path.splitext(os.path.basename(filename))[0]
chart = nvd3.lineChart(
name=chart_name + '_log_chart',
height=args.chart_height,
y_axis_format='.03g'
)
if args.sample < 0:
step = max(data.shape[0] / -args.sample, 1)
else:
step = args.sample
xdata = np.arange(0, data.shape[0], step)
# if there is only one curve, we also show the running min and max
if data.shape[1] == 1:
# We also print the running min and max for the steps.
trunc_size = data.shape[0] / step
running_mat = data[:trunc_size * step].reshape((trunc_size, step))
chart.add_serie(
x=xdata[:trunc_size],
y=running_mat.min(axis=1),
name='running_min'
)
chart.add_serie(
x=xdata[:trunc_size],
y=running_mat.max(axis=1),
name='running_max'
)
chart.add_serie(x=xdata, y=data[xdata, 0], name=chart_name)
else:
for i in range(0, min(data.shape[1], args.max_curves)):
# data should have 4 dimensions.
chart.add_serie(
x=xdata,
y=data[xdata, i],
name='{}[{}]'.format(chart_name, i)
)
return jsonify_nvd3(chart)
def visualize_file(filename):
fullname = os.path.join(args.root, filename)
if filename.endswith('summary'):
return visualize_summary(fullname)
elif filename.endswith('log'):
return visualize_print_log(fullname)
else:
return flask.jsonify(
result='Unsupport file: {}'.format(filename),
script=''
)
@app.route('/')
def index():
files = glob.glob(os.path.join(args.root, "*.*"))
files.sort()
names = [os.path.basename(f) for f in files]
return flask.render_template(
'index.html',
root=args.root,
names=names,
debug_messages=names
)
@app.route('/visualization/<string:name>')
def visualization(name):
ret = visualize_file(name)
return ret
def main(argv):
parser = argparse.ArgumentParser("The mint visualizer.")
parser.add_argument(
'-p',
'--port',
type=int,
default=5000,
help="The flask port to use."
)
parser.add_argument(
'-r',
'--root',
type=str,
default='.',
help="The root folder to read files for visualization."
)
parser.add_argument(
'--max_curves',
type=int,
default=5,
help="The max number of curves to show in a dump tensor."
)
parser.add_argument(
'--chart_height',
type=int,
default=300,
help="The chart height for nvd3."
)
parser.add_argument(
'-s',
'--sample',
type=int,
default=-200,
help="Sample every given number of data points. A negative "
"number means the total points we will sample on the "
"whole curve. Default 100 points."
)
global args
args = parser.parse_args(argv)
server = tornado.httpserver.HTTPServer(tornado.wsgi.WSGIContainer(app))
server.listen(args.port)
print("Tornado server starting on port {}.".format(args.port))
tornado.ioloop.IOLoop.instance().start()
if __name__ == '__main__':
main(sys.argv[1:])
|