#include "CLConv.h"
//void CLConv::juanji(int x[N1][M1],int y[N2][M2],int z[N1+N2-1][M1+M2-1])
//{
// int i,j;
// int n,m;
// for(i=0;i<N1+N2-1;i++)
// for(j=0;j<M1+M2-1;j++)
// {
// int temp=0;
// for(m=0;m<N1;m++)
// for(n=0;n<M1;n++)
// if((i-m)>=0&&(i-m)<N2&&(j-n)>=0&&(j-n)<M2)
// temp+=x[m][n]*y[i-m][j-n];
// z[i][j]=temp;
// }
//}
bool CL2DMatrix::Conv2(CL2DMatrix &_srcA,CL2DMatrix &_srcB)
{
if(ReSize(_srcA.m_rows+_srcB.m_rows-1,_srcA.m_cols+_srcB.m_cols-1) == false )
return false;
for(int i=0;i<m_rows;i++)
{
for(int j=0;j<m_cols;j++)
{
int temp=0;
for(int m=0;m<_srcA.m_rows;m++)
{
for(int n=0;n<_srcA.m_cols;n++)
{
if((i-m)>=0&&(i-m)<_srcB.m_rows && (j-n)>=0&&(j-n)<_srcB.m_cols)
{
//temp+=x[m][n]*y[i-m][j-n];
temp+= _srcA.GetData(m,n) * _srcB.GetData(i-m,j-n);
}
}
}
//z[i][j]=temp;
m_Data[i*m_cols+j] = temp;
}
}
return true;
}
void CLConv::Print(int* _ArrData,int n,int m)
{
int i;
for(i=0;i<n;i++)
{
if(i%m==0)
printf("\n");
printf("%4d",_ArrData[i]);
}
printf("\n");
}
//////////////////////////////////////////
CL2DMatrix::CL2DMatrix()
{
m_Data = NULL;
}
CL2DMatrix::~CL2DMatrix()
{
if(m_Data !=NULL)
delete[] m_Data;
}
bool CL2DMatrix::Assing(int *_Data,const int _nRows,const int _nClos)
{
if(ReSize(_nRows,_nClos) == false )
return false;
int nSize = m_rows*m_cols;
for(int i=0;i<nSize;i++)
m_Data[i] = _Data[i];
return true;
}
bool CL2DMatrix::Step1(CL2DMatrix &_srcM,const int _nRows,const int _nClos)
{
if(ReSize(_nRows,_nClos) == false )
return false;
for(int i=0;i<_srcM.m_rows;i++)
{
for(int j=0;j<_srcM.m_cols;j++)
{
m_Data[i*m_cols+j] = _srcM.GetData(i,j);
}
}
return true;
}
bool CL2DMatrix::Step2(CL2DMatrix &_srcM)
{
if(ReSize(_srcM.m_rows*_srcM.m_cols,1) == false )
return false;
int k=0;
for(int i=0;i<_srcM.m_rows;i++)
{
for(int j=0;j<_srcM.m_cols;j++)
{
m_Data[k] = _srcM.GetData(i,j);
k++;
}
}
return true;
}
bool CL2DMatrix::Step3(CL2DMatrix &_srcM,const int _rows)
{
if(ReSize(_srcM.m_cols,_srcM.m_cols) == false )
return false;
for(int i=0;i<_srcM.m_cols;i++)
{
for(int j=0;j<_srcM.m_cols;j++)
{
int k= (_srcM.m_cols - j + i )%_srcM.m_cols;
m_Data[i*m_cols + k] = _srcM.GetData(_rows,j);
}
}
return true;
}
bool CL2DMatrix::Step4(CL2DMatrix *_pMatrix,const int _nNum)
{
if(_nNum < 1 || ReSize(_pMatrix[0].m_cols*_nNum,_pMatrix[0].m_cols*_nNum) == false )
return false;
int ncols = _pMatrix[0].m_cols;
int nNum = 0;
for(int i=0;i<_nNum;i++)
{
for(int ii=0;ii<ncols;ii++)
{
for(int j=0;j<_nNum;j++)
{
int food = (_nNum - j + i )%_nNum;
for(int jj=0;jj<ncols;jj++)
{
m_Data[nNum] = _pMatrix[food].GetData(ii,jj);
nNum++;
}
}
}
}
return true;
}
bool CL2DMatrix::Step5(CL2DMatrix &_srcA,CL2DMatrix &_srcB)
{
if(_srcA.m_cols != _srcB.m_rows || ReSize(_srcA.m_rows,_srcB.m_cols) == false )
{
printf("error: _srcA.m_cols(%d) != _srcB.m_rows(%d) \n",_srcA.m_cols , _srcB.m_rows);
return false;
}
for(int i=0;i<_srcA.m_rows;i++)
{
for(int j=0;j<_srcB.m_cols;j++)
{
for(int k=0;k<_srcB.m_rows;k++)
{
//C[i][j] += A[i][k] * B[k][j];
m_Data[i*_srcB.m_cols + j] += _srcA.GetData(i,k) * _srcB.GetData(k,j);
}
}
}
return true;
}
bool CL2DMatrix::Step6(CL2DMatrix &_srcA,const int _nRows,const int _nClos)
{
if(_srcA.m_cols != 1 || _srcA.m_rows != _nRows*_nClos || ReSize(_nRows,_nClos) == false )
{
printf("error: _srcA.m_rows(%d) != _nRows*_nClos (%d) \n",_srcA.m_rows ,_nRows*_nClos );
return false;
}
for(int i=0;i<_srcA.m_rows;i++)
{
m_Data[i] = _srcA[i];
}
return true;
}
bool CL2DMatrix::ReSize(const int _nRows,const int _nClos)
{
m_rows = _nRows;
m_cols = _nClos;
int nSize = m_rows*m_cols;
if(nSize < 1)
return false;
if(m_Data != NULL)
delete[] m_Data;
m_Data = new int[nSize];
for(int i=0;i<nSize;i++)
m_Data[i] = 0;
return true;
}
void CL2DMatrix::Print()
{
CLConv::Print(m_Data,m_rows*m_cols,m_cols);
}
//////////////////////////
void CLConv::Conv2(CL2DMatrix &_srcF,CL2DMatrix &_kernelG,CL2DMatrix &_dstR)
{
int dst_row = _srcF.m_rows + _kernelG.m_rows - 1;
int dst_cols = _srcF.m_cols + _kernelG.m_cols - 1;
///step 1: 扩展 _srcF 与 _kernelG 为 dst_row * dst_cols 大小的矩阵
CL2DMatrix srcFp;
srcFp.Step1(_srcF,dst_row,dst_cols);
CL2DMatrix kernelGp;
kernelGp.Step1(_kernelG,dst_row,dst_cols);
//step 2: 构造 (dst_row * dst_cols ) * 1矩阵
CL2DMatrix srcFb;
srcFb.Step2(srcFp);
//step3:总共生成kernelGp.m_rows个 kernelGp.m_cols * kernelGp.m_cols 列矩阵
CL2DMatrix *pgp = new CL2DMatrix[kernelGp.m_rows];
for(int i=0;i<kernelGp.m_rows;i++)
{
pgp[i].Step3(kernelGp,i);
}
///step4: 生成N^2 * N^2 矩阵
CL2DMatrix kernelGb;
kernelGb.Step4(pgp,kernelGp.m_rows);
//step5 :求两矩阵相乘 kernelGb *srcFb
CL2DMatrix resultTmp;
resultTmp.Step5(kernelGb,srcFb);
//step6 :行列回转
_dstR.Step6(resultTmp,dst_row,dst_cols);
}
/*
假设有两个矩阵a,b,a的大小是ma行na列,b的大小是mb行nb列。
c=conv2(a,b)计算这两个矩阵的卷积,c的大小是ma+mb-1行,na+nb-1列。
计算过程如下:
1.对矩阵a进行边界填补0,填充规则是:在a的第一行之前和最后一行之后分别填充mb-1行0,并在a的第一列之前和最后一列之后分别填充nb-1列0;
2.对矩阵b进行翻转,上下换位左右换位。 rot90(b,2)或者fliplr(flipud(b))或者flipud(fliplr(b))
3.从左上角开始,按照先列后行的顺序在矩阵a上滑动矩阵b,对应的元素相乘然后求和所得的数值为相应的c中的值。
*/
bool CL2DMatrix::Step2_1(CL2DMatrix &_srcA,CL2DMatrix &_srcB)
{
if(ReSize(_srcA.m_rows+2*(_srcA.m_rows-1),_srcA.m_cols+2*(_srcB.m_cols-1)) == false )
return false;
int ii = _srcA.m_rows-1;
int jj = _srcB.m_cols-1;
for(int i=0;i<_srcA.m_rows;i++)
{
for(int j=0;j<_srcB.m_cols;j++)
{
m_Data[ii*m_cols + jj] = _srcA.GetData(i,j);
jj++;
}
jj = _srcB.m_cols-1;
ii++;
}
return true;
}
bool CL2DMatrix::Step2_2(CL2DMatrix &_srcM)
{
if(ReSize(_srcM.m_rows,_srcM.m_cols) == false )
return false;
int nSize = _srcM.m_rows * _srcM.m_cols -1;
for(int i=0;i<=nSize;i++)
{
m_Data[nSize - i] = _srcM[i];
}
return true;
}
bool CL2DMatrix::Step2_3(CL2DMatrix &_srcA,CL2DMatrix &_srcB,const int _nRows,const int _nClos)
{
if(ReSize(_nRows,_nClos) == false )
return false;
for(int i=0;i<_nRows;i++)
{
for(int j=0;j<_nClos;j++)
{
int temp=0;
for(int m=0;m<_srcB.m_rows;m++)
{
for(int n=0;n<_srcB.m_cols;n++)
{
//if((i-m)>=0&&(i-m)<_srcB.m_rows && (j-n)>=0&&(j-n)<_srcB.m_cols)
//{
//temp+=x[m][n]*y[i-m][j-n];
temp+= _srcA.GetData(i+m,j+n) * _srcB.GetData(m,n);
//}
}
}
m_Data[i*m_cols+j] = temp;
}
}
return true;
}
void CLConv::Conv2ByMove(CL2DMatrix &_srcF,CL2DMatrix &_kernelG,CL2DMatrix &_arrResult,const EShapeMode _eMode)
{
int dst_row = _srcF.m_rows + _kernelG.m_rows - 1;
int dst_cols = _srcF.m_cols + _kernelG.m_cols - 1;
//step 1;
CL2DMatrix srcFA;
srcFA.Step2_1(_srcF,_kernelG);
//step 2;
CL2DMatrix srcFB;
srcFB.Step2_2(_kernelG);
//step 2;
_arrResult.Step2_3(srcFA,srcFB,dst_row,dst_cols);
}