#!/usr/bin/python
# -*- coding:utf-8 -*-
import logging
import os
import time
import warnings
import torch
from torch import nn
from torch import optim
import models
import AE_Datasets
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
def SAEloss(recon_x, x, z):
"""
recon_x: generating images
x: origin images
mu: latent mean
logvar: latent log variance
"""
reconstruction_function = nn.MSELoss() # mse loss
BCE = reconstruction_function(recon_x, x)
pmean = 0.5
p = F.sigmoid(z)
p = torch.mean(p, 1)
KLD = pmean * torch.log(pmean / p) + (1 - pmean) * torch.log((1 - pmean) / (1 - p))
KLD = torch.sum(KLD, 0)
return BCE + KLD
class train_utils(object):
def __init__(self, args, save_dir):
self.args = args
self.save_dir = save_dir
def setup(self):
"""
Initialize the datasets, model, loss and optimizer
:param args:
:return:
"""
args = self.args
# Consider the gpu or cpu condition
if torch.cuda.is_available():
self.device = torch.device("cuda")
self.device_count = torch.cuda.device_count()
logging.info('using {} gpus'.format(self.device_count))
assert args.batch_size % self.device_count == 0, "batch size should be divided by device count"
else:
warnings.warn("gpu is not available")
self.device = torch.device("cpu")
self.device_count = 1
logging.info('using {} cpu'.format(self.device_count))
# Load the datasets
if args.processing_type == 'O_A':
from AE_Datasets.O_A import datasets
Dataset = getattr(datasets, args.data_name)
elif args.processing_type == 'R_A':
from AE_Datasets.R_A import datasets
Dataset = getattr(datasets, args.data_name)
elif args.processing_type == 'R_NA':
from AE_Datasets.R_NA import datasets
Dataset = getattr(datasets, args.data_name)
else:
raise Exception("processing type not implement")
self.datasets = {}
self.datasets['train'], self.datasets['val'] = Dataset(args.data_dir, args.normlizetype).data_preprare()
self.dataloaders = {x: torch.utils.data.DataLoader(self.datasets[x], batch_size=args.batch_size,
shuffle=(True if x == 'train' else False),
num_workers=args.num_workers,
pin_memory=(True if self.device == 'cuda' else False))
for x in ['train', 'val']}
# Define the model
fmodel=getattr(models, args.model_name)
self.encoder = getattr(fmodel, 'encoder')(in_channel=Dataset.inputchannel, out_channel=Dataset.num_classes)
self.decoder = getattr(fmodel, 'decoder')(in_channel=Dataset.inputchannel,
out_channel=Dataset.num_classes)
self.classifier = getattr(fmodel, 'classifier')(in_channel=Dataset.inputchannel,
out_channel=Dataset.num_classes)
# Define the optimizer
if args.opt == 'sgd':
self.optimizer = optim.SGD([{'params': self.encoder.parameters()}, {'params': self.decoder.parameters()}],
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
elif args.opt == 'adam':
self.optimizer = optim.Adam([{'params': self.encoder.parameters()}, {'params': self.decoder.parameters()}],
lr=args.lr, weight_decay=args.weight_decay)
else:
raise Exception("optimizer not implement")
# Define the learning rate decay
if args.lr_scheduler == 'step':
steps = [int(step) for step in args.steps.split(',')]
self.lr_scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, steps, gamma=args.gamma)
elif args.lr_scheduler == 'exp':
self.lr_scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, args.gamma)
elif args.lr_scheduler == 'stepLR':
steps = int(args.steps)
self.lr_scheduler = optim.lr_scheduler.StepLR(self.optimizer, steps, args.gamma)
elif args.lr_scheduler == 'fix':
self.lr_scheduler = None
else:
raise Exception("lr schedule not implement")
# Define the optimizer
if args.opt == 'sgd':
self.optimizer1 = optim.SGD([{'params': self.encoder.parameters()}, {'params': self.classifier.parameters()}],
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
elif args.opt == 'adam':
self.optimizer1 = optim.Adam([{'params': self.encoder.parameters()}, {'params': self.classifier.parameters()}],
lr=args.lr, weight_decay=args.weight_decay)
else:
raise Exception("optimizer not implement")
# Define the learning rate decay
if args.lr_scheduler == 'step':
steps1 = [int(step) for step in args.steps1.split(',')]
self.lr_scheduler1 = optim.lr_scheduler.MultiStepLR(self.optimizer1, steps1, gamma=args.gamma)
elif args.lr_scheduler == 'exp':
self.lr_scheduler1 = optim.lr_scheduler.ExponentialLR(self.optimizer1, args.gamma)
elif args.lr_scheduler == 'stepLR':
steps1 = int(args.steps1)
self.lr_scheduler1 = optim.lr_scheduler.StepLR(self.optimizer1, steps1, args.gamma)
elif args.lr_scheduler == 'fix':
self.lr_scheduler1 = None
else:
raise Exception("lr schedule not implement")
self.start_epoch = 0
# Invert the model and define the loss
self.encoder.to(self.device)
self.decoder.to(self.device)
self.classifier.to(self.device)
self.criterion = nn.CrossEntropyLoss()
self.criterion1 = nn.MSELoss()
def train(self, writer):
"""
Training process
:return:
"""
args = self.args
step = 0
best_acc = 0.0
batch_count = 0
batch_loss = 0.0
batch_acc = 0
step_start = time.time()
traing_acc = []
testing_acc = []
traing_loss = []
testing_loss = []
print("Training Autoencoder with minimum loss")
for epoch in range(args.middle_epoch):
logging.info('-'*5 + 'Epoch {}/{}'.format(epoch, args.middle_epoch - 1) + '-'*5)
# Update the learning rate
if self.lr_scheduler is not None:
# self.lr_scheduler.step(epoch)
logging.info('current lr: {}'.format(self.lr_scheduler.get_lr()))
else:
logging.info('current lr: {}'.format(args.lr))
# Each epoch has a training and val phase
for phase in ['train', 'val']:
# Define the temp variable
epoch_start = time.time()
epoch_loss = 0.0
# Set model to train mode or test mode
if phase == 'train':
self.encoder.train()
self.decoder.train()
else:
self.encoder.eval()
self.decoder.eval()
for batch_idx, (inputs, labels) in enumerate(self.dataloaders[phase]):
inputs = inputs.to(self.device)
# Do the learning process, in val, we do not care about the gradient for relaxing
with torch.set_grad_enabled(phase == 'train'):
#forward
if args.model_name in ["Vae1d", "Vae2d"]:
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
毕设&课程作业_基于深度学习的故障检测算法.zip (493个子文件)
events.out.tfevents.1671781259.LAPTOP-1FVELO7I.18620.0 76KB
events.out.tfevents.1671783488.LAPTOP-1FVELO7I.31056.0 76KB
events.out.tfevents.1671780055.LAPTOP-1FVELO7I.33704.0 76KB
events.out.tfevents.1671780415.LAPTOP-1FVELO7I.36708.0 76KB
events.out.tfevents.1671780308.LAPTOP-1FVELO7I.32200.0 76KB
events.out.tfevents.1671782008.LAPTOP-1FVELO7I.33900.0 76KB
events.out.tfevents.1671782609.LAPTOP-1FVELO7I.8800.0 76KB
events.out.tfevents.1671780001.LAPTOP-1FVELO7I.32280.0 76KB
events.out.tfevents.1671781480.LAPTOP-1FVELO7I.34588.0 76KB
events.out.tfevents.1671780616.LAPTOP-1FVELO7I.30140.0 76KB
events.out.tfevents.1671783203.LAPTOP-1FVELO7I.37192.0 76KB
events.out.tfevents.1671782345.LAPTOP-1FVELO7I.36344.0 76KB
events.out.tfevents.1671780097.LAPTOP-1FVELO7I.32764.0 76KB
events.out.tfevents.1671712873.LAPTOP-1FVELO7I.24828.0 66KB
events.out.tfevents.1671713035.LAPTOP-1FVELO7I.24528.0 66KB
events.out.tfevents.1671713392.LAPTOP-1FVELO7I.29152.0 66KB
events.out.tfevents.1671712835.LAPTOP-1FVELO7I.26844.0 66KB
events.out.tfevents.1671712649.LAPTOP-1FVELO7I.26920.0 66KB
.DS_Store 12KB
.DS_Store 10KB
.DS_Store 10KB
.DS_Store 8KB
.DS_Store 6KB
.DS_Store 6KB
.DS_Store 6KB
.DS_Store 6KB
.DS_Store 6KB
.DS_Store 6KB
.DS_Store 6KB
.DS_Store 6KB
.DS_Store 6KB
.DS_Store 6KB
.DS_Store 6KB
.gitignore 47B
.gitignore 37B
DL-based-Intelligent-Diagnosis-Benchmark.iml 494B
training.log 38KB
training.log 38KB
training.log 38KB
training.log 38KB
training.log 38KB
training.log 37KB
training.log 37KB
training.log 37KB
training.log 37KB
training.log 37KB
training.log 37KB
training.log 37KB
training.log 28KB
training.log 28KB
training.log 27KB
training.log 27KB
training.log 27KB
training.log 27KB
training.log 26KB
training.log 26KB
training.log 26KB
training.log 26KB
training.log 26KB
training.log 25KB
training.log 25KB
training.log 25KB
training.log 25KB
training.log 25KB
training.log 25KB
training.log 25KB
README.md 2KB
.name 7B
train_utils_ae.py 19KB
train_utils.py 12KB
Resnet2d.py 7KB
Resnet1d.py 7KB
CWRUCWT.py 6KB
CWRUCWT.py 6KB
CWRUSTFT.py 6KB
CWRUSTFT.py 6KB
CWRUSlice.py 6KB
CWRUSlice.py 6KB
CWRUFFT.py 6KB
CWRUFFT.py 6KB
CWRU.py 6KB
CWRU.py 6KB
CWRUCWT.py 6KB
CWRUCWT.py 6KB
CWRUCWT.py 6KB
CWRUCWT.py 6KB
CWRUSTFT.py 6KB
CWRUSTFT.py 6KB
CWRUSlice.py 6KB
CWRUSTFT.py 6KB
CWRUSlice.py 6KB
CWRUSTFT.py 6KB
CWRUSlice.py 6KB
CWRUSlice.py 6KB
CWRUFFT.py 5KB
CWRUFFT.py 5KB
CWRUFFT.py 5KB
CWRUFFT.py 5KB
SEUCWT.py 5KB
SEUSlice.py 5KB
共 493 条
- 1
- 2
- 3
- 4
- 5
资源评论
学术菜鸟小晨
- 粉丝: 1w+
- 资源: 5534
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功