/*
* cgf.c: conj. grad. routine for finding optimal s - fast!
*/
#include <stdio.h>
#include <math.h>
#include "mex.h"
#define sgn(x) (x>0 ? 1 : (x<0 ? -1 : 0))
extern void cgf(double *Sout, double *nits, double *nf, double *ng, double *Sin, double *X, int npats, double tol, int maxiter, int numflag);
/* Input & Output Arguments */
#define A_IN prhs[0] /* basis matrix */
#define X_IN prhs[1] /* data vectors */
#define S_IN prhs[2] /* initial guess for S */
#define SPARSITY_IN prhs[3] /* initial guess for S */
#define LAMBDA_IN prhs[4] /* precision */
#define BETA_IN prhs[5] /* prior steepness */
#define SIGMA_IN prhs[6] /* scaling parameter for prior */
#define TOL_IN prhs[7] /* tolerance */
#define MAXITER_IN prhs[8] /* maximum iterations for dfrpmin */
#define OUTFLAG_IN prhs[9] /* output flag */
#define NUMFLAG_IN prhs[10] /* pattern number output flag */
#define EPSILON_IN prhs[11] /* huber function epsilon */
#define S_OUT plhs[0] /* basis coeffs for each data vector */
#define NITS_OUT plhs[1] /* total iterations done by cg */
#define NF_OUT plhs[2] /* total P(s|x,A) calcs */
#define NG_OUT plhs[3] /* total d/ds P(s|x,A) calcs */
/* Define indexing macros for matricies */
/* L = dimension of input vectors
* M = number of basis functions
*/
#define A_(i,j) A[(i) + (j)*L] /* A is L x M */
#define X_(i,n) X[(i) + (n)*L] /* X is L x npats */
#define Sout_(i,n) Sout[(i) + (n)*M] /* S is M x npats */
#define Sin_(i,n) Sin[(i) + (n)*M] /* S is M x npats */
#define AtA_(i,j) AtA[(i) + (j)*M] /* AtA is M x M */
/* Globals for using with frprmin */
static double *A; /* basis matrix */
static int L; /* data dimension */
static int M; /* number of basis vectors */
static double lambda; /* 1/noise_var */
static double beta; /* prior steepness */
static double sigma; /* prior scaling */
static double k1, k2, k3; /* precomputed constants for f1dim */
static double *x; /* current data vector being fitted */
static double *s0; /* init coefficient vector (1:M) */
static double *d; /* search dir. coefficient vector (1:M) */
static int outflag; /* print search progress */
static double *AtA; /* Only compute A'*A once (1:M,1:M) */
static double *Atx; /* A*x (1:M) */
static int fcount, gcount;
#define SP_LOG 0
#define SP_HUBER_L1 1
#define SP_EPS_L1 2
static int g_sparsity_func;
static double g_epsilon; /* use global variable for huber function epsilon */
static void init_global_arrays()
{
int i, j, k;
double *Ai, *Aj, sum;
x = (double *) malloc(L * sizeof(double));
s0 = (double *) malloc(M * sizeof(double));
d = (double *) malloc(M * sizeof(double));
AtA = (double *) malloc(M * M * sizeof(double));
Atx = (double *) malloc(M * sizeof(double));
/* Calc A'*A */
for (i = 0; i < M; i++) {
Ai = A + i * L;
for (j = 0; j < M; j++) {
Aj = A + j * L;
sum = 0.0;
for (k = 0; k < L; k++) {
sum += Ai[k] * Aj[k];
}
AtA_(i, j) = sum;
}
}
}
static void free_global_arrays()
{
free((double *) x);
free((double *) s0);
free((double *) d);
free((double *) AtA);
free((double *) Atx);
}
float init_f1dim(s1, d1)
float *s1, *d1;
{
register int i, j;
register double As, Ag, sum;
register float fval;
extern double sparse();
for (i = 0; i < M; i++) {
s0[i] = s1[i + 1];
d[i] = d1[i + 1];
}
k1 = k2 = k3 = 0;
for (i = 0; i < L; i++) {
As = Ag = 0;
for (j = 0; j < M; j++) {
As += A_(i, j) * s0[j];
Ag += A_(i, j) * d[j];
}
k1 += As * (As - 2 * x[i]);
k2 += Ag * (As - x[i]);
k3 += Ag * Ag;
}
k1 *= 0.5 * lambda;
k2 *= lambda;
k3 *= 0.5 * lambda;
fval = k1;
sum = 0;
for (i = 0; i < M; i++)
sum += sparse(s0[i] / sigma);
fval += beta * sum;
fcount++;
return (fval);
}
float f1dim(alpha)
float alpha;
{
int i;
double sum;
float fval;
extern double sparse();
fval = k1 + (k2 + k3 * alpha) * alpha;
sum = 0;
for (i = 0; i < M; i++) {
sum += sparse((s0[i] + alpha * d[i]) / sigma);
}
fval += beta * sum;
fcount++;
return (fval);
}
/*
* Gradient evaluation used by conj grad descent
*/
void dfunc(p, grad)
float *p, *grad;
{
register int i, j;
register double sum, *cptr, bos = beta / sigma;
register float *p1;
extern double sparse_prime();
p1 = &p[1];
for (i = 0; i < M; i++) {
cptr = AtA + i * M;
sum = 0;
for (j = 0; j < M; j++) {
sum += p1[j] * *cptr++;
}
grad[i + 1] = lambda * (sum - Atx[i]) + bos * sparse_prime((double) p1[i] / sigma);
}
gcount++;
}
double sparse(x)
double x;
{
if (g_sparsity_func== SP_LOG) {
return (log(1.0 + x * x));
} else if (g_sparsity_func== SP_HUBER_L1) {
/* retval(idx_in) = 1/(2*eps).*x(idx_in).^2;
retval(idx_out) = 1/2.*(2.*abs(x(idx_out))-eps); */
if (fabs(x) < g_epsilon)
return x*x/(2.0*g_epsilon); /*1.0/(2.0*g_epsilon)* x*x;*/
else
return (2*abs(x)-g_epsilon)/2.0; /*1.0/2.0* (2*abs(x)-g_epsilon);*/
} else if (g_sparsity_func== SP_EPS_L1) {
return (sqrt(x * x + g_epsilon));
}
fprintf(stderr, "Error: sparsity function is not properly specified!\n");
exit(-1);
}
double sparse_prime(x)
double x;
{
if (g_sparsity_func== SP_LOG) {
return (2 * x / (1.0 + x * x));
} else if (g_sparsity_func== SP_HUBER_L1) {
/* retval(idx_in) = 1/(2*eps).* 2.0.*x(idx_in);
retval(idx_out) = 1/2.* 2.*sign(x(idx_out)); */
if (fabs(x) < g_epsilon)
return x/ g_epsilon; /*1.0/(2.0*g_epsilon)* 2.0*x;*/
else
return sgn(x);
} else if (g_sparsity_func== SP_EPS_L1) {
return x/sqrt(x * x + g_epsilon);
}
fprintf(stderr, "Error: sparsity function is not properly specified!\n");
exit(-2);
}
void iter_do()
{
}
#include <nrutil.h>
extern int ITMAX;
void cgf(double *Sout, double *nits, double *nf, double *ng, double *Sin, double *X, int npats, double tol, int maxiter, int numflag)
{
double sum;
float fret;
int niter, l, m, n;
float *p;
*nits = *nf = *ng = 0.0;
ITMAX = 10;
init_global_arrays();
p = vector(1, M);
for (n = 0; n < npats; n++) {
if (numflag) {
fprintf(stderr, "\r%d", n + 1);
fflush(stderr);
}
for (l = 0; l < L; l++) {
x[l] = X_(l, n);
}
for (m = 0; m < M; m++) {
/* precompute Atx for this pattern */
sum = 0.0;
for (l = 0; l < L; l++) {
sum += A_(l, m) * x[l];
}
Atx[m] = sum;
/* copy initial guess */
p[m + 1] = Sin_(m, n);
}
fcount = gcount = 0;
frprmn(p, M, (float) tol, &niter, &fret, init_f1dim, f1dim, dfunc);
*nits += (double) niter;
*nf += (double) fcount;
*ng += (double) gcount;
if (outflag) {
fprintf(stdout, "\nfret=%f niters=%d fcount=%d gcount=%d\n", fret, niter, fcount, gcount);
fflush(stdout);
}
/* copy back solution */
for (m = 0; m < M; m++) {
Sout_(m, n) = p[m + 1];
}
}
free_global_arrays();
free_vector(p, 1, n);
}
void mexFunction(int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[])
{
double *Sout, nits = 0, nf = 0, ng = 0, *Sin;
double *X, tol;
int maxiter, npats, numflag, i;
/* Check for proper number of arguments */
if (nrhs < 7) {
mexErrMsgTxt("cgf requires 6 input arguments.");
} else if (nlhs < 1) {
mexErrMsgTxt("cgf requires 1 output argument.");
}
/* Assign pointers to the various p
- 1
- 2
- 3
前往页