# flower_photos: 5种花朵原始图片集(雏菊、蒲公英、玫瑰、向日葵、郁金香)
# config.py 配置文件将存储驱动程序脚本中使用的重要变量和参数。与其在每个脚本中重新定义它们只需在这里定义一次(从而使代码更干净、更容易阅读)
# create_dataloader.py help函数,Dataloader加载flower_photos
# output/ 存放训练损失图及模型
# build_dataset.py 根据flower_photos目录构建数据集目录,将创建特殊的子目录来存储训练和验证拆分,允许PyTorch的ImageFolder脚本来解析目录并训练模型
# train_feature_extraction.py 执行特征提取的迁移学习,并把模型存储磁盘
# fine_tune.py 执行基于微调的迁移学习,并把模型存储磁盘
# inference.py 接受经过训练的PyTorch模型,并使用它对输入的花朵图像进行预测
# 要实现的第一种迁移学习方法是特征提取
# 通过特征提取进行迁移学习的工作原理如下:
# 采用预先训练的CNN(通常在ImageNet数据集上),从CNN上卸下FC(Fully Connection)层头,将网络主体的输出视为空间维度为M×N×C的任意特征提取器
# 分类器有俩个选择:
# 采用标准的逻辑回归分类器(如scikit学习库中的分类器),并根据每个图像中提取的特征对其进行训练。或者,更简单地说,将softmax分类器放在网络主体的顶部,
# 任何一种选择都是可行的,而且或多或少与另一种“相同”。
# 当提取的特征数据集适合机器的RAM时,第一个选项非常有效。这样可以加载整个数据集,实例化逻辑回归分类器模型的一个实例,然后对其进行训练。
# 当数据集太大而无法放入机器内存时,就会出现问题。当这种情况发生时,你可以使用类似在线学习的方法来训练你的逻辑回归分类器,但这只是引入了另一组库和依赖项。
# 相反,更容易的是利用PyTorch的强大功能,在提取的特征之上创建一个类似逻辑回归的分类器,然后使用PyTorch函数对其进行训练。
# 训练特征提取模型,执行该脚本后,将在输出目录中找到一个名为warmup_model.pth的文件——该文件是序列化PyTorch模型,然后可以用于在inference.py脚本中进行预测。
# 总的训练时间只有5分钟多一点,获得了84.26%的训练准确率和87.74%的验证准确率。
# USAGE
# python train_feature_extraction.py
# 导入必要的包
from pyimagesearch import config
from pyimagesearch import create_dataloaders # 从输入数据集目录创建PyTorch DataLoader的实例
from imutils import paths
from torchvision.models import resnet50 # 要使用的ImageNet的预训练模型
from torchvision import transforms # 允许定义一组预处理和/或数据增强,将依次应用于输入图像
from tqdm import tqdm # 用于创建格式良好的进度条的Python库
from torch import nn # 包含PyTorch的神经网络类和函数
import matplotlib.pyplot as plt
import numpy as np
import torch # 包含PyTorch的神经网络类和函数
import time
# 定义增强管道(使用Compose函数构建数据处理/扩充步骤,该函数位于PyTorch的transforms子模块中。
# 首先创建一个trainTransform,在给定输入图像的情况下,它将:
# 随机调整图像大小并将其裁剪为image_SIZE尺寸
# 随机执行水平翻转
# 在[-90,90]范围内随机执行旋转
# 将生成的图像转换为PyTorch张量
# 执行平均值减法和缩放,同样的用于验证数据集的 valTransform
# 请注意,我们不在验证转换器中执行数据扩充——没有必要对验证数据执行数据扩充。
# 创建了训练和验证Compose对象后,让我们应用get_datalader函数:)
trainTansform = transforms.Compose([
transforms.RandomResizedCrop(config.IMAGE_SIZE),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(90),
transforms.ToTensor(),
transforms.Normalize(mean=config.MEAN, std=config.STD)
])
valTransform = transforms.Compose([
transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=config.MEAN, std=config.STD)
])
# 创建DataLoader
(trainDS, trainLoader) = create_dataloaders.get_dataloader(config.TRAIN,
transforms=trainTansform,
batchSize=config.FEATURE_EXTRACTION_BATCH_SIZE)
(valDS, valLoader) = create_dataloaders.get_dataloader(config.VAL,
transforms=valTransform,
batchSize=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False)
# 通过特征提取为迁移学习准备ResNet50模型
# 加载预训练的ImageNet ResNet50 model
model = resnet50(pretrained=True)
# 由于使用ResNet50模型作为特征提取器,设置其参数为不可训练(默认情况下是可训练的)
for param in model.parameters():
param.requires_grad = False
# 将一个新的分类顶部附加到我们的特征提取器并弹出它,连接到当前设备
# 创建一个由单个FC层组成的新FC层头。实际上当使用分类交叉熵损失进行训练时,这一层将作为代理softmax分类器。
# 然后,这个新层被附加到网络主体,模型本身被移动到设备(CPU或GPU)。
modelOutputFeats = model.fc.in_features
model.fc = nn.Linear(modelOutputFeats, len(trainDS.classes))
model = model.to(config.DEVICE)
# 接下来,初始化损失函数和优化方法(注意只是向优化器提供分类顶部的参数)
lossFunc = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.fc.parameters(), lr=config.LR)
# 计算训练/验证集的每一个纪元步数
trainSteps = len(trainDS) // config.FEATURE_EXTRACTION_BATCH_SIZE
valSteps = len(valDS) // config.FEATURE_EXTRACTION_BATCH_SIZE
# 初始化字典以存储训练历史
H = {"train_loss": [], "train_acc": [], "val_loss": [],
"val_acc": []}
# 遍历纪元
print("[INFO] training the network...")
startTime = time.time()
for e in tqdm(range(config.EPOCHS)):
# 设置模型训练模式
model.train()
# 初始化训练/验证损失
totalTrainLoss = 0
totalValLoss = 0
# 初始化训练/验证集中的预测正确个数
trainCorrect = 0
valCorrect = 0
# 遍历训练集
# 对于trainLoader中的每一批数据,将图像和类标签移动到CPU/GPU、对数据进行预测、计算损失,计算梯度,更新模型权重,并将梯度归零
# 累积在该时期的总训练损失、计算正确预测的总数
for (i, (x, y)) in enumerate(trainLoader):
# 传递输入到设备
(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))
# 向前传递并计算训练损失
pred = model(x)
loss = lossFunc(pred, y)
# 计算损失梯度
loss.backward()
# 检查是否正在更新模型参数,如果是 更新它们,并将之前累积的梯度清零
if (i + 2) % 2 == 0:
opt.step()
opt.zero_grad()
# 将损失加上迄今为止的总训练损失,同样累加正确预测的数量
totalTrainLoss += loss
trainCorrect += (pred.argmax(1) == y).type(
torch.float).sum().item()
# 关闭autograd并将模型置于评估模式中——这是使用PyTorch进行评估时的要求
# switch off autograd
with torch.no_grad():
# 设置模型为评估模式
model.eval()
# 在valLoader中循环所有数据点,对它们进行预测,并计算总损失和正确验证预测的数量。
# 遍历验证集
for (x, y) in valLoader:
# �
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
flower_photos: 5种花朵原始图片集(雏菊、蒲公英、玫瑰、向日葵、郁金香) config.py 配置文件将存储驱动程序脚本中使用的重要变量和参数。与其在每个脚本中重新定义它们只需在这里定义一次(从而使代码更干净、更容易阅读) create_dataloader.py help函数,Dataloader加载flower_photos output/ 存放训练损失图及模型 build_dataset.py 根据flower_photos目录构建数据集目录,将创建特殊的子目录来存储训练和验证拆分,允许PyTorch的ImageFolder脚本来解析目录并训练模型 train_feature_extraction.py 执行特征提取的迁移学习,并把模型存储磁盘 fine_tune.py 执行基于微调的迁移学习,并把模型存储磁盘 inference.py 接受经过训练的PyTorch模型,并使用它对输入的花朵图像进行预测
资源推荐
资源详情
资源评论
收起资源包目录
使用PyTorch执行特征提取和微调的迁移学习来进行图像分类 (749个子文件)
2431737309_1468526f8b.jpg 275KB
2431737309_1468526f8b.jpg 275KB
8717900362_2aa508e9e5.jpg 260KB
8717900362_2aa508e9e5.jpg 260KB
244074259_47ce6d3ef9.jpg 237KB
244074259_47ce6d3ef9.jpg 237KB
3711892138_b8c953fdc1_z.jpg 227KB
3711892138_b8c953fdc1_z.jpg 227KB
5874818796_3efbb8769d.jpg 224KB
5874818796_3efbb8769d.jpg 224KB
2816503473_580306e772.jpg 223KB
2816503473_580306e772.jpg 223KB
3254533919_cb0b8af26c.jpg 223KB
3254533919_cb0b8af26c.jpg 223KB
10094731133_94a942463c.jpg 220KB
10094731133_94a942463c.jpg 220KB
8695372372_302135aeb2.jpg 215KB
8695372372_302135aeb2.jpg 215KB
3568925290_faf7aec3a0.jpg 215KB
3568925290_faf7aec3a0.jpg 215KB
2256230386_08b54ca760.jpg 211KB
2256230386_08b54ca760.jpg 211KB
466486216_ab13b55763.jpg 210KB
466486216_ab13b55763.jpg 210KB
155097272_70feb13184.jpg 208KB
155097272_70feb13184.jpg 208KB
13471273823_4800ca8eec.jpg 202KB
13471273823_4800ca8eec.jpg 202KB
530738000_4df7e4786b.jpg 201KB
530738000_4df7e4786b.jpg 201KB
310380634_60e6c79989.jpg 191KB
310380634_60e6c79989.jpg 191KB
7176729016_d73ff2211e.jpg 187KB
7176729016_d73ff2211e.jpg 187KB
6116210027_61923f4b64.jpg 184KB
6116210027_61923f4b64.jpg 184KB
2256214682_130c01d9d9.jpg 180KB
2256214682_130c01d9d9.jpg 180KB
8713396140_5af8136136.jpg 179KB
8713396140_5af8136136.jpg 179KB
175638423_058c07afb9.jpg 179KB
175638423_058c07afb9.jpg 179KB
5674125303_953b0ecf38.jpg 178KB
5674125303_953b0ecf38.jpg 178KB
4263272885_1a49ea5209.jpg 178KB
4263272885_1a49ea5209.jpg 178KB
4932144003_cbffc89bf0.jpg 173KB
4932144003_cbffc89bf0.jpg 173KB
155646858_9a8b5e8fc8.jpg 172KB
155646858_9a8b5e8fc8.jpg 172KB
8174935013_b16626b49b.jpg 172KB
8174935013_b16626b49b.jpg 172KB
11614202956_1dcf1c96a1.jpg 170KB
11614202956_1dcf1c96a1.jpg 170KB
6207492986_0ff91f3296.jpg 169KB
6207492986_0ff91f3296.jpg 169KB
4414084638_03d2db38ae.jpg 166KB
4414084638_03d2db38ae.jpg 166KB
15333843782_060cef3030.jpg 162KB
15333843782_060cef3030.jpg 162KB
3858508462_db2b9692d1.jpg 162KB
3858508462_db2b9692d1.jpg 162KB
864957037_c75373d1c5.jpg 159KB
864957037_c75373d1c5.jpg 159KB
184682506_8a9b8c662d.jpg 158KB
184682506_8a9b8c662d.jpg 158KB
4496277750_8c34256e28.jpg 147KB
4496277750_8c34256e28.jpg 147KB
7166646966_41d83cd703.jpg 146KB
7166646966_41d83cd703.jpg 146KB
9535500195_543d0b729b.jpg 144KB
9535500195_543d0b729b.jpg 144KB
13513851673_9d813dc7b0.jpg 142KB
13513851673_9d813dc7b0.jpg 142KB
3231873181_faf2da6382.jpg 139KB
3231873181_faf2da6382.jpg 139KB
4666648087_b10f376f19.jpg 138KB
4666648087_b10f376f19.jpg 138KB
14044685976_0064faed21.jpg 138KB
14044685976_0064faed21.jpg 138KB
4586018734_6de9c513c2.jpg 137KB
4586018734_6de9c513c2.jpg 137KB
138166590_47c6cb9dd0.jpg 135KB
138166590_47c6cb9dd0.jpg 135KB
4933822272_79af205b94.jpg 135KB
4933822272_79af205b94.jpg 135KB
12163418275_bd6a1edd61.jpg 134KB
12163418275_bd6a1edd61.jpg 134KB
9300335851_cdf1cef7a9.jpg 131KB
9300335851_cdf1cef7a9.jpg 131KB
13561966423_e5c641fe11.jpg 131KB
13561966423_e5c641fe11.jpg 131KB
251811158_75fa3034ff.jpg 130KB
251811158_75fa3034ff.jpg 130KB
3496258301_ca5f168306.jpg 130KB
3496258301_ca5f168306.jpg 130KB
10466558316_a7198b87e2.jpg 127KB
10466558316_a7198b87e2.jpg 127KB
5598845098_13e8e9460f.jpg 125KB
5598845098_13e8e9460f.jpg 125KB
共 749 条
- 1
- 2
- 3
- 4
- 5
- 6
- 8
资源评论
- YJJSg2024-01-15资源中能够借鉴的内容很多,值得学习的地方也很多,大家一起进步!
- null_code2023-10-12资源有一定的参考价值,与资源描述一致,很实用,能够借鉴的部分挺多的,值得下载。
- m0_689114842024-03-20资源不错,对我启发很大,获得了新的灵感,受益匪浅。
- hzxiaohong282024-05-01资源内容详实,描述详尽,解决了我的问题,受益匪浅,学到了。
程序媛一枚~
- 粉丝: 4w+
- 资源: 30
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功