# coding: utf-8
# # MNIST手写体数字识别问题研究
# ### 1 训练集、验证集、测试集数据划分
# In[1]:
from keras.datasets import mnist
from keras.utils import to_categorical
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
(train_images0, train_labels0), (test_images, test_labels) = mnist.load_data()
train_images0 = train_images0.reshape((60000, 28, 28,1)) #(height, width, channels) 3D张量
train_images0 = train_images0.astype('float32') / 255
test_images = test_images.reshape((10000, 28, 28,1))
test_images = test_images.astype('float32') / 255
train_images, val_images, train_labels, val_labels = train_test_split(train_images0, train_labels0, test_size = 0.15)
#sns.countplot(train_labels0)
sns.countplot(test_labels)
# In[2]:
from collections import Counter
Counter(train_labels0)
# In[3]:
import tensorflow as tf
#from tensorflow import keras
import keras
from keras import layers
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D, BatchNormalization
from keras import models
model = models.Sequential() #网络层的叠加
model.add(Conv2D(32, kernel_size = 3, activation='relu', input_shape = (28, 28, 1)))#conv2D(output_depth, window_height,
#window_width) (输出特征图的深度:卷积所计算的过滤器的数量,
#从输入中提取的图块尺寸3*3)
model.add(BatchNormalization())
model.add(Conv2D(32, kernel_size = 3, activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(32,kernel_size=5,strides=2,padding='same',activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.4)) #Dropout层有效降低过拟合,相当于正则化
model.add(Conv2D(64, kernel_size = 3, activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(64, kernel_size = 3, activation='relu'))
model.add(BatchNormalization())
model.add(Conv2D(64,kernel_size=5,strides=2,padding='same',activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.4))
model.add(layers.Flatten()) #形状(4,4,64)的输出被展平为形状为(1024,)的向量
model.add(layers.Dense(128, activation='relu')) #密集连接分类器,处理1D向量
model.add(BatchNormalization())
model.add(Dropout(0.4))
model.add(layers.Dense(10, activation='softmax')) #返回一个由10个概率值组成的数组
model.summary()
# In[4]:
Counter(test_labels)
# In[ ]:
# CREATE MORE IMAGES VIA DATA AUGMENTATION
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ReduceLROnPlateau
datagen = ImageDataGenerator(
rotation_range=10,
zoom_range = 0.10,
width_shift_range=0.1,
height_shift_range=0.1)
learning_rate_reduction = ReduceLROnPlateau(monitor='val_loss', #如果验证损失不再改善,使用回调函数来降低学习率
patience=4,
verbose=2, #0 = silent, 1 = progress bar, 2 = one line per epoch.
factor=0.2,
min_lr=0.00001)
datagen.fit(train_images)
# ### 2 CNN网络构建——未用池化层,加了BatchNormalization()和dropout层
# In[ ]:
### 3 数据增强及回调函数的使用
# ### 4 模型迭代,优化参数
# In[ ]:
model.compile(optimizer='rmsprop', #优化器。通过梯度下降的方法减小损失
loss='categorical_crossentropy', #损失函数用于学习权重张量的反馈信号
metrics=['accuracy'])
train_labels = to_categorical(train_labels) #label 编码
val_labels = to_categorical(val_labels)
test_labels = to_categorical(test_labels)
history = model.fit_generator(datagen.flow(train_images, train_labels, batch_size=64),
steps_per_epoch = train_images.shape[0]//64,epochs = 50,
callbacks=[learning_rate_reduction],verbose=2,
validation_data=(val_images, val_labels)) #训练数据上迭代50次(每小批量64个样本),在所有数据
#上迭代一次后叫做一轮次,一轮一共938次。50轮过后
#进行46900次梯度更新
# ### 5 实验结果
# In[ ]:
from sklearn.metrics import confusion_matrix
import numpy as np
from sklearn.utils.multiclass import unique_labels
#----------------------------画出迭代结果图------------------------------#
fig, ax = plt.subplots(2,1)
ax[0].plot(history.history['loss'], color='b', label="Training loss")
ax[0].plot(history.history['val_loss'], color='r', label="Validation loss",axes =ax[0])
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Loss')
legend = ax[0].legend(loc='best', shadow=True)
ax[1].plot(history.history['acc'], color='b', label="Training accuracy")
ax[1].plot(history.history['val_acc'], color='r',label="Validation accuracy")
ax[1].set_xlabel('Epochs')
ax[1].set_ylabel('Accuracy')
legend = ax[1].legend(loc='best', shadow=True)
#----------------------------测试集分类结果-----------------------------#
test_loss, test_accuracy = model.evaluate(test_images, test_labels)
print("Test Loss = {0:.6f}, Test Accuracy = {1:0.6f}".format(test_loss, test_accuracy))
#----------------------------测试集混淆矩阵-----------------------------#
def plot_confusion_matrix(y_true, y_pred, classes, #定义画出混淆矩阵的函数
normalize=False,
title=None,
cmap=plt.cm.Blues):
if not title:
if normalize:
title = 'Normalized confusion matrix'
else:
title = 'Confusion matrix, without normalization'
# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
#print(cm)
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
ax.figure.colorbar(im, ax=ax)
# We want to show all ticks...
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
# ... and label them with the respective list entries
xticklabels=classes, yticklabels=classes,
title=title,
ylabel='True label',
xlabel='Predicted label')
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
return ax
predict_labels = model.predict(test_images)
predict_labels = np.argmax(predict_labels, axis = 1) #取出元素最大值所对应的索引, axis表示按行方向搜索
test_labels1 = np.argmax(test_labels, axis = 1)
plot_confusion_matrix(test_labels1, predict_labels, classes = range(10))
# #### 统计正确及错误分类结果
# In[ ]:
correct = np.nonzero(predict_labels==test_labels1)[0]
incorrect = np.nonzero(predict_labels!=test_labels1)[0]
print("Correct predicted classes:",correct.shape[0])
print("Incorrect predicted classes:",incorrect.shape[0])
# In[ ]:
plt.figure(figsize=(15,8))
for i in range(30):
plt.subplot(4, 8, i+1)
index = incorrect[i]
plt.imshow(test_images[index].res
评论0