import torch
import torch.nn as nn
from torch import optim
import matplotlib.pyplot as plt
import argparse,os
from create_h5py import create_file
from model import build_model
from dataset import load_datasets_h5py,load_datasets_readImage
device = "cuda" if torch.cuda.is_available() else "cpu"
def default_argument_parser():
parser = argparse.ArgumentParser(description="pytorch-dataset-study")
parser.add_argument('--test',action="store_true",help="only test the model")
parser.add_argument('-M','--model-type', default=0,type=int)
parser.add_argument('-L','--load-type', default=0,type=int)
return parser
def change_dim(pic):
'''change dimension from [C H W] to [H W C]'''
return pic.permute(1,2,0)
def main(args):
model = build_model(args.model_type)
if args.load_type == 0:
if not os.path.exists('train.hdf5') or not os.path.exists('test.hdf5'):
create_file()
training_dataloader ,test_dataloader = load_datasets_h5py()
print('using h5py method to load')
else:
training_dataloader ,test_dataloader = load_datasets_readImage()
print('using readImage method to load')
optimizer = optim.SGD(model.parameters(),lr= 1e-4,momentum=0.5)
loss_fn = nn.CrossEntropyLoss()
EPOCH = 100
loss_all = []
if not args.test:
print('start training')
for epoch in range(EPOCH):
print(f'\n-----------epoch {epoch}-----------')
loss = train(model,training_dataloader,optimizer,loss_fn,epoch=epoch)
loss_all.append(loss)
test(model,test_dataloader)
plt.plot(loss_all)
plt.savefig(f"model_weights/{model.__class__.__name__}.png")
plt.show()
plt.close()
torch.save(model.state_dict(), f"model_weights/{model.__class__.__name__}.pth")
print("Saved PyTorch Model State to model.pth")
model = build_model(args.model_type)
model.load_state_dict(torch.load(f"model_weights/{model.__class__.__name__}.pth"))
labels = {0:'bird',1:'flower'}
model.eval()
plt.figure(figsize=(8, 4))
for id,data in enumerate(test_dataloader):
if isinstance(data,list):
image = data[0].type(torch.FloatTensor).to(device)
#target = data[1].to(device)
elif isinstance(data,dict):
image = data['image'].type(torch.FloatTensor).to(device)
#target = data['target'].to(device)
else :
raise TypeError
plt.title("image-show")
with torch.no_grad():
output =nn.Softmax(dim=1)(model(image))
pred = output.argmax(dim = 1).cpu().numpy()
plt.ion()
for i in range(1,5):
plt.subplot(1,4,i)
plt.title(labels[pred[i-1]])
plt.imshow(change_dim(image[i-1].cpu()))
plt.pause(3)
plt.show()
def train(model,train_dataloader,optimizer,loss_fn,epoch):
model.train()
loss_total = 0
for _, data in enumerate(train_dataloader):
if isinstance(data,list):
image = data[0].type(torch.FloatTensor).to(device)
target = data[1].to(device)
elif isinstance(data,dict):
image = data['image'].type(torch.FloatTensor).to(device)
target = data['target'].to(device)
else :
print(type(data))
raise TypeError
#print(target)
optimizer.zero_grad()
output = model(image)
#print(output)
loss = loss_fn(output, target)
loss_total+=loss.item()
loss.backward()
optimizer.step()
#exit(0)
print(f'{round(loss_total,2)} in epoch {epoch}')
return loss_total
def test(model,test_dataloader):
model.eval()
correct = 0
for _ , data in enumerate(test_dataloader):
if isinstance(data,list):
image = data[0].type(torch.FloatTensor).to(device)
target = data[1].to(device)
elif isinstance(data,dict):
image = data['image'].type(torch.FloatTensor).to(device)
target = data['target'].to(device)
else :
raise TypeError
with torch.no_grad():
output = model(image)
pred = nn.Softmax(dim=1)(output)
correct += (pred.argmax(1) == target).type(torch.float).sum().item()
print(f'accurency = {correct}/{len(test_dataloader)*4} = {correct/len(test_dataloader)/4}')
if __name__ == "__main__":
args = default_argument_parser().parse_args()
main(args)
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
基于pytorch使用神经网络完成 鸟花 二分类问题(包含模型训练、样本数据等全套内容).zip 1、该资源内项目代码经过严格调试,下载即用确保可以运行! 2、该资源适合计算机相关专业(如计科、人工智能、大数据、数学、电子信息等)正在做课程设计、期末大作业和毕设项目的学生、或者相关技术学习者作为学习资料参考使用。 3、该资源包括全部源码,需要具备一定基础才能看懂并调试代码。 基于pytorch使用神经网络完成 鸟花 二分类问题(包含模型训练、样本数据等全套内容).zip 1、该资源内项目代码经过严格调试,下载即用确保可以运行! 2、该资源适合计算机相关专业(如计科、人工智能、大数据、数学、电子信息等)正在做课程设计、期末大作业和毕设项目的学生、或者相关技术学习者作为学习资料参考使用。 3、该资源包括全部源码,需要具备一定基础才能看懂并调试代码。 基于pytorch使用神经网络完成 鸟花 二分类问题(包含模型训练、样本数据等全套内容).zip 1、该资源内项目代码经过严格调试,下载即用确保可以运行! 2、该资源适合计算机相关专业(如计科、人工智能、大数据、数学、电子信息等)正在做课程
资源推荐
资源详情
资源评论
收起资源包目录
基于pytorch使用神经网络完成 鸟花 二分类问题(包含模型训练、样本数据等全套内容).zip (67个子文件)
code_111230
create_h5py.py 3KB
dif-transform.py 1KB
model_weights
resnet18-5c106cde.pth 44.66MB
main.py 5KB
pytorch-tutorial.py 3KB
dataset.py 3KB
model.py 2KB
training-set
bird
6368df61bc6be51ed776b7faa6b82570.jpeg 105KB
47ef371e4226ec6bba5b8092926a9bf0.jpeg 61KB
71d9a4c1bf90904cceb0edc98c53b8a8.jpeg 14KB
73ee32e532f932a4bc34a16684df6f1c.jpeg 86KB
7f6879f668629a75457dd93feefc6209.jpeg 24KB
58f42d7c5abc1eedf16e61c21772ab67.jpeg 28KB
46fd4de58c2266223e83a21370027180.jpeg 71KB
2db22000281fc3cf7f87bb6ac6c0ceed.jpeg 26KB
515c9cfdb994892d0d55e5518cb4a22d.jpeg 58KB
34c6580b9b0c38854480e547f01d0795.jpeg 57KB
29e71ce6aaebc9f06beffa9d9f01a21b.jpeg 34KB
717ca03970624b0c459426bc188de4cb.jpeg 95KB
01791f41a341270bbdb1b4c65eabfc41.jpeg 287KB
3e621180c51d70fa9dad8c427655367a.jpeg 99KB
0a78cb587f941d59489befa89285bb63.jpeg 24KB
6da78b30fc5a1dc9dc3cc0d6e451baee.jpeg 130KB
8bbde8be67c46cb07eab262224afeabe.jpeg 107KB
4f23da4fbf1f79be772ed4057938cbfd.jpeg 38KB
246f43a069b6fef6577129bc6fa2e9cb.jpeg 31KB
130b9bfb1cd594ba7b12dae9df90f8a2.jpeg 87KB
flower
62ef9e022166f33bfad7401d6d9c8dbb.jpeg 32KB
0ecee7747d3388bd929b19df526fe7ab.jpeg 44KB
9c3b079c3318c70031ba67381d00fc4a.jpeg 34KB
9892491da39cd63bb0c8444871e4da18.jpeg 20KB
bff530f2deef67d18b34db7ddd4f61ee.jpeg 88KB
bca392d862b6c4dde9ba1b7ae73a1b6a.jpeg 56KB
17e2ad2fed14d6d0b8d6999245c9ba4a.jpeg 53KB
8fb8d5004c26bd8edd2646f0ee8b90bf.jpeg 37KB
bd9e1b9c4df5993554a61fb611c7df94.jpeg 43KB
821c4074003640f103794a462c5abe8a.jpeg 62KB
a0a369ae79108320343e81b8cffebd0f.jpeg 63KB
0754a441ffed0d494036e9f82ccc883c.jpeg 65KB
71ebd1a8848170a22abea3d7c5d9e5b3.jpeg 43KB
0e49e6273e41648fb7a4f1b804cc63a2.jpeg 111KB
bd4dcd9b6e5a891c8a7c16a678de3851.jpeg 41KB
400fa3440ecbb83f4a2f9e0330904570.jpeg 75KB
bcb3596678a42d2058f4d634ec756ac1.jpeg 74KB
aa8e9760b2267d8ead1a4fae74a22e04.jpeg 49KB
38eb0ef9ef0b8289b3da07b5aecaa761.jpeg 501KB
0bd85256a37087c086813bc7f5cd4398.jpeg 53KB
test-set
bird
a87ff6bb8fec05c9527676a844dec788.jpeg 371KB
e439e2d54c2b6911c769beb5cfcf28a8.jpeg 71KB
d9e2a7551da3dac10066d42140d7f571.jpeg 26KB
c6977f982380fb4f86ba730102c75406.jpeg 63KB
ddd3b055c5d3077e6ade6046b42c57bf.jpeg 23KB
612706af4fb052215447c466472e1394.jpeg 29KB
ba86e9c05865bde7014db50b9c22e48a.jpeg 63KB
66138d1941439c5bc1cbf1c94a19aa0f.jpeg 93KB
b295ba3fbc48e225353780b07adf0802.jpeg 403KB
f295daf76a073e41284276fcaf03c334.jpeg 245KB
flower
d0c88e3de35faead0d28aeb8e6cc8e27.jpeg 40KB
cb5ee06378c5ca07be5cf85f3575c704.jpeg 54KB
e3413f158c38ff457a6484cadfcd0c1e.jpeg 46KB
ed1c7f9559cb2ee623f9a53d2d9f9238.jpeg 77KB
c86f4be37f0096c53096db9d23e45863.jpeg 231KB
cb0b9d22bc770958731afa3170f45d95.jpeg 33KB
d0432f1c83ad65afd7b3a4d55d850d92.jpeg 18KB
c7b0ba6f0649fc55132b25d34e50f65f.jpeg 44KB
f13d062ad33b794f9a14c0279f0a295d.jpeg 53KB
c6bfa10e741759b76bbae64a77dbbf83.jpeg 53KB
共 67 条
- 1
资源评论
辣椒种子
- 粉丝: 4063
- 资源: 5733
下载权益
C知道特权
VIP文章
课程特权
开通VIP
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功