import torch
import torch.nn as nn
from torchvision import transforms, datasets
import json
import os
import torch.optim as optim
from model import MobileNetV2
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
image_path = "./data_set/" # flower data set path
train_dataset = datasets.ImageFolder(root=image_path+"train",
transform=data_transform["train"])
train_num = len(train_dataset)
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=6)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 16
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=0)
validate_dataset = datasets.ImageFolder(root=image_path + "val",
transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=0)
net = MobileNetV2(num_classes=7)
# load pretrain weights
model_weight_path = "./mobilenet_v2-b0353104.pth"
pre_weights = torch.load(model_weight_path)
# delete classifier weights
pre_dict = {k: v for k, v in pre_weights.items() if "classifier" not in k}
missing_keys, unexpected_keys = net.load_state_dict(pre_dict, strict=False)
# freeze features weights
for param in net.features.parameters():
param.requires_grad = False
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.00001)
best_acc = 0.0
save_path = './bestmodel.pth'
# save_path = './save/mobilenet_v2_1.4_224_'
for epoch in range(50):
# train
net.train()
running_loss = 0.0
for step, data in enumerate(train_loader, start=0):
images, labels = data
optimizer.zero_grad()
logits = net(images.to(device))
loss = loss_function(logits, labels.to(device))
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
# print train process
rate = (step+1)/len(train_loader)
a = "*" * int(rate * 50)
b = "." * int((1 - rate) * 50)
print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100), a, b, loss), end="")
print()
# validate
net.eval()
acc = 0.0 # accumulate accurate number / epoch
with torch.no_grad():
for val_data in validate_loader:
val_images, val_labels = val_data
outputs = net(val_images.to(device)) # eval model only have last output layer
# loss = loss_function(outputs, test_labels)
predict_y = torch.max(outputs, dim=1)[1]
acc += (predict_y == val_labels.to(device)).sum().item()
val_accurate = acc / val_num
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %
(epoch + 1, running_loss / step, val_accurate))
print('Finished Training')
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
基于深度学习卷积神经网络实现垃圾识别分类系统源码(含数据集+模型).zip 资源介绍: 1、建议安装anaconda和pycharm,项目在pycharm中运行。 2、在anaconda中安装必要的软件包。 3、程序源码可训练和测试。 【备注】 该项目是个人毕设项目源码,评审分达到95分,都经过严格调试,确保可以运行!放心下载使用。 该项目资源主要针对计算机、自动化等相关专业的学生或从业者下载使用,也可作为期末课程设计、课程大作业、毕业设计等。 具有较高的学习借鉴价值!基础能力强的可以在此基础上修改调整,以实现类似其他功能。
资源推荐
资源详情
资源评论
收起资源包目录
基于深度学习卷积神经网络实现垃圾识别分类系统源码(含数据集+模型).zip (192个子文件)
垃圾识别.iml 284B
vegetable (4).jpg 15KB
cigarette (5).jpg 14KB
cigarette (20).jpg 14KB
mask (15).JPG 13KB
vegetable (1).jpg 13KB
vegetable (1).jpg 12KB
vegetable (2).jpg 12KB
vegetable (7).jpg 12KB
battery (3).jpg 12KB
mask (10).JPG 12KB
vegetable (10).jpg 12KB
mask (18).JPG 12KB
vegetable (5).jpg 12KB
vegetable (9).jpg 12KB
mask (12).JPG 12KB
battery (11).jpg 12KB
vegetable (17).jpg 12KB
mask (20).JPG 12KB
vegetable (2).jpg 12KB
mask (7).JPG 11KB
vegetable (4).jpg 11KB
mask (5).JPG 11KB
mask (19).JPG 11KB
mask (16).JPG 11KB
vegetable (16).jpg 11KB
vegetable (3).jpg 11KB
vegetable (5).jpg 11KB
mask (4).JPG 11KB
vegetable (3).jpg 11KB
vegetable (19).jpg 11KB
vegetable (18).jpg 11KB
mask (6).JPG 11KB
vegetable (20).jpg 11KB
mask (11).JPG 10KB
mask (3).JPG 10KB
mask (2).JPG 10KB
vegetable (8).jpg 10KB
mask (9).JPG 10KB
mask (17).JPG 10KB
mask (1).JPG 10KB
vegetable (15).jpg 10KB
mask (14).JPG 10KB
vegetable (14).jpg 9KB
mask (1).JPG 9KB
vegetable (11).jpg 9KB
vegetable (12).jpg 9KB
vegetable (13).jpg 9KB
mask (5).JPG 9KB
mask (2).JPG 9KB
mask (3).JPG 9KB
vegetable (6).jpg 9KB
cigarette (19).jpg 9KB
injector (4).JPG 9KB
mask (13).JPG 8KB
battery (1).jpg 8KB
injector (16).JPG 8KB
battery (15).jpg 8KB
battery (16).jpg 8KB
injector (2).JPG 8KB
battery (14).jpg 8KB
injector (2).JPG 7KB
injector (20).JPG 7KB
injector (15).JPG 7KB
injector (17).JPG 7KB
injector (14).JPG 7KB
injector (18).JPG 7KB
battery (1).jpg 7KB
injector (1).JPG 7KB
battery (2).JPG 7KB
injector (3).JPG 7KB
injector (19).JPG 7KB
battery (5).jpg 7KB
battery (3).JPG 7KB
bottle (15).JPG 6KB
injector (7).JPG 6KB
battery (13).jpg 6KB
bottle (3).JPG 6KB
injector (11).JPG 6KB
mask (8).jpg 6KB
battery (5).JPG 6KB
swab (10).JPG 6KB
bottle (20).JPG 6KB
bottle (19).JPG 6KB
injector (8).JPG 6KB
battery (7).JPG 6KB
battery (8).JPG 6KB
battery (9).JPG 6KB
bottle (7).JPG 6KB
bottle (2).JPG 6KB
bottle (3).JPG 6KB
bottle (16).JPG 6KB
bottle (8).JPG 6KB
bottle (14).JPG 6KB
injector (9).JPG 6KB
bottle (4).JPG 6KB
swab (9).JPG 6KB
battery (2).jpg 6KB
battery (4).jpg 6KB
battery (10).JPG 6KB
共 192 条
- 1
- 2
资源评论
- dawdjisuagfh2024-03-28假的,内容假的z同学的编程之路2024-04-18什么假的?什么问题,你有问题嘛?上来就差评,这么无耻嘛
- m0_464017242024-04-25资源内容详细全面,与描述一致,对我很有用,有一定的使用价值。
- 2301_767428182024-02-27资源不错,内容挺好的,有一定的使用价值,值得借鉴,感谢分享。z同学的编程之路2024-04-18感谢认可
- stellajj_662023-10-29资源很受用,资源主总结的很全面,内容与描述一致,解决了我当下的问题。z同学的编程之路2024-04-18谢谢你的支持和认可~
z同学的编程之路
- 粉丝: 2356
- 资源: 2134
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功