#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <math.h>
#include "blas.h"
#include "lapack.h"
#include "mex.h"
const int ione = 1;
const int itwo = 2;
const int ithree = 3;
const int iseven = 7;
const double fone = 1;
const double ftwo = 2;
const double fzero = 0;
const double fmone = -1;
int quiet = 0;
void printmat(double *A, int m, int n);
void fmpcsolve(double *A, double *B, double *At, double *Bt, double *eyen,
double *eyem, double *Q, double *R, double *Qf, double *zmax, double *zmin,
double *x, double *z, int T, int n, int m, int nz, int niters, double kappa);
void gfgphp(double *Q, double *R, double *Qf, double *zmax, double *zmin, double *z,
int T, int n, int m, int nz, double *gf, double *gp, double *hp);
void rdrp(double *A, double *B, double *Q, double *R, double *Qf, double *z, double *nu,
double *gf, double *gp, double *b, int T, int n, int m, int nz,
double kappa, double *rd, double *rp, double *Ctnu);
void resdresp(double *rd, double *rp, int T, int n, int nz, double *resd,
double *resp, double *res);
void dnudz(double *A, double *B, double *At, double *Bt, double *eyen,
double *eyem, double *Q, double *R, double *Qf, double *hp, double *rd,
double *rp, int T, int n, int m, int nz, double kappa, double *dnu, double *dz);
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
/* problem setup */
int i, j, m, n, nz, T, niters, nsteps, info;
double kappa;
double *dptr, *dptr1, *dptr2, *y1, *y2;
double *A, *B, *At, *Bt, *Q, *R, *Qf, *xmax, *xmin, *umax, *umin, *x, *u;
double *zmax, *zmin, *Xhist, *Uhist, *z, *eyen, *eyem, *x0, *w;
double *K, *Rtilde, *Rlower, *tempnm, *xterm, *uterm;
double *umaxp, *uminp, *xmaxp, *xminp, *zmaxp, *zminp;
double *X0, *U0;
double *telapsed;
clock_t t1, t2;
A = mxGetPr(mxGetField(prhs[0],0,"A"));
B = mxGetPr(mxGetField(prhs[0],0,"B"));
Q = mxGetPr(mxGetField(prhs[0],0,"Q"));
R = mxGetPr(mxGetField(prhs[0],0,"R"));
Qf = mxGetPr(mxGetField(prhs[1],0,"Qf"));
xmax = mxGetPr(mxGetField(prhs[0],0,"xmax"));
xmin = mxGetPr(mxGetField(prhs[0],0,"xmin"));
umax = mxGetPr(mxGetField(prhs[0],0,"umax"));
umin = mxGetPr(mxGetField(prhs[0],0,"umin"));
n = (int)mxGetScalar(mxGetField(prhs[0],0,"n"));
m = (int)mxGetScalar(mxGetField(prhs[0],0,"m"));
T = (int)mxGetScalar(mxGetField(prhs[1],0,"T"));
kappa = (double)mxGetScalar(mxGetField(prhs[1],0,"kappa"));
niters = (int)mxGetScalar(mxGetField(prhs[1],0,"niters"));
quiet = (int)mxGetScalar(mxGetField(prhs[1],0,"quiet"));
nsteps = (int)mxGetScalar(mxGetField(prhs[1],0,"nsteps"));
X0 = mxGetPr(prhs[2]);
U0 = mxGetPr(prhs[3]);
x0 = mxGetPr(prhs[4]);
w = mxGetPr(prhs[5]);
nz = T*(n+m);
/* outputs */
plhs[0] = mxCreateDoubleMatrix(n,nsteps,mxREAL);
plhs[1] = mxCreateDoubleMatrix(m,nsteps,mxREAL);
plhs[2] = mxCreateDoubleMatrix(1,1,mxREAL);
Xhist = mxGetPr(plhs[0]);
Uhist = mxGetPr(plhs[1]);
telapsed = mxGetPr(plhs[2]);
At = malloc(sizeof(double)*n*n);
Bt = malloc(sizeof(double)*n*m);
eyen = malloc(sizeof(double)*n*n);
eyem = malloc(sizeof(double)*m*m);
z = malloc(sizeof(double)*nz);
x = malloc(sizeof(double)*n);
u = malloc(sizeof(double)*m);
y1 = malloc(sizeof(double)*n);
y2 = malloc(sizeof(double)*n);
zmax = malloc(sizeof(double)*nz);
zmin = malloc(sizeof(double)*nz);
K = malloc(sizeof(double)*n*m);
Rtilde = malloc(sizeof(double)*m*m);
Rlower = malloc(sizeof(double)*m*m);
tempnm = malloc(sizeof(double)*n*m);
xterm = malloc(sizeof(double)*n);
uterm = malloc(sizeof(double)*m);
umaxp = malloc(sizeof(double)*m);
uminp = malloc(sizeof(double)*m);
xmaxp = malloc(sizeof(double)*n);
xminp = malloc(sizeof(double)*n);
zmaxp = malloc(sizeof(double)*nz);
zminp = malloc(sizeof(double)*nz);
/* eyen, eyem */
for (i = 0; i < n*n; i++) eyen[i] = 0;
dptr = eyen;
for (i = 0; i < n; i++)
{
*dptr = 1;
dptr = dptr+n+1;
}
for (i = 0; i < m*m; i++) eyem[i] = 0;
dptr = eyem;
for (i = 0; i < m; i++)
{
*(dptr+i*m+i) = 1;
}
for (i = 0; i < n; i++) x[i] = x0[i];
dptr = z;
for (i = 0; i < T; i++)
{
for (j = 0; j < m; j++)
{
*dptr = *(U0+i*m+j);
dptr++;
}
for (j = 0; j < n; j++)
{
*dptr = *(X0+i*n+j);
dptr++;
}
}
/* At, Bt */
F77_CALL(dgemm)("t","n",&n,&n,&n,&fone,A,&n,eyen,&n,&fzero,At,&n);
F77_CALL(dgemm)("n","t",&m,&n,&m,&fone,eyem,&m,B,&n,&fzero,Bt,&m);
/* zmax, zmin */
dptr1 = zmax;
dptr2 = zmin;
for (i = 0; i < T; i++)
{
for (j = 0; j < m; j++)
{
*dptr1 = *(umax+j);
*dptr2 = *(umin+j);
dptr1++; dptr2++;
}
for (j = 0; j < n; j++)
{
*dptr1 = *(xmax+j);
*dptr2 = *(xmin+j);
dptr1++; dptr2++;
}
}
/* zmaxp, zminp */
for (i = 0; i < nz; i++) zminp[i] = zmin[i] + 0.01*(zmax[i]-zmin[i]);
for (i = 0; i < nz; i++) zmaxp[i] = zmax[i] - 0.01*(zmax[i]-zmin[i]);
/* project z */
for (i = 0; i < nz; i++) z[i] = z[i] > zmaxp[i] ? zmaxp[i] : z[i];
for (i = 0; i < nz; i++) z[i] = z[i] < zminp[i] ? zminp[i] : z[i];
/* Rtilde */
for (i = 0; i < m*m; i++) Rtilde[i] = R[i];
F77_CALL(dgemm)("n","n",&n,&m,&n,&fone,Qf,&n,B,&n,&fzero,tempnm,&n);
F77_CALL(dgemm)("t","n",&m,&m,&n,&fone,B,&n,tempnm,&n,&fone,Rtilde,&m);
/* K */
F77_CALL(dgemm)("t","n",&m,&n,&n,&fone,B,&n,Qf,&n,&fzero,tempnm,&m);
F77_CALL(dgemm)("n","n",&m,&n,&n,&fone,tempnm,&m,A,&n,&fzero,K,&m);
for (i = 0; i < m*m; i++) Rlower[i] = Rtilde[i];
F77_CALL(dposv)("l",&m,&n,Rlower,&m,K,&m,&info);
for (i = 0; i < n*m; i++) K[i] = -K[i];
/* uminp, umaxp, xminp, xmaxp */
for (i = 0; i < m; i++) uminp[i] = umin[i] + 0.01*(umax[i]-umin[i]);
for (i = 0; i < m; i++) umaxp[i] = umax[i] - 0.01*(umax[i]-umin[i]);
for (i = 0; i < n; i++) xminp[i] = xmin[i] + 0.01*(xmax[i]-xmin[i]);
for (i = 0; i < n; i++) xmaxp[i] = xmax[i] - 0.01*(xmax[i]-xmin[i]);
t1 = clock();
for (i = 0; i < nsteps; i++)
{
fmpcsolve(A,B,At,Bt,eyen,eyem,Q,R,Qf,zmax,zmin,x,z,T,n,m,nz,niters,kappa);
/* save x and u to Xhist and Uhist */
dptr = Xhist+i*n; dptr1 = x;
for (j = 0; j < n; j++)
{
*dptr = *dptr1;
dptr++; dptr1++;
}
dptr = Uhist+i*m; dptr1 = u; dptr2 = z;
for (j = 0; j < m; j++)
{
*dptr = *dptr2;
*dptr1 = *dptr2;
dptr++; dptr1++; dptr2++;
}
/* compute x = A*x + B*u + w */
F77_CALL(dgemv)("n",&n,&n,&fone,A,&n,x,&ione,&fzero,y1,&ione);
F77_CALL(dgemv)("n",&n,&m,&fone,B,&n,u,&ione,&fzero,y2,&ione);
F77_CALL(daxpy)(&n,&fone,y2,&ione,y1,&ione);
F77_CALL(daxpy)(&n,&fone,w+n*i,&ione,y1,&ione);
dptr = x; dptr1 = y1;
for (j = 0; j < n; j++)
{
*dptr = *dptr1;
dptr++; dptr1++;
}
/* shift z for warm start and compute terminal controls*/
dptr = z;
for (j = 0; j < nz-n-m; j++)
{
*dptr = *(dptr+n+m);
dptr++;
}
F77_CALL(dgemv)("n",&m,&n,&fone,K,&m,z+nz-2*n-m,&ione,&fzero,uterm,&ione);
F77_CALL(dgemv)("n",&n,&n,&fone,A,&n,z+nz-2*n-m,&ione,&fzero,xterm,&ione);
F77_CALL(dgemv)("n",&n,&m,&fone,B,&n,u,&ione,&fzero,y2,&ione);
F77_CALL(daxpy)(&n,&fone,y2,&ione,xterm,&ione);
for (j = 0; j < m; j++) uterm[j] = uterm[j] > umaxp[j] ? umaxp[j] : uterm[j];
for (j = 0; j < m; j++) uterm[j] = uterm[j] < uminp[j] ? uminp[j] : uterm[j];
for (j = 0; j < n;