import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
import os
import cv2
import matplotlib.pyplot as plt
import pandas as pd
from tensorflow import keras
import glob
from keras.layers import Dropout
from keras.layers.normalization import BatchNormalization
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 定义图片路径
path = 'photo/'
# 读取图像
def read_img(path):
cate = [path + x for x in os.listdir(path) if os.path.isdir(path + x)]
imgs = []
labels = []
fpath = []
for idx, folder in enumerate(cate):
# 遍历整个目录判断每个文件是不是符合
for im in glob.glob(folder + '/*.jpg'):
# print('reading the images:%s' % (im))
img = cv2.imread(im) # 调用opencv库读取像素点
img = cv2.resize(img, (32, 32)) # 图像像素大小一致
imgs.append(img) # 图像数据
labels.append(idx) # 图像类标
fpath.append(path + im) # 图像路径名
# print(path+im, idx)
return np.asarray(fpath, np.string_), np.asarray(imgs, np.float32), np.asarray(labels, np.int32)
# 读取图像
fpaths, data, label = read_img(path)
# print(data.shape) # (1000, 256, 256, 3)
# 计算有多少类图片
num_classes = len(set(label))
# 定义随机种子,将数据进行打乱处理
np.random.seed(116)
np.random.shuffle(data)
np.random.seed(116)
np.random.shuffle(label)
x_train, x_test, y_train, y_test = train_test_split(data, label, test_size=0.2, random_state=1,stratify=label)
x_train,x_valid,y_train,y_valid= train_test_split(x_train, y_train, test_size=0.1, random_state=1,stratify=y_train)
#定义0-9所对应的实际类别
target_names=['people','beaches','buildings','trucks','dinosaurs',
'elephants','flowers','horses','mountains','food']
#数据归一化处理
from sklearn.preprocessing import StandardScaler
scaler=StandardScaler()
x_train_scaler=scaler.fit_transform(
x_train.astype(np.float32).reshape(-1,1)).reshape(-1,32,32,3)
x_valid_scaler=scaler.fit_transform(
x_valid.astype(np.float32).reshape(-1,1)).reshape(-1,32,32,3)
x_test_scaler=scaler.fit_transform(
x_test.astype(np.float32).reshape(-1,1)).reshape(-1,32,32,3)
#模型构建
model=keras.models.Sequential()
model.add(keras.layers.Conv2D(filters=32,kernel_size=3,padding='same',activation='relu',input_shape=(32,32,3)))
model.add(keras.layers.Conv2D(filters=32,kernel_size=3,padding='same',activation='relu'))
model.add(keras.layers.MaxPool2D(pool_size=2))
# model.add(keras.layers.Conv2D(filters=256,kernel_size=3,padding='same',activation='relu'))
# model.add(keras.layers.Conv2D(filters=256,kernel_size=3,padding='same',activation='relu'))
# model.add(keras.layers.MaxPool2D(pool_size=2))
# model.add(keras.layers.Conv2D(filters=128,kernel_size=3,padding='same',activation='relu'))
# model.add(keras.layers.Conv2D(filters=128,kernel_size=3,padding='same',activation='relu'))
# model.add(keras.layers.MaxPool2D(pool_size=2))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(32,activation='relu'))
# model.add(Dropout(0.25)) #随机退出率
model.add(keras.layers.Dense(10,activation="softmax"))
model.compile(loss="sparse_categorical_crossentropy",
optimizer="Adam",metrics=["accuracy"],
dropout=0.5)
model.summary()
#模型训练
history=model.fit(x_train_scaler,y_train,epochs=8,
validation_data=(x_valid_scaler,y_valid))
#画出模型在训练过程中的损失函数和准确率的变化
def plot_learning(history):
pd.DataFrame(history.history).plot(figsize=(8,5))
plt.grid(True)
plt.gca().set_ylim(0,1)
plt.title('训练过程损失函数和准确率的变化图')
plt.show()
plot_learning(history)
#模型预测
predict=model.predict(x_test_scaler)
# predict=model.evaluate(x_val_scaler,y_val)
predict=np.argmax(predict,axis=1)
#输出测试集的精确率、召回率和F1测度等指标
print (classification_report(y_test, predict,target_names=target_names))
#输出混淆矩阵
def plot_confusion_matrix(cm,title='混淆矩阵热力图',cmap=plt.cm.binary):
# plt.figure(figsize=(11, 11))
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
xlocations = np.array(range(len(target_names)))
plt.xticks(xlocations, target_names, rotation=0)
plt.yticks(xlocations, target_names)
plt.ylabel('正确标签')
plt.xlabel('预测标签')
plt.show()
conmatrix=confusion_matrix(y_true=y_test, y_pred=predict)
plot_confusion_matrix(conmatrix)
print(conmatrix)
#绘制混淆矩阵
import seaborn as sn
df_cm=pd.DataFrame(conmatrix,
index=target_names,
columns=target_names)
plt.figure(figsize=(11,11))
# plt.yticks(rotation=90)
plt.title('混淆矩阵对应正误个数图')
sn.heatmap(df_cm,annot=True,cmap="BuPu")
plt.show()