File: sparsecoding.lua

package info (click to toggle)
lua-torch-optim 0~20171127-ga5ceed7-1
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 548 kB
  • sloc: makefile: 8
file content (127 lines) | stat: -rw-r--r-- 3,907 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
require 'kex'

-- L1 FISTA Solution
-- L1 solution with a linear dictionary ||Ax-b||^2 + \lambda ||x||_1
-- D     : dictionary, each column is a dictionary element
-- params: set of params to pass to FISTA and possibly temp allocation (**optional**)
--         check unsup.FistaLS function for details.
-- returns fista : a table with the following entries
-- fista.run(x,lambda) : run L1 sparse coding algorithm with input x and lambda.
-- The following entries will be allocated and reused by each call to fista.run(x,lambda)
-- fista.reconstruction: reconstructed input.
-- fista.gradf         : gradient of L2 part of the problem wrt x
-- fista.code          : the solution of L1 problem
-- The following entries just point to data passed to fista.run(x)
-- fista.input         : points to the tensor 'x' used in the last fista.run(x,lambda)
-- fista.lambda        : the lambda value used in the last fista.run(x,lambda)
function optim.FistaL1(D, params)

   -- this is for keeping parameters related to fista algorithm
   local params = params or {}
   -- this is for temporary variables and such
   local fista = {}

   -- related to FISTA
   params.L = params.L or 0.1
   params.Lstep = params.Lstep or 1.5
   params.maxiter = params.maxiter or 50
   params.maxline = params.maxline or 20
   params.errthres = params.errthres or 1e-4
   
   -- temporary stuff that might be good to keep around
   fista.reconstruction = torch.Tensor()
   fista.gradf = torch.Tensor()
   fista.gradg = torch.Tensor()
   fista.code = torch.Tensor()

   -- these will be assigned in run(x)
   -- fista.input points to the last input that was run
   -- fista.lambda is the lambda value from the last run
   fista.input = nil
   fista.lambda = nil

   -- CREATE FUNCTION CLOSURES
   -- smooth function
   fista.f = function (x,mode)

		local reconstruction = fista.reconstruction
		local input = fista.input
		-- -------------------
		-- function evaluation
		if x:dim() == 1 then
		   --print(D:size(),x:size())
		   reconstruction:resize(D:size(1))
		   reconstruction:addmv(0,1,D,x)
		elseif x:dim(2) then
		   reconstruction:resize(x:size(1),D:size(1))
		   reconstruction:addmm(0,1,x,D:t())
		end
		local fval = input:dist(reconstruction)^2
		
		-- ----------------------
		-- derivative calculation
		if mode and mode:match('dx') then
		   local gradf = fista.gradf
		   reconstruction:add(-1,input):mul(2)
		   gradf:resizeAs(x)
		   if input:dim() == 1 then
		      gradf:addmv(0,1,D:t(),reconstruction)
		   else
		      gradf:addmm(0,1,reconstruction, D)
		   end
		   ---------------------------------------
		   -- return function value and derivative
		   return fval, gradf, reconstruction
		end
		
		------------------------
		-- return function value
		return fval, reconstruction
	     end

   -- non-smooth function L1
   fista.g =  function (x)

		 local fval = fista.lambda*x:norm(1)

		 if mod and mode:match('dx') then
		    local gradg = fista.gradg
		    gradg:resizAs(x)
		    gradg:sign():mul(fista.lambda)
		    return fval,gradg
		 end
		 return fval
	      end
   
   -- argmin_x Q(x,y), just shrinkage for L1
   fista.pl = function (x,L)
		 x:shrinkage(fista.lambda/L)
	      end
   
   fista.run = function(x, lam, codeinit)
		  local code = fista.code
		  fista.input = x
		  fista.lambda = lam
		  
		  -- resize code, maybe a different number of dimensions
		  -- fill with zeros, initial point
		  if codeinit then
		     code:resizeAs(codeinit)
		     code:copy(codeinit)
		  else
		     if x:dim() == 1 then
			code:resize(D:size(2))
		     elseif x:dim() == 2 then
			code:resize(x:size(1),D:size(2))
		     else
			error(' I do not know how to handle ' .. x:dim() .. ' dimensional input')
		     end
		     code:fill(0)
		  end
		  -- return the result of unsup.FistaLS call.
		  return optim.FistaLS(fista.f, fista.g, fista.pl, fista.code, params)
	       end

   return fista
end