File: test_solver.m

package info (click to toggle)
caffe 1.0.0%2Bgit20180821.99bd997-8
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 16,288 kB
  • sloc: cpp: 61,586; python: 5,783; makefile: 599; sh: 559
file content (45 lines) | stat: -rw-r--r-- 1,422 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
classdef test_solver < matlab.unittest.TestCase
  
  properties
    num_output
    solver
  end
  
  methods
    function self = test_solver()
      self.num_output = 13;
      model_file = caffe.test.test_net.simple_net_file(self.num_output);
      solver_file = tempname();
      
      fid = fopen(solver_file, 'w');
      fprintf(fid, [ ...
        'net: "'  model_file  '"\n' ...
        'test_iter: 10 test_interval: 10 base_lr: 0.01 momentum: 0.9\n' ...
        'weight_decay: 0.0005 lr_policy: "inv" gamma: 0.0001 power: 0.75\n' ...
        'display: 100 max_iter: 100 snapshot_after_train: false\n' ]);
      fclose(fid);
      
      self.solver = caffe.Solver(solver_file);
      % also make sure get_solver runs
      caffe.get_solver(solver_file);
      caffe.set_mode_cpu();
      % fill in valid labels
      self.solver.net.blobs('label').set_data(randi( ...
        self.num_output - 1, self.solver.net.blobs('label').shape));
      self.solver.test_nets(1).blobs('label').set_data(randi( ...
        self.num_output - 1, self.solver.test_nets(1).blobs('label').shape));
      
      delete(solver_file);
      delete(model_file);
    end
  end
  methods (Test)
    function test_solve(self)
      self.verifyEqual(self.solver.iter(), 0)
      self.solver.step(30);
      self.verifyEqual(self.solver.iter(), 30)
      self.solver.solve()
      self.verifyEqual(self.solver.iter(), 100)
    end
  end
end