import torch
import torch.nn as nn
from matplotlib import pyplot as plt
import os
from Dataset.Dataset import csv_dataset
from module.model import MLPClassifier
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import numpy as np
import pandas as pd
import tqdm
from sklearn.model_selection import StratifiedKFold
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('[Info]: use:', device)
Data_dir = './Datas/Datas.csv'
label_dir = './Datas/labels.csv'
np_file_dir = './Datas/random_allnums.npy'
SAVEDIR = './runs'
ALL_ACC, ALL_SEN, ALL_PRE = 0, 0, 0
for K_flod in range(1, 6):
train_set = csv_dataset(Data_dir, label_dir, k=K_flod, train=True, path_np=np_file_dir)
test_set = csv_dataset(Data_dir, label_dir, k=K_flod, train=False, path_np=np_file_dir)
train_loader = DataLoader(dataset=train_set,
batch_size=8,
shuffle=True, )
test_loader = DataLoader(dataset=test_set,
batch_size=8,
shuffle=True, )
model = MLPClassifier()
model = model.to(device)
weights = torch.FloatTensor([2])
criterion = nn.BCEWithLogitsLoss(pos_weight=weights).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)
loss_list = []
TP, TN, FP, FN = 0, 0, 0, 0
for epoch in range(15):
bar = tqdm.tqdm(train_loader)
for X, y in bar:
bar.set_description("epoch: %s" % str(epoch))
X = X.to(device, dtype=torch.float32)
y = y.to(device, dtype=torch.float32)
y_pred = model(X)
loss = criterion(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_list.append(loss.data.item())
bar.set_postfix(loss=loss.data.item())
torch.save(model.state_dict(), os.path.join(SAVEDIR, str(epoch)) + '.pt')
model.eval()
bar2 = tqdm.tqdm(test_loader)
for X, y in bar2:
X = X.to(device, dtype=torch.float32)
# y = y.to(device, dtype=torch.float32)
y_pred = model(X)
y_pred = F.sigmoid(y_pred)
y_pred[y_pred > 0.5] = 1
y_pred[y_pred <= 0.5] = 0
y_pred = torch.Tensor.cpu(y_pred)
sTP = np.array(((y == 1) & (y_pred == 1)).sum())
sFN = np.array(((y == 1) & (y_pred == 0)).sum())
sTN = np.array(((y == 0) & (y_pred == 0)).sum())
sFP = np.array(((y == 0) & (y_pred == 1)).sum())
TP = sTP + TP
TN = sTN + TN
FP = sFP + FP
FN = sFN + FN
print('TP:',TP)
print('TN:',TN)
print('FP:',FP)
print('FN:',FN)
acc = (TP + TN) / (TP + TN + FP + FN)
sen = (TP) / (FN + TP)
pre = (TP) / (TP + FP)
print('flod:' + str(K_flod) + ' acc:', acc)
print('flod:' + str(K_flod) + ' sen:', sen)
print('flod:' + str(K_flod) + ' pre:', pre)
ALL_ACC = ALL_ACC + acc
ALL_SEN = ALL_SEN + sen
ALL_PRE = ALL_PRE + pre
plt.plot(np.linspace(0, 100, len(loss_list)), loss_list)
plt.show()
print('Flod:',K_flod,' Finish!')
print('==========================')
print('ALL_ACC:', ALL_ACC / 5)
print('ALL_SEN:', ALL_SEN / 5)
print('ALL_PRE:', ALL_PRE / 5)
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
毕业设计基于python实现电信客户流失预测与分析源码.zip毕业设计基于python实现电信客户流失预测与分析源码.zip毕业设计基于python实现电信客户流失预测与分析源码.zip毕业设计基于python实现电信客户流失预测与分析源码.zip毕业设计基于python实现电信客户流失预测与分析源码.zip毕业设计基于python实现电信客户流失预测与分析源码.zip毕业设计基于python实现电信客户流失预测与分析源码.zip毕业设计基于python实现电信客户流失预测与分析源码.zip毕业设计基于python实现电信客户流失预测与分析源码.zip毕业设计基于python实现电信客户流失预测与分析源码.zip毕业设计基于python实现电信客户流失预测与分析源码.zip 【备注】 项目多为高分毕设,评审平均分达到95分以上,都经过本地验证,运行OK后上传,可直接运行起来。 主要针对计算机相关专业的正在做毕设的学生和需要项目实战的Java、JavaScript、c#、游戏开发、小程序开发学习者、深度学习等专业方向。 也可作为课程设计、期末大作业。包含:项目源码、数据库、项目说明等,该项目可以直接作为毕设、课程设计使用。 也可以用来学习参考借鉴!
资源推荐
资源详情
资源评论
收起资源包目录
毕业设计基于python实现电信客户流失预测与分析源码.zip (37个子文件)
train_bf.py 2KB
runs
6.pt 99KB
11.pt 99KB
14.pt 99KB
1.pt 99KB
18.pt 99KB
2.pt 99KB
4.pt 99KB
5.pt 99KB
19.pt 99KB
7.pt 99KB
8.pt 99KB
12.pt 99KB
0.pt 99KB
10.pt 99KB
15.pt 99KB
16.pt 99KB
9.pt 99KB
3.pt 99KB
13.pt 99KB
17.pt 99KB
utils
__init__.py 0B
loss_func.py 2KB
.idea
misc.xml 196B
Mining.iml 284B
inspectionProfiles
profiles_settings.xml 174B
modules.xml 264B
.gitignore 182B
encodings.xml 157B
module
__init__.py 0B
model.py 1KB
data_process.zip 2KB
train.py 3KB
data_process
op_1to2.py 1KB
op_2.py 1KB
op_0to1.py 1KB
down_dim.py 513B
共 37 条
- 1
资源评论
- 2401_848371042024-05-11有没有分析结果呀onnx2024-05-11有的呢
- 2301_775139922024-01-07资源不错,很实用,内容全面,介绍详细,很好用,谢谢分享。onnx2024-05-11不客气,对你有用就好
- 2301_801361292023-12-27资源不错,对我启发很大,获得了新的灵感,受益匪浅。onnx2024-05-11谢谢支持和对项目的认可~
onnx
- 粉丝: 9971
- 资源: 5626
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- 自动化应用驱动的容器弹性管理平台解决方案
- 各种排序算法 Python 实现的源代码
- BlurAdmin 是一款使用 AngularJs + Bootstrap实现的单页管理端模版,视觉冲击极强的管理后台,各种动画效果
- 基于JSP+Servlet的网上书店系统源代码项目包含全套技术资料.zip
- GGJGJGJGGDGGDGG
- 基于SpringBoot的毕业设计选题系统源代码项目包含全套技术资料.zip
- Springboot + mybatis-plus + layui 实现的博客系统源代码全套技术资料.zip
- 智慧农场小程序源代码全套技术资料.zip
- 大数据技术毕业设计源代码全套技术资料.zip
- renren-ui-nodejs安装及环境配置
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功