import torch
import torch.nn as nn
from matplotlib import pyplot as plt
import os
from Dataset.Dataset import csv_dataset
from module.model import MLPClassifier
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import numpy as np
import pandas as pd
import tqdm
from sklearn.model_selection import StratifiedKFold
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('[Info]: use:', device)
Data_dir = './Datas/Datas.csv'
label_dir = './Datas/labels.csv'
np_file_dir = './Datas/random_allnums.npy'
SAVEDIR = './runs'
ALL_ACC, ALL_SEN, ALL_PRE = 0, 0, 0
for K_flod in range(1, 6):
train_set = csv_dataset(Data_dir, label_dir, k=K_flod, train=True, path_np=np_file_dir)
test_set = csv_dataset(Data_dir, label_dir, k=K_flod, train=False, path_np=np_file_dir)
train_loader = DataLoader(dataset=train_set,
batch_size=8,
shuffle=True, )
test_loader = DataLoader(dataset=test_set,
batch_size=8,
shuffle=True, )
model = MLPClassifier()
model = model.to(device)
weights = torch.FloatTensor([2])
criterion = nn.BCEWithLogitsLoss(pos_weight=weights).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)
loss_list = []
TP, TN, FP, FN = 0, 0, 0, 0
for epoch in range(15):
bar = tqdm.tqdm(train_loader)
for X, y in bar:
bar.set_description("epoch: %s" % str(epoch))
X = X.to(device, dtype=torch.float32)
y = y.to(device, dtype=torch.float32)
y_pred = model(X)
loss = criterion(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_list.append(loss.data.item())
bar.set_postfix(loss=loss.data.item())
torch.save(model.state_dict(), os.path.join(SAVEDIR, str(epoch)) + '.pt')
model.eval()
bar2 = tqdm.tqdm(test_loader)
for X, y in bar2:
X = X.to(device, dtype=torch.float32)
# y = y.to(device, dtype=torch.float32)
y_pred = model(X)
y_pred = F.sigmoid(y_pred)
y_pred[y_pred > 0.5] = 1
y_pred[y_pred <= 0.5] = 0
y_pred = torch.Tensor.cpu(y_pred)
sTP = np.array(((y == 1) & (y_pred == 1)).sum())
sFN = np.array(((y == 1) & (y_pred == 0)).sum())
sTN = np.array(((y == 0) & (y_pred == 0)).sum())
sFP = np.array(((y == 0) & (y_pred == 1)).sum())
TP = sTP + TP
TN = sTN + TN
FP = sFP + FP
FN = sFN + FN
print('TP:',TP)
print('TN:',TN)
print('FP:',FP)
print('FN:',FN)
acc = (TP + TN) / (TP + TN + FP + FN)
sen = (TP) / (FN + TP)
pre = (TP) / (TP + FP)
print('flod:' + str(K_flod) + ' acc:', acc)
print('flod:' + str(K_flod) + ' sen:', sen)
print('flod:' + str(K_flod) + ' pre:', pre)
ALL_ACC = ALL_ACC + acc
ALL_SEN = ALL_SEN + sen
ALL_PRE = ALL_PRE + pre
plt.plot(np.linspace(0, 100, len(loss_list)), loss_list)
plt.show()
print('Flod:',K_flod,' Finish!')
print('==========================')
print('ALL_ACC:', ALL_ACC / 5)
print('ALL_SEN:', ALL_SEN / 5)
print('ALL_PRE:', ALL_PRE / 5)
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
数据挖掘大作业-基于python的电信客户流失预测与分析源码+模型文件.zip 这个项目是一个数据挖掘项目,旨在预测和分析电信客户的流失情况。 主要功能点 数据预处理和清洗 使用深度学习模型(如FCN)进行客户流失预测 对预测结果进行分析和可视化 技术栈 Python PyTorch - 不懂运行,下载完可以私聊问,可远程教学 该资源内项目源码是个人的毕设,代码都测试ok,都是运行成功后才上传资源,答辩评审平均分达到96分,放心下载使用! <项目介绍> 1、该资源内项目代码都经过测试运行成功,功能ok的情况下才上传的,请放心下载使用! 2、本项目适合计算机相关专业(如计科、人工智能、通信工程、自动化、电子信息等)的在校学生、老师或者企业员工下载学习,也适合小白学习进阶,当然也可作为毕设项目、课程设计、作业、项目初期立项演示等。 3、如果基础还行,也可在此代码基础上进行修改,以实现其他功能,也可用于毕设、课设、作业等。 下载后请首先打开README.md文件(如有),仅供学习参考, 切勿用于商业用途。 --------
资源推荐
资源详情
资源评论
收起资源包目录
数据挖掘大作业-基于python的电信客户流失预测与分析源码+模型文件.zip (36个子文件)
train_bf.py 2KB
runs
6.pt 99KB
11.pt 99KB
14.pt 99KB
1.pt 99KB
18.pt 99KB
2.pt 99KB
4.pt 99KB
5.pt 99KB
19.pt 99KB
7.pt 99KB
8.pt 99KB
12.pt 99KB
0.pt 99KB
10.pt 99KB
15.pt 99KB
16.pt 99KB
9.pt 99KB
3.pt 99KB
13.pt 99KB
17.pt 99KB
utils
__init__.py 0B
loss_func.py 2KB
.idea
misc.xml 196B
Mining.iml 284B
inspectionProfiles
profiles_settings.xml 174B
modules.xml 264B
.gitignore 182B
encodings.xml 157B
module
__init__.py 0B
model.py 1KB
train.py 3KB
data_process
op_1to2.py 1KB
op_2.py 1KB
op_0to1.py 1KB
down_dim.py 513B
共 36 条
- 1
资源评论
Scikit-learn
- 粉丝: 4815
- 资源: 3181
下载权益
C知道特权
VIP文章
课程特权
开通VIP
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- 人和箱子检测2-YOLO(v5至v11)、COCO、CreateML、Paligemma、TFRecord、VOC数据集合集.rar
- 清华大学2022年秋季学期 高等数值分析课程报告
- GEE错误集-Cannot add an object of type <Element> to the map. Might be fixable with an explicit .pdf
- 清华大学2022年秋季学期 高等数值分析课程报告
- 矩阵与线程的对应关系图
- 人体人员检测46-YOLO(v5至v9)、COCO、Darknet、TFRecord数据集合集.rar
- GEMM优化代码实现1
- java实现的堆排序 含代码说明和示例.docx
- 资料阅读器(先下载解压) 5.0.zip
- 人、垃圾、非垃圾检测18-YOLO(v5至v11)、COCO、CreateML、Paligemma、TFRecord、VOC数据集合集.rar
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功