import os
import sys
import json
import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm
from model import vgg
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
transform = transforms.Compose([
transforms.ToTensor(), # 将数据转为 tensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化
])
# data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path
# image_path = os.path.join(data_root, "data_set", "flower_data") # flower data set path
# assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
trainset = datasets.CIFAR10(root="./data", train=True, download=False, transform=transform)
# testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=False, transform=transform)
train_num = len(trainset)
# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = trainset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 32
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(trainset,
batch_size=batch_size, shuffle=True,
num_workers=2)
# # validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
# transform=data_transform["val"])
testset = datasets.CIFAR10(root="./data", train=False, download=False, transform=transform)
val_num = len(testset)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
print("using {} images for training, {} images for validation.".format(train_num,
val_num))
# test_data_iter = iter(validate_loader)
# test_image, test_label = test_data_iter.next()
model_name = "vgg16"
net = vgg(model_name=model_name, num_classes=10, init_weights=True)
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)
epochs = 30
best_acc = 0.0
save_path = './{}Net.pth'.format(model_name)
train_steps = len(train_loader)
for epoch in range(epochs):
# train
net.train()
running_loss = 0.0
train_bar = tqdm(train_loader, file=sys.stdout)
for step, data in enumerate(train_bar):
images, labels = data
optimizer.zero_grad()
outputs = net(images.to(device))
loss = loss_function(outputs, labels.to(device))
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
epochs,
loss)
# validate
net.eval()
acc = 0.0 # accumulate accurate number / epoch
with torch.no_grad():
val_bar = tqdm(testloader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
val_accurate = acc / val_num
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
print('Finished Training')
if __name__ == '__main__':
main()
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
Test3_vggnet.zip (21个子文件)
class_indices.json 182B
data
cifar-10-batches-py
data_batch_4 29.6MB
data_batch_2 29.6MB
data_batch_1 29.6MB
batches.meta 158B
data_batch_5 29.6MB
test_batch 29.6MB
readme.html 88B
data_batch_3 29.6MB
vgg16Net.pth 128.33MB
train.py 4KB
__pycache__
model.cpython-39.pyc 3KB
model.cpython-37.pyc 2KB
predict.py 2KB
.idea
Test3_vggnet.iml 328B
misc.xml 196B
modules.xml 276B
workspace.xml 4KB
.gitignore 47B
inspectionProfiles
profiles_settings.xml 174B
model.py 3KB
共 21 条
- 1
资源评论
板砖大师
- 粉丝: 0
- 资源: 1
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功