from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
import os
import numpy as np
from utils.Evaluator import Evaluator
import torch
import torch.nn as nn
from utils.img_read_save import img_save,image_read_cv2
import warnings
import logging
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.CRITICAL)
import cv2
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
CDDFuse_path=r"models/CDDFuse_IVF.pth"
CDDFuse_MIF_path=r"models/CDDFuse_MIF.pth"
for dataset_name in ["MRI_CT","MRI_PET","MRI_SPECT"]:
print("\n"*2+"="*80)
print("The test result of "+dataset_name+" :")
print("\t\t EN\t SD\t SF\t MI\tSCD\tVIF\tQabf\tSSIM")
for ckpt_path in [CDDFuse_path,CDDFuse_MIF_path]:
model_name=ckpt_path.split('/')[-1].split('.')[0]
test_folder=os.path.join('test_img',dataset_name)
test_out_folder=os.path.join('test_result',dataset_name)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
Encoder.load_state_dict(torch.load(ckpt_path)['DIDF_Encoder'])
Decoder.load_state_dict(torch.load(ckpt_path)['DIDF_Decoder'])
BaseFuseLayer.load_state_dict(torch.load(ckpt_path)['BaseFuseLayer'])
DetailFuseLayer.load_state_dict(torch.load(ckpt_path)['DetailFuseLayer'])
Encoder.eval()
Decoder.eval()
BaseFuseLayer.eval()
DetailFuseLayer.eval()
'''
单通道版
with torch.no_grad():
for img_name in os.listdir(os.path.join(test_folder,dataset_name.split('_')[0])):
data_IR=image_read_cv2(os.path.join(test_folder,dataset_name.split('_')[1],img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
data_VIS = image_read_cv2(os.path.join(test_folder,dataset_name.split('_')[0],img_name), mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
data_IR,data_VIS = torch.FloatTensor(data_IR),torch.FloatTensor(data_VIS)
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
feature_V_B, feature_V_D, feature_V = Encoder(data_VIS)
feature_I_B, feature_I_D, feature_I = Encoder(data_IR)
feature_F_B = BaseFuseLayer(feature_V_B + feature_I_B)
feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D)
if ckpt_path==CDDFuse_path:
data_Fuse, _ = Decoder(data_IR+data_VIS, feature_F_B, feature_F_D)
else:
data_Fuse, _ = Decoder(None, feature_F_B, feature_F_D)
data_Fuse=(data_Fuse-torch.min(data_Fuse))/(torch.max(data_Fuse)-torch.min(data_Fuse))
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
img_save(fi, img_name.split(sep='.')[0], test_out_folder)
eval_folder=test_out_folder
ori_img_folder=test_folder
'''
'''
多通道版
'''
with torch.no_grad():
for img_name in os.listdir(os.path.join(test_folder,dataset_name.split('_')[0])):
data_IR=image_read_cv2(os.path.join(test_folder,dataset_name.split('_')[0],img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
data_VIS = cv2.split(image_read_cv2(os.path.join(test_folder,dataset_name.split('_')[1],img_name), mode='YCrCb'))[0][np.newaxis,np.newaxis, ...]/255.0
data_VIS_BGR = cv2.imread(os.path.join(test_folder,dataset_name.split('_')[1],img_name))
_, data_VIS_Cr, data_VIS_Cb = cv2.split(cv2.cvtColor(data_VIS_BGR, cv2.COLOR_BGR2YCrCb))
data_IR,data_VIS = torch.FloatTensor(data_IR),torch.FloatTensor(data_VIS)
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
feature_V_B, feature_V_D, feature_V = Encoder(data_VIS)
feature_I_B, feature_I_D, feature_I = Encoder(data_IR)
feature_F_B = BaseFuseLayer(feature_V_B + feature_I_B)
feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D)
if ckpt_path==CDDFuse_path:
data_Fuse, _ = Decoder(data_IR+data_VIS, feature_F_B, feature_F_D)
else:
data_Fuse, _ = Decoder(None, feature_F_B, feature_F_D)
data_Fuse=(data_Fuse-torch.min(data_Fuse))/(torch.max(data_Fuse)-torch.min(data_Fuse))
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
fi=fi.astype(np.uint8)
ycrcb_fi = np.dstack((fi, data_VIS_Cr, data_VIS_Cb))
rgb_fi = cv2.cvtColor(ycrcb_fi, cv2.COLOR_YCrCb2RGB)
img_save(rgb_fi, img_name.split(sep='.')[0], test_out_folder)
eval_folder=test_out_folder
ori_img_folder=test_folder
metric_result = np.zeros((8))
for img_name in os.listdir(os.path.join(ori_img_folder,dataset_name.split('_')[0])):
ir = image_read_cv2(os.path.join(ori_img_folder,dataset_name.split('_')[1], img_name), 'GRAY')
vi = image_read_cv2(os.path.join(ori_img_folder,dataset_name.split('_')[0], img_name), 'GRAY')
fi = image_read_cv2(os.path.join(eval_folder, img_name.split('.')[0]+".png"), 'GRAY')
metric_result += np.array([Evaluator.EN(fi), Evaluator.SD(fi)
, Evaluator.SF(fi), Evaluator.MI(fi, ir, vi)
, Evaluator.SCD(fi, ir, vi), Evaluator.VIFF(fi, ir, vi)
, Evaluator.Qabf(fi, ir, vi), Evaluator.SSIM(fi, ir, vi)])
metric_result /= len(os.listdir(eval_folder))
print('ours:'+'\t'+str(np.round(metric_result[0], 2))+'\t' #model_name
+str(np.round(metric_result[1], 2))+'\t'
+str(np.round(metric_result[2], 2))+'\t'
+str(np.round(metric_result[3], 2))+'\t'
+str(np.round(metric_result[4], 2))+'\t'
+str(np.round(metric_result[5], 2))+'\t'
+str(np.round(metric_result[6], 2))+'\t'
+str(np.round(metric_result[7], 2))
)
print("="*80)
没有合适的资源?快使用搜索试试~ 我知道了~
CDDFuse无颜色问题(已解决)
共2个文件
py:2个
1 下载量 128 浏览量
2024-05-14
17:04:27
上传
评论
收藏 3KB ZIP 举报
温馨提示
CDDFuse无颜色问题(已解决)
资源推荐
资源详情
资源评论
收起资源包目录
CDDFuse无颜色问题解决文件.zip (2个子文件)
CDDFuse无颜色问题解决文件
test_IVF.py 6KB
test_MIF.py 6KB
共 2 条
- 1
资源评论
林浩杨_
- 粉丝: 93
- 资源: 1
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- 基于Javascript和微信小程序的Anna设计源码
- 基于Java的仿制品设计源码 - bilibili
- 基于Javascript的影视动画设计源码 - cad
- 基于Java和深度学习的瓦斯浓度预测系统后端设计源码 - 瓦斯浓度预测后端
- Screenshot_20240528_103010.jpg
- 基于Python的新能源承载力计算及界面设计源码 - HAINING-DG
- 基于Java的本科探索学习项目设计源码 - 本科探索
- 基于Javascript和Python的微商城项目设计源码 - MicroMall
- 基于Java的网上订餐系统设计源码 - online ordering system
- 基于Javascript的超级美眉网络资源管理应用模块设计源码
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功