"""
不使用数据增强的快速特征提取
"""
import os
import numpy as np
from keras.preprocessing.image import ImageDataGenerator
from keras.applications import VGG16
conv_base = VGG16(weights='imagenet',include_top=False,input_shape=(150, 150, 3))
base_dir='D:/pic/small_dogvscat'
train_dir=os.path.join(base_dir,'train')
validation_dir=os.path.join(base_dir,'validation')
test_dir=os.path.join(base_dir,'test')
data=ImageDataGenerator(rescale=1./255)
batch_size=20
def extra_features(dic,sample_count):
features=np.zeros(shape=(sample_count,4,4,512))
labels=np.zeros(shape=(sample_count))
generator=data.flow_from_directory(
dic,
target_size=(150,150),
batch_size=batch_size,
class_mode='binary'
)
i=0
for input_batch,labels_batch in generator:
features_batch=conv_base.predict(input_batch)
features[i*batch_size:(i+1)*batch_size]=features_batch
labels[i*batch_size:(i+1)*batch_size]=labels_batch
i=i+1
if i*batch_size>=sample_count:
break
return features,labels
train_features,train_labels=extra_features(train_dir,2000)
validation_features,validation_labels=extra_features(validation_dir,1000)
test_features,test_labels=extra_features(test_dir,1000)
train_features=np.reshape(train_features,(2000,4*4*512))
validation_features=np.reshape(validation_features,(1000,4*4*512))
test_features=np.reshape(test_features,(1000,4*4*512))
from keras import models
from keras import layers
from keras import optimizers
model = models.Sequential()
model.add(layers.Dense(256, activation='relu', input_dim=4 * 4 * 512))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(1, activation='sigmoid'))
model.compile(optimizer=optimizers.RMSprop(lr=2e-5),
loss='binary_crossentropy',
metrics=['acc'])
history = model.fit(train_features, train_labels,
epochs=30,
batch_size=20,
validation_data=(validation_features, validation_labels))
import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(acc) + 1)
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'r', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()
plt.figure()
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()

博士僧小星
- 粉丝: 2560
最新资源
- 关于C语言跟踪调试方法.doc
- 基于PLC的转速测量.doc
- 财务管理:会计实务:Excel建立采购成本的分析表.pdf
- 工程特点、难点与项目管理重点.doc
- 区块链技术的实际应用场景.ppt
- 大数据环境下商业银行客户标签体系构建.doc
- 中国电子商务协会职业经理人认证机构合作协议(范本).doc
- 操作系统实验报告--实验一--进程管理.doc
- 编程题复习要点.doc
- 综合信息化业务合作协议.doc
- 财务管理:财务会计网络化的实施步骤.pdf
- 项目管理九大管理的输入、输出.doc
- 通信资源互换合作协议.doc
- 信息化建设项目管理办法.doc
- 工具变量-市场准入负面清单数据集(DID).xlsx
- PLC课后习题.doc
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈


