import torch
import pandas as pd
from torch import nn, optim
import numpy as np
import torch.utils.data.dataloader as DataLoader
from torch.autograd import Variable
import matplotlib.pyplot as plt
from main import IrisDataSet
from main.Net import ClasifyNet
trainPath = "./datasets/iris.csv"
testPath = "./datasets/iris_test.csv"
epoches = 200
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
torch.nn.init.normal_(m.weight.data, mean=0, std=1)
def compute_accuracy(x, y):
z = x == y
length = z.shape[0]
length = torch.tensor(length).float()
print("length = {}".format(length))
summ = z.sum(dim=0).float()
print("summ = {}".format(summ))
return summ / length
if __name__ == '__main__':
# net = ClasifyNet()
# net.apply(weights_init)
net = torch.load("./test.pth")
criterion = nn.CrossEntropyLoss(size_average = True)
optimizer = optim.Adam(net.parameters(), lr=0.0001)
# 数据准备
train_dataset = IrisDataSet(trainPath)
train_data_loader = DataLoader.DataLoader(train_dataset, batch_size=150, shuffle=True, num_workers=4)
index = 1
num_x = []
num_y = []
for epoch in range(epoches):
for i, item in enumerate(train_data_loader):
optimizer.zero_grad()
data, label = item
data = Variable(data.float())
label = Variable(label.long())
y = net(data)
y_temp = y.argmax(dim=1)
accuracy = compute_accuracy(label, y_temp)
print("y = {}".format(accuracy))
loss = criterion(y, label)
loss.backward()
optimizer.step()
if i == 0:
print("losss = {}".format(loss.item()))
num_x.append(index)
index += 1
num_y.append(loss)
plt.plot(num_x, num_y)
plt.show()
torch.save(net, f='./test.pth')
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
在机器学习领域中,“鸢尾花”是指一个经典的多类分类问题的数据集,称为“Iris dataset”或“安德森鸢尾花卉数据集”。该数据集最早由英国统计学家兼生物学家罗纳德·费雪(Ronald Fisher)于1936年收集并整理发表,包含了150个样本观测值,对应三种不同类型的鸢尾花(Setosa、Versicolor、Virginica),每种类型各50个样本。 每个样本有四个特征: 萼片长度(Sepal Length) 萼片宽度(Sepal Width) 花瓣长度(Petal Length) 花瓣宽度(Petal Width) 这些特征都是连续数值型变量,而目标变量则是鸢尾花所属的类别。鸢尾花数据集常被用作新手入门机器学习算法时的第一个实践项目,因为它数据量适中且易于理解,同时适用于多种监督学习算法,如逻辑回归、K近邻(KNN)、支持向量机(SVM)、决策树以及各种集成方法等。
资源推荐
资源详情
资源评论
收起资源包目录
鸢尾花识别.zip (13个子文件)
content
.idea
misc.xml 288B
inspectionProfiles
profiles_settings.xml 174B
modules.xml 280B
deployment.xml 359B
.gitignore 38B
flower_clafier.iml 453B
main
__init__.py 1KB
eval.py 981B
Net.py 636B
datasets
iris_test.csv 1014B
iris.csv 4KB
iris_clasifier.py 2KB
test.pth 28KB
共 13 条
- 1
资源评论
生瓜蛋子
- 粉丝: 3794
- 资源: 4173
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- ItemApplicationTest.java
- 个人发卡源码,发卡系统,二次元发卡系统,二次元发卡源码,发卡程序,动漫发卡,PHP发卡源码,异次元发卡
- 基于matlab 决策树分类器的应用研究-乳腺癌诊断源代码+详细教程
- 2008全国电子设计竞赛优秀作品报告doc文档.zip
- 课程智能组卷系统 JAVA+Spring+SpringMVC+MyBatis
- 基于matlab LVQ神经网络的预测-人脸朝向识别源代码+详细教程
- Controlnet敏神大佬IC-Light的AI智能打光 AI这次真的大地震了
- 医院电子病历管理系统 JAVA+Spring+SpringMVC+MyBatis
- 基于matlab LVQ神经网络的分类-乳腺肿瘤诊断源代码+详细教程
- 【C#/.NET/.NET Core学习、工作、面试指南】记录、收集和总结
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功