from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from model import vgg
import tensorflow as tf
import json
import os
data_root = os.path.abspath(os.path.join(os.getcwd(), "..")) # get data root path
image_path = data_root + "/flower_data/" # flower data set path
train_dir = image_path + "train"
validation_dir = image_path + "val"
# create direction for saving weights
if not os.path.exists("save_weights"):
os.makedirs("save_weights")
im_height = 224
im_width = 224
batch_size = 32
epochs = 10
# data generator with data augmentation
train_image_generator = ImageDataGenerator(rescale=1. / 255,
horizontal_flip=True)
validation_image_generator = ImageDataGenerator(rescale=1. / 255)
train_data_gen = train_image_generator.flow_from_directory(directory=train_dir,
batch_size=batch_size,
shuffle=True,
target_size=(im_height, im_width),
class_mode='categorical')
total_train = train_data_gen.n
# get class dict
class_indices = train_data_gen.class_indices
# transform value and key of dict
inverse_dict = dict((val, key) for key, val in class_indices.items())
# write dict into json file
json_str = json.dumps(inverse_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
val_data_gen = train_image_generator.flow_from_directory(directory=validation_dir,
batch_size=batch_size,
shuffle=True,
target_size=(im_height, im_width),
class_mode='categorical')
total_val = val_data_gen.n
model = vgg("vgg16", 224, 224, 5)
model.summary()
# using keras high level api for training
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
metrics=["accuracy"])
callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='./save_weights/myAlex_{epoch}.h5',
save_best_only=True,
save_weights_only=True,
monitor='val_loss')]
# tensorflow2.1 recommend to using fit
history = model.fit(x=train_data_gen,
steps_per_epoch=total_train // batch_size,
epochs=epochs,
validation_data=val_data_gen,
validation_steps=total_val // batch_size,
callbacks=callbacks)
没有合适的资源?快使用搜索试试~ 我知道了~
Test3_vgg_VGG网络的源码_VGG源码_源码
共11个文件
xml:4个
py:3个
pyc:1个
1.该资源内容由用户上传,如若侵权请联系客服进行举报
2.虚拟产品一经售出概不退款(资源遇到问题,请及时私信上传者)
2.虚拟产品一经售出概不退款(资源遇到问题,请及时私信上传者)
版权申诉
0 下载量 61 浏览量
2021-10-02
01:39:22
上传
评论
收藏 6KB RAR 举报
温馨提示
这个是python的VGG源码,在TensorFlow下的
资源推荐
资源详情
资源评论
收起资源包目录
Test3_vgg.rar (11个子文件)
Test3_vgg
class_indices.json 108B
train.py 3KB
__pycache__
model.cpython-36.pyc 2KB
predict.py 841B
.idea
Test3_vgg.iml 338B
misc.xml 199B
modules.xml 277B
workspace.xml 5KB
.gitignore 50B
inspectionProfiles
profiles_settings.xml 174B
model.py 2KB
共 11 条
- 1
资源评论
浊池
- 粉丝: 48
- 资源: 4783
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- Pytorch-pytorch深度学习教程之逻辑回归.zip
- Pytorch-pytorch深度学习教程之双向循环网络.zip
- Pytorch-pytorch深度学习教程之卷积神经网络.zip
- Pytorch-pytorch深度学习教程之前馈神经网络.zip
- Pytorch-pytorch深度学习教程之线性回归.zip
- Pytorch-pytorch深度学习教程之基本操作.zip
- 基于QT的地图可视化桌面系统后台数据库为MySQL5.7源码.zip
- 基于simulink的PLL锁相环系统仿真【包括模型,文档,参考文献,操作步骤】
- 基于EM-GMM模型的目标跟踪和异常行为检测matlab仿真【包括程序,注释,参考文献,操作步骤,说明文档】
- 2109010044_胡晨燕_选课管理数据库设计与实现.prj
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功