package kmeans;
import java.io.*;
import jxl.Cell;
import jxl.CellType;
import jxl.NumberCell;
import jxl.Sheet;
import jxl.Workbook;
public class K_Means2 {
int Flags = 0;
int rows, columns, k;
int[] logo; // 标志位
double[][] train;// 将对象实体化
double[][] centers;// k个中心点数
double[][] tempCenters;// k个新中心点
public K_Means2(int pointNum, int pointLength, int centerNum) {
rows = pointNum;
columns = pointLength;
k = centerNum;
logo=new int[rows];
train=new double[rows][columns];
centers=new double[k][columns];
tempCenters=new double[k][columns];
// 初始化标志位
for (int i = 0; i < rows; i++) {
logo[i] = 0;
}
//初始化所有点
for(int i=0;i<rows;i++){
for(int j=0;j<columns;j++){
train[i][j]=0.0;
}
}
}
/***************************************************************************
* 读文件data0.txt中数据,保存至train[][]二维数组中
**************************************************************************/
public void readtxt(String filename) throws Exception {
try {
InputStream is = new FileInputStream(new File(filename));
jxl.Workbook rwb = Workbook.getWorkbook(is);
Sheet rs = rwb.getSheet(0);
for (int i = 0; i < rs.getRows(); i++) {
for (int j = 0; j <rs.getColumns(); j++) {
Cell cell=rs.getCell(j, i);
if(cell.getType()==CellType.NUMBER){
NumberCell numCell=(NumberCell)cell;
double cellValue=numCell.getValue();
train[i][j]=cellValue;
}
}
}
rwb.close();
is.close();
} catch (Exception e) {
e.printStackTrace();
}
}
/***************************************************************************
* 初始化新旧中心点
**************************************************************************/
public void initial()
{
//初始化旧中心点,不妨以前k个为中心点
for (int i = 0; i < k; i++){
for(int j=0;j<columns;j++){
centers[i][j]=train[i][j];
}
}
//初始化新中心点
for(int i=0;i<k;i++){
for(int j=0;j<columns;j++){
tempCenters[i][j]=0.0;
}
}
}
/***************************************************************************
* 比较差值,确定所属类
**************************************************************************/
public void plus()// 求方差
{
double dis[] = new double[k];
for (int i = 0; i < rows; i++) {
for(int m = 0;m < k; m++){
dis[m]=0.0;
for(int j=0;j<columns;j++){
dis[m]+=Math.pow((train[i][j]-centers[m][j]), 2);
}
}
token(dis,i);// 标记logo[i] ,以logo=1,2,3表示所属类
}
}
/***************************************************************************
* 求数组中最小值的索引
**************************************************************************/
public int minIndex(double[] dis)
{
int index=0;
double tempMin=dis[0];
for(int i=0;i<dis.length;i++)
{
if(tempMin>dis[i]){
tempMin=dis[i];
index=i;
}
}
return index;
}
/***************************************************************************
* 更改标志位(重新分配,看点属于那个簇),用1--k标记,分别表示k个类别
**************************************************************************/
public void token(double[] dis,int i) {
int index=minIndex(dis);
logo[i]=index;//将第i个元素归为第index类
}
/***************************************************************************
* 每聚类一次,求一次新的中心点
**************************************************************************/
public void center() {
int num[]=new int[k];
for(int i=0;i<k;i++)
num[i]=0;
for(int m=0;m<k;m++){
for(int i=0;i<rows;i++){
if(logo[i]==m){
arraryAdd(tempCenters[m],train[i],columns);
num[m]++;
}
}
}
for(int i=0;i<k;i++){
for(int j=0;j<columns;j++){
tempCenters[i][j]=tempCenters[i][j]/(double)num[i];
}
}
}
/***************************************************************************
* 求两个数组(向量)的和
**************************************************************************/
public void arraryAdd(double[]a, double[] b, int length){
for(int i=0;i<length;i++){
a[i]=a[i]+b[i];
}
}
/***************************************************************************
* 判断聚类运算是否结束
**************************************************************************/
public void compare()// 新旧中心点
{
double[] errors;
errors=new double[k];
for(int i=0;i<k;i++){
errors[i]=0.0;
}
for (int i = 0; i < k; i++) {
for(int j=0;j<columns;j++){
errors[i]+=Math.pow((tempCenters[i][j]-centers[i][j] ),2);
}
}
// 符合要求
for(int i=0;i<k;i++){
if(errors[i]>0.01){
Flags=0;
break;
}
else
Flags=1;
}
}
/***************************************************************************
* 用新的簇中心替换旧的簇中心
**************************************************************************/
public void DuplicateCenter() {
if(Flags==0){
for(int i=0;i<k;i++){
for(int j=0;j<columns;j++){
centers[i][j]=tempCenters[i][j];
}
}
}
}
/***************************************************************************
* 写文件到cluster_result.txt中
**************************************************************************/
public void WrietTxt(String writeFilename) throws Exception {
BufferedWriter bw = new BufferedWriter(new FileWriter(
writeFilename));
int i = 0;
while (i < train.length) {
String s = "";
// for (int j = 0; j < columns; j++) {
// s += " " + String.valueOf(train[i][j]) + " ";// 获得train[][]每一行数据,转化为字符串类型
// }
s+=" "+String.valueOf(i)+" ";
bw.write(s + String.valueOf(logo[i]));// 加入标志位
bw.newLine();
i++;
}
bw.close();
}
public static void main(String args[]) throws Exception {
K_Means2 app = new K_Means2(239,64,8);// 创建一个类实例
app.readtxt("D:\\239xls.xls");
//app.set();
app.initial();
app.plus();
app.center();
// 不停的循环
while (app.Flags == 0) {
app.DuplicateCenter();
app.plus();
app.center();
app.compare();
}
app.WrietTxt("D:\\239cluster.txt");
//
// for(int i=0;i<app.train.length;i++)//显示结果-测试用
// {
// // System.out.println(i+"----"+ app.logo[i]);
// // };
// // new MainFrame(app);
// }
}
}
- 1
- 2
- 3
- 4
- 5
- 6
前往页