forked from ndwork/dworkLib
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfista.m
89 lines (76 loc) · 2.91 KB
/
fista.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
function [xStar,objectiveValues] = fista( x, g, gGrad, proxth, varargin )
% [xStar,objectiveValues] = fista( x, g, gGrad, proxth [, ...
% 'h', h, 'N', N, 't', t, 'verbose', verbose ] )
%
% This function implements the FISTA optimization algorithm
% FISTA finds the x that minimizes functions of form g(x) + h(x) where
% g is differentiable and h has a simple proximal operator.
%
% Inputs:
% x - the starting point
% g - a function handle representing the g function; accepts a vector x
% as input and returns a scalar.
% gGrad - a function handle representing the gradient function of g;
% input: the point to evaluation, output: the gradient vector
% proxth - the proximal operator of the h function (with parameter t);
% two inputs: the vector and the scalar value of the parameter t
%
% Optional Inputs:
% h - a handle to the h function. This is needed to calculate the
% objective values.
% N - the number of iterations that FISTA will perform (default is 100)
% t - step size (default is 1)
% verbose - if set then prints fista iteration
%
% Outputs:
% xStar - the optimal point
%
% Written by Nicholas Dwork - Copyright 2017
%
% This software is offered under the GNU General Public License 3.0. It
% is offered without any warranty expressed or implied, including the
% implied warranties of merchantability or fitness for a particular
% purpose.
defaultN = 100;
p = inputParser;
p.addParameter( 'h', [] );
p.addParameter( 'N', defaultN, @(x) ispositive(x) || numel(x)==0 );
p.addParameter( 'printEvery', 1, @ispositive );
p.addParameter( 't', 1, @isnumeric );
p.addParameter( 'verbose', 0, @(x) isnumeric(x) || islogical(x) );
p.parse( varargin{:} );
h = p.Results.h;
N = p.Results.N; % total number of iterations
printEvery = p.Results.printEvery; % display result printEvery iterations
t = p.Results.t; % t0 must be greater than 0
verbose = p.Results.verbose;
if numel( N ) == 0, N = defaultN; end
if t <= 0, error('fista: t0 must be greater than 0'); end
calculateObjectiveValues = 0;
if nargout > 1
if numel(h) == 0
error( 'fista.m - Cannot calculate objective values without h function handle' );
else
objectiveValues = zeros(N,1);
calculateObjectiveValues = 1;
end
end
z = x;
y = 0;
for k=0:N-1
x = z - t * gGrad( z );
lastY = y;
y = proxth( x, t );
if calculateObjectiveValues > 0, objectiveValues(k+1) = g(y) + h(y); end
if verbose>0 && mod( k, printEvery ) == 0
formatString = ['%', num2str(ceil(log10(N))), '.', num2str(ceil(log10(N))), 'i' ];
verboseString = [ 'FISTA Iteration: ', num2str(k,formatString) ];
if calculateObjectiveValues > 0
verboseString = [ verboseString, ', objective: ', num2str( objectiveValues(k+1) ) ]; %#ok<AGROW>
end
disp( verboseString );
end
z = y + (k/(k+3)) * (y-lastY);
end
xStar = y;
end