from __future__ import print_function
import glob
from itertools import chain
import os
import random
import zipfile
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from linformer import Linformer
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
from vit_pytorch.vit import ViT
# Training settings
batch_size = 64
epochs = 10
lr = 3e-5
gamma = 0.7
seed = 42
print(torch.version.cuda)
print(torch.__version__)
print("cuda:" , torch.cuda.is_available())
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
seed_everything(seed)
device = 'cuda'
os.makedirs('data', exist_ok=True)
train_dir = './data/train'
test_dir = './data/test'
train_list1 = glob.glob(os.path.join(train_dir+'/cats', '*.jpg'))
train_list2 = glob.glob(os.path.join(train_dir+'/dogs', '*.jpg'))
train_list = train_list1+train_list2
test_list1 = glob.glob(os.path.join(test_dir+'/cats', '*.jpg'))
test_list2 = glob.glob(os.path.join(test_dir+'/dogs', '*.jpg'))
test_list = test_list1+test_list2
print(f"Train Data: {len(train_list)}")
print(f"Test Data: {len(test_list)}")
labels = [path.split('/')[-1].split('.')[0] for path in train_list]
train_list, valid_list = train_test_split(train_list,
test_size=0.2,
stratify=labels,
random_state=seed)
print(f"Train Data: {len(train_list)}")
print(f"Validation Data: {len(valid_list)}")
print(f"Test Data: {len(test_list)}")
train_transforms = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]
)
val_transforms = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
]
)
test_transforms = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
]
)
class CatsDogsDataset(Dataset):
def __init__(self, file_list, transform=None):
self.file_list = file_list
self.transform = transform
def __len__(self):
self.filelength = len(self.file_list)
return self.filelength
def __getitem__(self, idx):
img_path = self.file_list[idx]
img = Image.open(img_path)
img_transformed = self.transform(img)
label = img_path.split("/")[-1].split(".")[0]
label = 1 if label == "dog" else 0
return img_transformed, label
train_data = CatsDogsDataset(train_list, transform=train_transforms)
valid_data = CatsDogsDataset(valid_list, transform=test_transforms)
test_data = CatsDogsDataset(test_list, transform=test_transforms)
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)
print(len(train_data), len(train_loader))
print(len(valid_data), len(valid_loader))
model = ViT(
image_size = 256,
patch_size = 16,
num_classes =2,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
).to(device)
print(model)
# from torchsummary import summary
# from ptflops import get_model_complexity_info
#
# summary(model,(3,256,256))
# get_model_complexity_info(model,(3,256,256))
# img = torch.randn(1,3, 256, 256)
# preds = model(img)
# print(preds.shape)
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
best_accuracy = 0
if os.path.exists('best.mdl'):
model.load_state_dict(torch.load('best.mdl'))
for epoch in range(epochs):
#train
epoch_loss = 0
epoch_accuracy = 0
for data, label in tqdm(train_loader):
data = data.to(device)
label = label.to(device)
output = model(data)
loss = criterion(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = (output.argmax(dim=1) == label).float().mean()
epoch_accuracy += acc / len(train_loader)
epoch_loss += loss / len(train_loader)
#validataion
with torch.no_grad():
epoch_val_accuracy = 0
epoch_val_loss = 0
for data, label in valid_loader:
data = data.to(device)
label = label.to(device)
val_output = model(data)
val_loss = criterion(val_output, label)
acc = (val_output.argmax(dim=1) == label).float().mean()
epoch_val_accuracy += acc / len(valid_loader)
epoch_val_loss += val_loss / len(valid_loader)
if epoch_val_accuracy>best_accuracy:
best_accuracy = epoch_val_accuracy
torch.save(model.state_dict(), 'best.mdl')
print(
f"Epoch : {epoch + 1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
)
#test
model.load_state_dict(torch.load('best.mdl'))
print('best_acc:',best_accuracy)
with torch.no_grad():
test_accuracy = 0
test_loss = 0
for data, label in test_loader:
data = data.to(device)
label = label.to(device)
test_output = model(data)
test_loss = criterion(test_output, label)
acc = (test_output.argmax(dim=1) == label).float().mean()
test_accuracy += acc / len(test_loader)
test_loss += test_loss / len(test_loader)
print('test_acc:',test_accuracy)
没有合适的资源?快使用搜索试试~ 我知道了~
Vision Transformer(ViT)实践项目,图像分类任务,“猫狗大战”(猫狗分类)
共2000个文件
jpg:1998个
py:2个
需积分: 0 22 下载量 71 浏览量
2024-02-25
17:38:22
上传
评论 4
收藏 218.41MB RAR 举报
温馨提示
利用ViT模型实现图像分类,本项目具有强大的泛化能力,可以实现任何图像分类任务,只需要修改数据集和类别数目参数。这里采用的是开源的“猫狗大战”数据集,实现猫狗分类。 本项目适用于Transformer初学者,通过该实践项目可以对于ViT模型的原理和结构有清晰地认识,并且可以学会在具体项目中如何运用ViT模型。本项目代码逻辑结构清晰,通俗易懂,适用于任何基础的学习者,是入门深度学习和了解Transformer注意力机制在计算机视觉中运用的绝佳项目。
资源推荐
资源详情
资源评论
收起资源包目录
Vision Transformer(ViT)实践项目,图像分类任务,“猫狗大战”(猫狗分类) (2000个子文件)
cat_or_dog_2.jpg 602KB
cat_or_dog_1.jpg 134KB
dog.3085.jpg 85KB
dog.2274.jpg 79KB
dog.1870.jpg 75KB
dog.3076.jpg 73KB
dog.157.jpg 71KB
dog.1039.jpg 68KB
dog.1201.jpg 65KB
dog.1132.jpg 65KB
dog.1408.jpg 64KB
dog.3782.jpg 62KB
dog.1976.jpg 62KB
dog.55.jpg 62KB
dog.3292.jpg 62KB
dog.3551.jpg 61KB
dog.3071.jpg 61KB
dog.735.jpg 61KB
dog.355.jpg 61KB
dog.1817.jpg 61KB
dog.2649.jpg 60KB
dog.1560.jpg 60KB
dog.811.jpg 60KB
dog.1073.jpg 59KB
dog.1860.jpg 59KB
dog.2112.jpg 57KB
dog.890.jpg 57KB
dog.265.jpg 56KB
dog.3444.jpg 55KB
dog.1460.jpg 55KB
dog.3950.jpg 55KB
dog.886.jpg 55KB
dog.769.jpg 55KB
dog.2763.jpg 54KB
dog.3152.jpg 54KB
dog.2249.jpg 54KB
dog.1948.jpg 54KB
dog.2317.jpg 54KB
dog.2100.jpg 53KB
dog.823.jpg 53KB
dog.898.jpg 53KB
dog.2152.jpg 53KB
dog.1471.jpg 53KB
dog.1769.jpg 52KB
dog.3095.jpg 52KB
dog.641.jpg 52KB
dog.2355.jpg 52KB
dog.2571.jpg 51KB
dog.2728.jpg 51KB
dog.3172.jpg 51KB
dog.491.jpg 51KB
dog.437.jpg 51KB
dog.2698.jpg 51KB
dog.2197.jpg 51KB
dog.1101.jpg 51KB
dog.2499.jpg 51KB
dog.3829.jpg 51KB
dog.2634.jpg 50KB
dog.223.jpg 50KB
dog.2856.jpg 50KB
dog.1636.jpg 50KB
dog.3814.jpg 50KB
dog.896.jpg 50KB
dog.2026.jpg 50KB
dog.744.jpg 50KB
dog.80.jpg 50KB
dog.3080.jpg 50KB
dog.3175.jpg 49KB
dog.3006.jpg 49KB
dog.98.jpg 49KB
dog.952.jpg 49KB
dog.3762.jpg 49KB
dog.2273.jpg 49KB
dog.602.jpg 49KB
dog.902.jpg 49KB
dog.736.jpg 48KB
dog.1146.jpg 48KB
dog.3933.jpg 48KB
dog.3363.jpg 48KB
dog.837.jpg 48KB
dog.1678.jpg 48KB
dog.2055.jpg 47KB
dog.2935.jpg 47KB
dog.2061.jpg 47KB
dog.1466.jpg 47KB
dog.1222.jpg 47KB
dog.8.jpg 47KB
dog.1594.jpg 47KB
dog.2507.jpg 47KB
dog.3470.jpg 46KB
dog.2122.jpg 46KB
dog.3629.jpg 46KB
dog.1405.jpg 46KB
dog.1981.jpg 46KB
dog.286.jpg 46KB
dog.682.jpg 46KB
dog.3312.jpg 46KB
dog.546.jpg 46KB
dog.3064.jpg 45KB
dog.3420.jpg 45KB
共 2000 条
- 1
- 2
- 3
- 4
- 5
- 6
- 20
资源评论
拉普拉斯妖的宇宙
- 粉丝: 89
- 资源: 1
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- hdmi in视频采集,输出到hdmi out, 支持HDMI指令控制,支持TFTP远程下载图片
- 批量word文件内容替换工具1.0 (批量实现多个 Word 文档文件文字替换利器).exe
- Cartoon GUI Pack 1.2.zip
- 【数据集和代码】基于加速度传感器的步态识别行人分类实验(可做步态识别)
- 我分享个魔兽内存修改器
- Python毕业设计基于Django的网易云数据分析可视化大屏系统的设计与实现+使用说明+全部资料(优秀项目).zip
- mp3 idv2,idv1,frame分析工具
- 鹈鹕优化算法POA MATLAB源码, 应用案例为函数极值求解以及优化svm进行分类,代码注释详细,可结合自身需求进行应用
- Python毕业设计基于Django的网易云数据分析可视化大屏系统的设计与实现+使用说明+全部资料(高分项目).zip
- 蛇优化算法SO MATLAB源码, 应用案例为函数极值求解以及优化svm进行分类,代码注释详细,可结合自身需求进行应用
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功