import os
import sys
import json
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm
from model import AlexNet
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') # 若有GPU默认使用第一块,没有则用CPU
print("using {} device.".format(device))
# 数据预处理
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224), # 随机裁剪成224 x 224
transforms.RandomHorizontalFlip(), # 水平方向随机翻转
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]),
"val": transforms.Compose([transforms.Resize((224,224)), # 不能是224,必须是(224, 224)
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])}
# Load data!!!
data_root = os.path.abspath(os.path.join(os.getcwd(),"../..")) # get data root path
image_path = os.path.join(data_root, "SZ-EEG", "0-Dataset", 'data_preprocess') # flower data set path
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
# train dataset
# dataset = datasets.DatasetFolder(root=os.path.join(image_path, "EEG_Dataset.npy"))
# Label = datasets.DatasetFolder(root=os.path.join(image_path, 'EEG_Labels.npy'))
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), # load dataset
transform=data_transform["train"])
train_num = len(train_dataset)
# {'NC':0, 'SZ':1}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items()) # 将key和value反过来
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 32
# number of workers, 多线程处理数, Windows系统下只能输入0,Linux系统能输入非0数值
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=nw)
# validate dataset
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=nw)
print("Using {} images for training, {} images for validation.".format(train_num,
val_num))
# 查看验证数据集,validate_loader里的batch_size要改为4,shuffle改为True
# test_data_iter = iter(validate_loader)
# test_image, test_label = test_data_iter.next()
#
# def imshow(img):
# img = img / 2 + 0.5 # unnormalize
# npimg = img.numpy()
# plt.imshow(np.transpose(npimg, (1, 2, 0)))
# plt.show()
#
# print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
# imshow(utils.make_grid(test_image))
net = AlexNet(num_classes=2, init_weights=True)
net.to(device) # 将网络分配到GPU/CPU上
# 定义损失函数和优化器
loss_function = nn.CrossEntropyLoss()
# pata = list(net.parameters()) # 用来调试查看模型参数
optimizer = optim.Adam(net.parameters(), lr=0.0002)
epochs = 10
save_path = './AlexNet.pth'
best_acc = 0.0
train_steps = len(train_loader)
for epoch in range(epochs):
# train
net.train() # 管理dropout方法的使用,只在训练过程实现dropout
running_loss = 0.0
t1 = time.perf_counter() # 统计训练一个epoch所需时间,t1是开始时间
train_bar = tqdm(train_loader, file=sys.stdout) # tqdm 是进度条库
for step, data in enumerate(train_bar):
images, labels = data
optimizer.zero_grad()
outputs = net(images.to(device))
loss = loss_function(outputs, labels.to(device))
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
# # print train process,用*和.打印进度条
# rate = (step + 1) / len(train_loader) # len(train_loader):训练一轮所需步数,rate:当前训练进度
# a = "*" * int(rate * 50)
# b = "." * int((1 - rate) * 50)
# print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate*100),a,b,loss),end="")
# print()
train_bar.desc = "train epoch[{}/{}] loss: {:.3f}".format(epoch+1, epochs, loss)
print(time.perf_counter() - t1)
# validate
net.eval() # 在验证过程中关闭dropout方法
acc = 0.0 # accumulate accurate number / epoch
with torch.no_grad(): # 禁止pytorch对验证过程的参数进行跟踪,即在验证过程不会计算损失梯度
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1] # 求输出中最大值,对应最有可能的类别
acc += torch.eq(predict_y, val_labels.to(device)).sum().item() # acc:累计验证集中预测正确的样本个数
# 若 predict_y == val_labels 即预测正确为True(1),累加正确个数的对应的数值(item)加到acc中
val_accurate = acc / val_num # 验证准确率
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch+1, running_loss / train_steps, val_accurate))
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
print('Finished Training')
if __name__ == '__main__':
main()
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
Test2_AlexNet.zip (4个子文件)
read_pth.py 3KB
predict.py 2KB
model.py 3KB
train.py 7KB
共 4 条
- 1
资源评论
KanSY
- 粉丝: 2
- 资源: 6
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- jdk - 22.0.2 - macos
- 在Windows系统中管理Mac磁盘的实用工具-在Windows系统中创建并使用Mac磁盘,读取Mac磁盘中的文件
- PFC理论基础与Matlab仿真模型学习笔记(1)-PFC电路概述
- 吞食天地2马腾传.nes
- 西部数据发布的一款西数硬盘检测修复工具-支持WD-L/WD-ROYL板,能进行硬盘软复位,可识别硬盘查看或清除-供大家学习参考
- wwwwwwwwwwwwwwwwwww
- 利用恒源云在云端租用GPU服务器训练YOLOv8模型(包括Linux系统命令讲解)_恒源云跑模型-CSDN博客.html
- python自学教程-12-sql注入和防止sql注入.ev4.rar
- python自学教程-11-pymsql对数据库的增删改操作.ev4.rar
- python自学教程-10-pymysql的查询语句操作.ev4.rar
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功