# Update 11/06/17: FISTA with backtracking is tested with lasso, lasso_weighted, and Elastic net.
# A simple implementation of [FISTA](https://github.com/tiepvupsu/FISTA)
**A MATLAB FISTA implementation based on the paper:**
A. Beck and M. Teboulle, "A fast iterative shrinkage-thresholding algorithm for linear inverse problems", *SIAM Journal on Imaging Sciences*,
vol. 2, no. 1, pp. 183–202, 2009. [View the paper](http://people.rennes.inria.fr/Cedric.Herzet/Cedric.Herzet/Sparse_Seminar/Entrees/2012/11/12_A_Fast_Iterative_Shrinkage-Thresholding_Algorithmfor_Linear_Inverse_Problems_(A._Beck,_M._Teboulle)_files/Breck_2009.pdf).
**[Tiep Vu](http://www.personal.psu.edu/thv102/), Penn State, Sep 2016**
If you find any issue, please let me know via [this](https://github.com/tiepvupsu/FISTA/issues). I would really appreciate. Thank you.
***Note:*** Results in this repo are compared with those obtained by the [*SPAMS*](http://spams-devel.gforge.inria.fr/) toolbox. You need to install spams and put the generated 'build' folder under the 'spams' folder of this repo.
# Table of content
<!-- MarkdownTOC -->
- [General Optimization problem](#general-optimization-problem)
- [Algorithms](#algorithms)
- [If `L(f)` is easy to calculate,](#if-lf-is-easy-to-calculate)
- [In case `L(f)` is hard to find,](#in-case-lf-is-hard-to-find)
- [Usage](#usage)
- [`fista_general.m`](#fistageneralm)
- [`fista_backtracking`](#fistabacktracking)
- [Examples](#examples)
- [Lasso \(and weighted\) problems](#lasso-and-weighted-problems)
- [Elastic net problems](#elastic-net-problems)
- [Row sparsity problems](#row-sparsity-problems)
- [Group sparsity problems \(implement later\)](#group-sparsity-problems-implement-later)
<!-- /MarkdownTOC -->
<a name="general-optimization-problem"></a>
## General Optimization problem
<img src = "latex/fista1.png" height = "30"/>
where:
- `g: R^n -> R`: a continuous convex function which is possibly _nonsmooth_.
+ `f: R^n -> R`: a smooth convex function of the type `C^{1, 1}`, i.e., continuously differentiable with Lipschitz continuous gradient `L(f)`:
`||grad_f(x) - grad_f(y)|| <= L(f)||x - y||` for every `x, y \in R^n`
***Note***: this implementation also work on nonnegativity constrained problems.
<a name="algorithms"></a>
## Algorithms
<a name="if-lf-is-easy-to-calculate"></a>
### If `L(f)` is easy to calculate,
We use the following algorithm:
![FISTA with constant step](https://raw.githubusercontent.com/tiepvupsu/FISTA/master/figs/FISTA_L.png)
where `pL(y)` is a proximal function defined as:
![pL(y)](https://raw.githubusercontent.com/tiepvupsu/FISTA/master/figs/ply.png)
For a new problem, our job is to implement two functions: `grad_f(x)` and `pL(y)` which are often simpler than the original optimization stated in (1).
<a name="in-case-lf-is-hard-to-find"></a>
### In case `L(f)` is hard to find,
We can alternatively use the following algorithm:
![FISTA with backtracking](https://raw.githubusercontent.com/tiepvupsu/FISTA/master/figs/FISTA_noL.png)
where `QL(x, y)` is defined as:
![FISTA with backtracking](https://raw.githubusercontent.com/tiepvupsu/FISTA/master/figs/qlxy.png)
<a name="usage"></a>
## Usage
<a name="fistageneralm"></a>
### `fista_general.m`
`[X, iter, min_cost] = fista_general(grad, proj, Xinit, L, opts, calc_F) `
See [`fista_general.m`](https://github.com/tiepvupsu/FISTA/blob/master/fista_general.m).
where:
```matlab
INPUT:
grad : a function calculating gradient of f(X) given X.
proj : a function calculating pL(x) -- projection
Xinit : a matrix -- initial guess.
L : a scalar the Lipschitz constant of the gradient of f(X).
opts : a struct
opts.lambda : a regularization parameter, can be either a scalar or
a weighted matrix.
opts.max_iter: maximum iterations of the algorithm.
Default 300.
opts.tol : a tolerance, the algorithm will stop if difference
between two successive X is smaller than this value.
Default 1e-8.
opts.verbose : showing F(X) after each iteration or not.
Default false.
calc_F: optional, a function calculating value of F at X
via feval(calc_F, X).
OUTPUT:
X : solution
iter : number of run iterations
min_cost : the achieved cost
```
<a name="fistabacktracking"></a>
### `fista_backtracking`
`function X = fista_backtracking(calc_f, grad, Xinit, opts, calc_F)`
See [`fista_backtracking.m`](https://github.com/tiepvupsu/FISTA/blob/master/fista_backtracking.m)
where:
```matlab
INPUT:
calc_f : a function calculating f(x) in F(x) = f(x) + g(x)
grad : a function calculating gradient of f(X) given X.
Xinit : a matrix -- initial guess.
opts : a struct
opts.lambda : a regularization parameter, can be either a scalar or
a weighted matrix.
opts.max_iter: maximum iterations of the algorithm.
Default 300.
opts.tol : a tolerance, the algorithm will stop if difference
between two successive X is smaller than this value.
Default 1e-8.
opts.verbose : showing F(X) after each iteration or not.
Default false.
opts.L0 : a positive scalar.
opts.eta: (must be > 1). eta in the algorithm (page 194)
calc_F: optional, a function calculating value of F at X
via feval(calc_F, X).
OUTPUT:
X : solution
```
<a name="examples"></a>
## Examples
<a name="lasso-and-weighted-problems"></a>
### Lasso (and weighted) problems
***Optimization problem:***
This function solves the l1 Lasso problem:
<img src = "latex/fista_lasso1.png" height = "40"/>
if `lambda` is a scalar, or :
<img src = "latex/fista_lasso2.png" height = "40"/>
if `lambda` is a matrix. In case `lambda` is a vector, it will be converted to a matrix with same columns and its # of columns = # of columns of `X`.
***MATLAB function:***
```matlab
function X = lasso_fista(Y, D, Xinit, opts)
opts = initOpts(opts);
lambda = opts.lambda;
if numel(Xinit) == 0
Xinit = zeros(size(D,2), size(Y,2));
end
%% cost f
function cost = calc_f(X)
cost = 1/2 *normF2(Y - D*X);
end
%% cost function
function cost = calc_F(X)
if numel(lambda) == 1 % scalar
cost = calc_f(X) + lambda*norm1(X);
elseif numel(lambda) == numel(X)
cost = calc_f(X) + norm1(lambda.*X);
end
end
%% gradient
DtD = D'*D;
DtY = D'*Y;
function res = grad(X)
res = DtD*X - DtY;
end
%% Checking gradient
if nargin == 0 && opts.check_grad
check_grad(@calc_f, @grad, Xinit);
end
%% Lipschitz constant
L = max(eig(DtD));
%% Use fista
[X, ~, ~] = fista_general(@grad, @proj_l1, Xinit, L, opts, @calc_F);
end
```
(See [])
***Example:***
**1. L1 minimization** (`lambda` is a scalar)
See [`demo_lasso.m`](https://github.com/tiepvupsu/FISTA/blob/master/demo_lasso.m)
```matlab
function test_lasso()
clc
d = 300; % data dimension
N = 70; % number of samples
k = 100; % dictionary size
lambda = 0.01;
Y = normc(rand(d, N));
D = normc(rand(d, k));
%% cost function
function c = calc_F(X)
c = 0.5*normF2(Y - D*X) + lambda*norm1(X);
end
%% fista solution
opts.pos = true; % change to false for unconstrained problems
opts.lambda = lambda;
X_fista = lasso_fista(Y, D, [], opts);
%% spams solution
param.lambda = lambda;
param.lambda2 = 0;
param.numThreads = 1;
param.mode = 2;
param.pos = opts
没有合适的资源?快使用搜索试试~ 我知道了~
FISTA implementation in MATLAB (recently updated FISTA with back
共313个文件
m:169个
mexmaci64:49个
png:12个
1.该资源内容由用户上传,如若侵权请联系客服进行举报
2.虚拟产品一经售出概不退款(资源遇到问题,请及时私信上传者)
2.虚拟产品一经售出概不退款(资源遇到问题,请及时私信上传者)
版权申诉
0 下载量 191 浏览量
2023-07-21
20:27:24
上传
评论
收藏 5.76MB ZIP 举报
温馨提示
FISTA implementation in MATLAB (recently updated FISTA with backtracking).zip
资源推荐
资源详情
资源评论
收起资源包目录
FISTA implementation in MATLAB (recently updated FISTA with back (313个子文件)
fista_elastic2.aux 8B
fista_lasso1.aux 8B
fista_row_sparsity1.aux 8B
fista_row_sparsity0.aux 8B
fista1.aux 8B
fista_lasso2.aux 8B
fista_row_sparsity2.aux 8B
fista_elastic.aux 8B
doc_spams.css 4KB
done 3KB
.DS_Store 8KB
.DS_Store 6KB
.DS_Store 6KB
fista_row_sparsity0.fdb_latexmk 27KB
fista_elastic2.fdb_latexmk 27KB
fista_elastic.fdb_latexmk 27KB
fista_lasso1.fdb_latexmk 27KB
fista_lasso2.fdb_latexmk 26KB
fista_row_sparsity1.fdb_latexmk 26KB
fista_row_sparsity2.fdb_latexmk 26KB
fista1.fdb_latexmk 25KB
fista_row_sparsity0.fls 25KB
fista_elastic2.fls 25KB
fista_elastic.fls 25KB
fista_lasso1.fls 25KB
fista_lasso2.fls 25KB
fista_row_sparsity1.fls 25KB
fista_row_sparsity2.fls 25KB
fista1.fls 24KB
previous_motif.gif 317B
next_motif.gif 317B
contents_motif.gif 316B
fista_elastic2.synctex.gz 3KB
fista_row_sparsity0.synctex.gz 3KB
fista_elastic.synctex.gz 3KB
fista_lasso2.synctex.gz 3KB
fista_lasso1.synctex.gz 3KB
fista1.synctex.gz 3KB
fista_row_sparsity1.synctex.gz 2KB
fista_row_sparsity2.synctex.gz 2KB
doc_spams.html 440KB
doc_spams006.html 210KB
doc_spams005.html 83KB
doc_spams004.html 56KB
doc_spams008.html 33KB
doc_spams009.html 32KB
doc_spams010.html 11KB
doc_spams002.html 9KB
doc_spams001.html 7KB
index.html 6KB
doc_spams007.html 3KB
doc_spams003.html 2KB
fista_row_sparsity0.log 31KB
fista_elastic2.log 31KB
fista_elastic.log 31KB
fista_lasso1.log 31KB
fista_lasso2.log 31KB
fista_row_sparsity1.log 31KB
fista_row_sparsity2.log 31KB
fista1.log 29KB
test_FistaFlat.m 9KB
mexFistaFlat.m 6KB
mexFistaFlat.m 6KB
test_FistaTree.m 6KB
test_IncrementalProx.m 6KB
mexProximalTree.m 6KB
mexProximalTree.m 6KB
test_FistaGraph.m 6KB
mexTrainDL.m 5KB
mexTrainDL.m 5KB
mexStructTrainDL.m 5KB
mexStructTrainDL.m 5KB
demo_full.m 5KB
fista_backtracking.m 5KB
mexProximalFlat.m 5KB
mexProximalFlat.m 5KB
test_ProximalTree.m 4KB
mexTrainDL_Memory.m 4KB
mexTrainDL_Memory.m 4KB
test_StochasticProx.m 4KB
mexFistaTree.m 4KB
mexFistaTree.m 4KB
fista_general.m 4KB
mexFistaGraph.m 4KB
mexFistaGraph.m 4KB
mexFistaPathCoding.m 3KB
mexFistaPathCoding.m 3KB
mexProximalGraph.m 3KB
mexProximalGraph.m 3KB
test_StructTrainDL.m 3KB
nmf.m 3KB
nmf.m 3KB
mexStochasticProx.m 3KB
mexStochasticProx.m 3KB
mexIncrementalProx.m 3KB
mexIncrementalProx.m 3KB
mexLasso.m 3KB
mexLasso.m 3KB
nnsc.m 3KB
nnsc.m 3KB
共 313 条
- 1
- 2
- 3
- 4
资源评论
AbelZ_01
- 粉丝: 873
- 资源: 5441
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功