1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
|
--[[ Logger: a simple class to log symbols during training,
and automate plot generation
Example:
logger = optim.Logger('somefile.log') -- file to save stuff
for i = 1,N do -- log some symbols during
train_error = ... -- training/testing
test_error = ...
logger:add{['training error'] = train_error,
['test error'] = test_error}
end
logger:style{['training error'] = '-', -- define styles for plots
['test error'] = '-'}
logger:plot() -- and plot
---- OR ---
logger = optim.Logger('somefile.log') -- file to save stuff
logger:setNames{'training error', 'test error'}
for i = 1,N do -- log some symbols during
train_error = ... -- training/testing
test_error = ...
logger:add{train_error, test_error}
end
logger:style{'-', '-'} -- define styles for plots
logger:plot() -- and plot
-----------
logger:setlogscale(true) -- enable logscale on Y-axis
logger:plot() -- and plot
]]
require 'xlua'
local Logger = torch.class('optim.Logger')
function Logger:__init(filename, timestamp)
if filename then
self.name = filename
os.execute('mkdir ' .. (sys.uname() ~= 'windows' and '-p ' or '') .. ' "' .. paths.dirname(filename) .. '"')
if timestamp then
-- append timestamp to create unique log file
filename = filename .. '-'..os.date("%Y_%m_%d_%X")
end
self.file = io.open(filename,'w')
self.epsfile = self.name .. '.eps'
else
self.file = io.stdout
self.name = 'stdout'
print('<Logger> warning: no path provided, logging to std out')
end
self.empty = true
self.symbols = {}
self.styles = {}
self.names = {}
self.idx = {}
self.figure = nil
self.showPlot = true
self.plotRawCmd = nil
self.defaultStyle = '+'
self.logscale = false
end
function Logger:setNames(names)
self.names = names
self.empty = false
self.nsymbols = #names
for k,key in pairs(names) do
self.file:write(key .. '\t')
self.symbols[k] = {}
self.styles[k] = {self.defaultStyle}
self.idx[key] = k
end
self.file:write('\n')
self.file:flush()
return self
end
function Logger:add(symbols)
-- (1) first time ? print symbols' names on first row
if self.empty then
self.empty = false
self.nsymbols = #symbols
for k,val in pairs(symbols) do
self.file:write(k .. '\t')
self.symbols[k] = {}
self.styles[k] = {self.defaultStyle}
self.names[k] = k
end
self.idx = self.names
self.file:write('\n')
end
-- (2) print all symbols on one row
for k,val in pairs(symbols) do
if type(val) == 'number' then
self.file:write(string.format('%11.4e',val) .. '\t')
elseif type(val) == 'string' then
self.file:write(val .. '\t')
else
xlua.error('can only log numbers and strings', 'Logger')
end
end
self.file:write('\n')
self.file:flush()
-- (3) save symbols in internal table
for k,val in pairs(symbols) do
table.insert(self.symbols[k], val)
end
end
function Logger:style(symbols)
for name,style in pairs(symbols) do
if type(style) == 'string' then
self.styles[name] = {style}
elseif type(style) == 'table' then
self.styles[name] = style
else
xlua.error('style should be a string or a table of strings','Logger')
end
end
return self
end
function Logger:setlogscale(state)
self.logscale = state
end
function Logger:display(state)
self.showPlot = state
end
function Logger:plot(...)
if not xlua.require('gnuplot') then
if not self.warned then
print('<Logger> warning: cannot plot with this version of Torch')
self.warned = true
end
return
end
local plotit = false
local plots = {}
local plotsymbol =
function(name,list)
if #list > 1 then
local nelts = #list
local plot_y = torch.Tensor(nelts)
for i = 1,nelts do
plot_y[i] = list[i]
end
for _,style in ipairs(self.styles[name]) do
table.insert(plots, {self.names[name], plot_y, style})
end
plotit = true
end
end
local args = {...}
if not args[1] then -- plot all symbols
for name,list in pairs(self.symbols) do
plotsymbol(name,list)
end
else -- plot given symbols
for _,name in ipairs(args) do
plotsymbol(self.idx[name], self.symbols[self.idx[name]])
end
end
if plotit then
if self.showPlot then
self.figure = gnuplot.figure(self.figure)
if self.logscale then gnuplot.logscale('on') end
gnuplot.plot(plots)
if self.plotRawCmd then gnuplot.raw(self.plotRawCmd) end
gnuplot.grid('on')
gnuplot.title('<Logger::' .. self.name .. '>')
end
if self.epsfile then
os.execute('rm -f "' .. self.epsfile .. '"')
local epsfig = gnuplot.epsfigure(self.epsfile)
if self.logscale then gnuplot.logscale('on') end
gnuplot.plot(plots)
if self.plotRawCmd then gnuplot.raw(self.plotRawCmd) end
gnuplot.grid('on')
gnuplot.title('<Logger::' .. self.name .. '>')
gnuplot.plotflush()
gnuplot.close(epsfig)
end
end
end
|