1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
|
--[[ A Confusion Matrix class
Example:
conf = optim.ConfusionMatrix( {'cat','dog','person'} ) -- new matrix
conf:zero() -- reset matrix
for i = 1,N do
conf:add( neuralnet:forward(sample), label ) -- accumulate errors
end
print(conf) -- print matrix
image.display(conf:render()) -- render matrix
]]
local ConfusionMatrix = torch.class('optim.ConfusionMatrix')
function ConfusionMatrix:__init(nclasses, classes)
if type(nclasses) == 'table' then
classes = nclasses
nclasses = #classes
end
self.mat = torch.LongTensor(nclasses,nclasses):zero()
self.valids = torch.FloatTensor(nclasses):zero()
self.unionvalids = torch.FloatTensor(nclasses):zero()
self.nclasses = nclasses
self.totalValid = 0
self.averageValid = 0
self.classes = classes or {}
-- buffers
self._mat_flat = self.mat:view(-1)
self._target = torch.FloatTensor()
self._prediction = torch.FloatTensor()
self._max = torch.FloatTensor()
self._pred_idx = torch.LongTensor()
self._targ_idx = torch.LongTensor()
end
-- takes scalar prediction and target as input
function ConfusionMatrix:_add(p, t)
assert(p and type(p) == 'number')
assert(t and type(t) == 'number')
-- non-positive values are considered missing
-- and therefore ignored
if t > 0 then
self.mat[t][p] = self.mat[t][p] + 1
end
end
function ConfusionMatrix:add(prediction, target)
if type(prediction) == 'number' then
-- comparing numbers
self:_add(prediction, target)
else
self._prediction:resize(prediction:size()):copy(prediction)
assert(prediction:dim() == 1)
if type(target) == 'number' then
-- prediction is a vector, then target assumed to be an index
self._max:max(self._pred_idx, self._prediction, 1)
self:_add(self._pred_idx[1], target)
else
-- both prediction and target are vectors
assert(target:dim() == 1)
self._target:resize(target:size()):copy(target)
self._max:max(self._targ_idx, self._target, 1)
self._max:max(self._pred_idx, self._prediction, 1)
self:_add(self._pred_idx[1], self._targ_idx[1])
end
end
end
function ConfusionMatrix:batchAdd(predictions, targets)
local preds, targs, __
self._prediction:resize(predictions:size()):copy(predictions)
if predictions:dim() == 1 then
-- predictions is a vector of classes
preds = self._prediction
elseif predictions:dim() == 2 then
-- prediction is a matrix of class likelihoods
if predictions:size(2) == 1 then
-- or prediction just needs flattening
preds = self._prediction:select(2,1)
else
self._max:max(self._pred_idx, self._prediction, 2)
preds = self._pred_idx:select(2,1)
end
else
error("predictions has invalid number of dimensions")
end
self._target:resize(targets:size()):copy(targets)
if targets:dim() == 1 then
-- targets is a vector of classes
targs = self._target
elseif targets:dim() == 2 then
-- targets is a matrix of one-hot rows
if targets:size(2) == 1 then
-- or targets just needs flattening
targs = self._target:select(2,1)
else
self._max:max(self._targ_idx, self._target, 2)
targs = self._targ_idx:select(2,1)
end
else
error("targets has invalid number of dimensions")
end
-- non-positive values are considered missing and therefore ignored
local mask = targs:ge(1)
targs = targs[mask]
preds = preds[mask]
self._mat_flat = self._mat_flat or self.mat:view(-1) -- for backward compatibility
preds = preds:typeAs(targs)
assert(self.mat:isContiguous() and self.mat:stride(2) == 1)
local indices = ((targs - 1) * self.mat:stride(1) + preds):typeAs(self.mat)
local ones = torch.ones(1):typeAs(self.mat):expand(indices:size(1))
self._mat_flat:indexAdd(1, indices, ones)
end
function ConfusionMatrix:zero()
self.mat:zero()
self.valids:zero()
self.unionvalids:zero()
self.totalValid = 0
self.averageValid = 0
end
local function isNaN(number)
return number ~= number
end
function ConfusionMatrix:updateValids()
local total = 0
for t = 1,self.nclasses do
self.valids[t] = self.mat[t][t] / self.mat:select(1,t):sum()
self.unionvalids[t] = self.mat[t][t] / (self.mat:select(1,t):sum()+self.mat:select(2,t):sum()-self.mat[t][t])
total = total + self.mat[t][t]
end
self.totalValid = total / self.mat:sum()
self.averageValid = 0
self.averageUnionValid = 0
local nvalids = 0
local nunionvalids = 0
for t = 1,self.nclasses do
if not isNaN(self.valids[t]) then
self.averageValid = self.averageValid + self.valids[t]
nvalids = nvalids + 1
end
if not isNaN(self.valids[t]) and not isNaN(self.unionvalids[t]) then
self.averageUnionValid = self.averageUnionValid + self.unionvalids[t]
nunionvalids = nunionvalids + 1
end
end
self.averageValid = self.averageValid / nvalids
self.averageUnionValid = self.averageUnionValid / nunionvalids
end
-- Calculating FAR/FRR associated with the confusion matrix
function ConfusionMatrix:farFrr()
local cmat = self.mat
local noOfClasses = cmat:size()[1]
self._frrs = self._frrs or torch.zeros(noOfClasses)
self._frrs:zero()
self._classFrrs = self._classFrrs or torch.zeros(noOfClasses)
self._classFrrs:zero()
self._classFrrs:add(-1)
self._fars = self._fars or torch.zeros(noOfClasses)
self._fars:zero()
self._classFars = self._classFars or torch.zeros(noOfClasses)
self._classFars:zero()
self._classFars:add(-1)
local classSamplesCount = cmat:sum(2)
local indx = 1
for i=1,noOfClasses do
if classSamplesCount[i][1] ~= 0 then
self._frrs[indx] = 1 - cmat[i][i]/classSamplesCount[i][1]
self._classFrrs[i] = self._frrs[indx]
-- Calculating FARs
local farNumerator = 0
local farDenominator = 0
for j=1, noOfClasses do
if i ~= j then
if classSamplesCount[j][1] ~= 0 then
farNumerator = farNumerator + cmat[j][i]/classSamplesCount[j][1]
farDenominator = farDenominator + 1
end
end
end
self._fars[indx] = farNumerator/farDenominator
self._classFars[i] = self._fars[indx]
indx = indx + 1
end
end
indx = indx - 1
local returnFrrs = self._frrs[{{1, indx}}]
local returnFars = self._fars[{{1, indx}}]
return self._classFrrs, self._classFars, returnFrrs, returnFars
end
local function log10(n)
if math.log10 then
return math.log10(n)
else
return math.log(n) / math.log(10)
end
end
function ConfusionMatrix:__tostring__()
self:updateValids()
local str = {'ConfusionMatrix:\n'}
local nclasses = self.nclasses
table.insert(str, '[')
local maxCnt = self.mat:max()
local nDigits = math.max(8, 1 + math.ceil(log10(maxCnt)))
for t = 1,nclasses do
local pclass = self.valids[t] * 100
pclass = string.format('%2.3f', pclass)
if t == 1 then
table.insert(str, '[')
else
table.insert(str, ' [')
end
for p = 1,nclasses do
table.insert(str, string.format('%' .. nDigits .. 'd', self.mat[t][p]))
end
if self.classes and self.classes[1] then
if t == nclasses then
table.insert(str, ']] ' .. pclass .. '% \t[class: ' .. (self.classes[t] or '') .. ']\n')
else
table.insert(str, '] ' .. pclass .. '% \t[class: ' .. (self.classes[t] or '') .. ']\n')
end
else
if t == nclasses then
table.insert(str, ']] ' .. pclass .. '% \n')
else
table.insert(str, '] ' .. pclass .. '% \n')
end
end
end
table.insert(str, ' + average row correct: ' .. (self.averageValid*100) .. '% \n')
table.insert(str, ' + average rowUcol correct (VOC measure): ' .. (self.averageUnionValid*100) .. '% \n')
table.insert(str, ' + global correct: ' .. (self.totalValid*100) .. '%')
return table.concat(str)
end
function ConfusionMatrix:render(sortmode, display, block, legendwidth)
-- args
local confusion = self.mat:double()
local classes = self.classes
local sortmode = sortmode or 'score' -- 'score' or 'occurrence'
local block = block or 25
local legendwidth = legendwidth or 200
local display = display or false
-- legends
local legend = {
['score'] = 'Confusion matrix [sorted by scores, global accuracy = %0.3f%%, per-class accuracy = %0.3f%%]',
['occurrence'] = 'Confusion matrix [sorted by occurrences, accuracy = %0.3f%%, per-class accuracy = %0.3f%%]'
}
-- parse matrix / normalize / count scores
local diag = torch.FloatTensor(#classes)
local freqs = torch.FloatTensor(#classes)
local unconf = confusion
local confusion = confusion:clone()
local corrects = 0
local total = 0
for target = 1,#classes do
freqs[target] = confusion[target]:sum()
corrects = corrects + confusion[target][target]
total = total + freqs[target]
confusion[target]:div( math.max(confusion[target]:sum(),1) )
diag[target] = confusion[target][target]
end
-- accuracies
local accuracy = corrects / total * 100
local perclass = 0
local total = 0
for target = 1,#classes do
if confusion[target]:sum() > 0 then
perclass = perclass + diag[target]
total = total + 1
end
end
perclass = perclass / total * 100
freqs:div(unconf:sum())
-- sort matrix
if sortmode == 'score' then
_,order = torch.sort(diag,1,true)
elseif sortmode == 'occurrence' then
_,order = torch.sort(freqs,1,true)
else
error('sort mode must be one of: score | occurrence')
end
-- render matrix
local render = torch.zeros(#classes*block, #classes*block)
for target = 1,#classes do
for prediction = 1,#classes do
render[{ { (target-1)*block+1,target*block }, { (prediction-1)*block+1,prediction*block } }] = confusion[order[target]][order[prediction]]
end
end
-- add grid
for target = 1,#classes do
render[{ {target*block},{} }] = 0.1
render[{ {},{target*block} }] = 0.1
end
-- create rendering
require 'image'
require 'qtwidget'
require 'qttorch'
local win1 = qtwidget.newimage( (#render)[2]+legendwidth, (#render)[1] )
image.display{image=render, win=win1}
-- add legend
for i in ipairs(classes) do
-- background cell
win1:setcolor{r=0,g=0,b=0}
win1:rectangle((#render)[2],(i-1)*block,legendwidth,block)
win1:fill()
-- %
win1:setfont(qt.QFont{serif=false, size=fontsize})
local gscale = freqs[order[i]]/freqs:max()*0.9+0.1 --3/4
win1:setcolor{r=gscale*0.5+0.2,g=gscale*0.5+0.2,b=gscale*0.8+0.2}
win1:moveto((#render)[2]+10,i*block-block/3)
win1:show(string.format('[%2.2f%% labels]',math.floor(freqs[order[i]]*10000+0.5)/100))
-- legend
win1:setfont(qt.QFont{serif=false, size=fontsize})
local gscale = diag[order[i]]*0.8+0.2
win1:setcolor{r=gscale,g=gscale,b=gscale}
win1:moveto(120+(#render)[2]+10,i*block-block/3)
win1:show(classes[order[i]])
for j in ipairs(classes) do
-- scores
local score = confusion[order[j]][order[i]]
local gscale = (1-score)*(score*0.8+0.2)
win1:setcolor{r=gscale,g=gscale,b=gscale}
win1:moveto((i-1)*block+block/5,(j-1)*block+block*2/3)
win1:show(string.format('%02.0f',math.floor(score*100+0.5)))
end
end
-- generate tensor
local t = win1:image():toTensor()
-- display
if display then
image.display{image=t, legend=string.format(legend[sortmode],accuracy,perclass)}
end
-- return rendering
return t
end
|