File: train.lua

package info (click to toggle)
deepboof 0.4%2Bds-2
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 1,892 kB
  • sloc: java: 14,256; python: 50; makefile: 7; sh: 3
file content (214 lines) | stat: -rw-r--r-- 6,360 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
----------------------------------------------------------------------
-- This script demonstrates how to define a training procedure,
-- irrespective of the model/loss functions chosen.
--
-- It shows how to:
--   + construct mini-batches on the fly
--   + define a closure to estimate (a noisy) loss
--     function, as well as its derivatives wrt the parameters of the
--     model to be trained
--   + optimize the function, according to several optmization
--     methods: SGD, L-BFGS.
--
-- Clement Farabet
----------------------------------------------------------------------

require 'torch'   -- torch
require 'xlua'    -- xlua provides useful tools, like progress bars
require 'optim'   -- an optimization package, for online and batch methods
require 'deepboof'

if opt.type == 'cuda' then
   require 'cunn'
end

----------------------------------------------------------------------
-- Model + Loss:

local t = require(opt.model)
local model = t.model
local loss = t.loss

local d = require 'data'
local classes = d.classes
local trainData = d.trainData

----------------------------------------------------------------------
print(sys.COLORS.red ..  '==> defining some tools')

-- This matrix records the current confusion across classes
local confusion = optim.ConfusionMatrix(classes)

-- Log results to files
local trainLogger = optim.Logger(paths.concat(opt.save, 'train.log'))

----------------------------------------------------------------------
print(sys.COLORS.red ..  '==> flattening model parameters')

-- Retrieve parameters and gradients:
-- this extracts and flattens all the trainable parameters of the mode
-- into a 1-dim vector
local w,dE_dw = model:getParameters()

----------------------------------------------------------------------
print(sys.COLORS.red ..  '==> configuring optimizer')

local optimState = {}

----------------------------------------------------------------------
print(sys.COLORS.red ..  '==> allocating minibatch memory')
local x = torch.Tensor(opt.batchSize,trainData.data:size(2),
         trainData.data:size(3), trainData.data:size(4))
local yt = torch.Tensor(opt.batchSize)

if opt.type == 'cuda' then
   x = x:cuda()
   yt = yt:cuda()
end

----------------------------------------------------------------------
print(sys.COLORS.red ..  '==> defining training procedure')

local epoch

local function reset()
   epoch = 0

   if opt.search == 'sgd' then
      optimState = {
         learningRate = opt.learningRate,
         momentum = opt.sgdMomentum,
         weightDecay = opt.sgdWeightDecay,
         learningRateDecay = opt.sgdLearningRateDecay
      }
   elseif opt.search == 'adam' then
      optimState = {
         learningRate = opt.learningRate,
         beta1 = opt.adamBeta1,
         beta2 = opt.adamBeta2,
         epsilon = 1e-8
      }
   end
   -- reset weights
   model:reset()

   if  not trainLogger == nil then
      trainLogger.file:close()
   end
   trainLogger = optim.Logger(paths.concat(opt.save, 'train.log'))
end

local function train(trainData)

   local file_param = io.open("results/training_parameters.txt", "w")

   if opt.search == 'sgd' then
      file_param:write('learningRate '..optimState.learningRate..'\n')
      file_param:write('momentum '..optimState.momentum..'\n')
      file_param:write('weightDecay '..optimState.weightDecay..'\n')
      file_param:write('learningRateDecay '..optimState.learningRateDecay..'\n')
   elseif opt.search == 'adam' then
      file_param:write('learningRate '..optimState.learningRate..'\n')
      file_param:write('beta1 '..optimState.beta1..'\n')
      file_param:write('beta2 '..optimState.beta2..'\n')
   end
   file_param:write('size '.. opt.size ..'\n')
   file_param:write('model '.. opt.model ..'\n')
   file_param:write('search '.. opt.search ..'\n')
   file_param:close()

   -- epoch tracker
   epoch = epoch or 1

   -- local vars
   local time = sys.clock()

   -- shuffle at each epoch
   local shuffle = torch.randperm(trainData:size())

   -- Let it know that it's in training mode
   model:training()

   -- do one epoch
   print(sys.COLORS.green .. '==> doing epoch on training data:') 
   print("==> online epoch # " .. epoch .. ' [batchSize = ' .. opt.batchSize .. ']')
   for t = 1,trainData:size(),opt.batchSize do
      -- disp progress
      xlua.progress(t, trainData:size())
      collectgarbage()

      -- batch fits?
      if (t + opt.batchSize - 1) > trainData:size() then
         break
      end

      -- create mini batch
      local idx = 1
      for i = t,t+opt.batchSize-1 do
         x[idx] = trainData.data[shuffle[i]]
         yt[idx] = trainData.labels[shuffle[i]]
         idx = idx + 1
      end

      -- create closure to evaluate f(X) and df/dX
      local eval_E = function(w)
         -- reset gradients
         dE_dw:zero()

         -- evaluate function for complete mini batch
         local y = model:forward(x)
         local E = loss:forward(y,yt)

         -- estimate df/dW
         local dE_dy = loss:backward(y,yt)   
         model:backward(x,dE_dy)

         -- update confusion
         for i = 1,opt.batchSize do
            confusion:add(y[i],yt[i])
         end

         -- print("E ",E," dE_dw ",dE_dw:sum()," w ",w:sum())
         -- return f and df/dX
         return E,dE_dw
      end

      -- optimize on current mini-batch
      optim[opt.search](eval_E, w, optimState)
   end

   -- time taken
   time = sys.clock() - time
   time = time / trainData:size()
   print("\n==> time to learn 1 sample = " .. (time*1000) .. 'ms')

   local file_confusion = io.open(paths.concat(opt.save , "confusion_human_train.txt"), "w")
   file_confusion:write(tostring(confusion))
   file_confusion:close()

   file_confusion = io.open(paths.concat(opt.save , "confusion_train.txt"), "w")
   file_confusion:write(deepboof.confusionToString(confusion))
   file_confusion:close()

   -- print confusion matrix
   print(confusion)

   -- update logger/plot
   trainLogger:add{['% mean class accuracy (train set)'] = confusion.totalValid * 100}
   if opt.plot then
      trainLogger:style{['% mean class accuracy (train set)'] = '-'}
      trainLogger:plot()
   end


   -- next epoch
   local average_accuracy = confusion.totalValid
   confusion:zero()
   epoch = epoch + 1

   return average_accuracy,model:clone()
end

-- Export:
return {train=train,reset=reset}