# 根据用户的反馈,我将更新 infer-pt.py 脚本,使其在命令行参数中接受模型的路径。
# 这样用户就不需要每次都修改脚本来指定模型路径。
# 更新后的 infer-pt.py 脚本
# 需要注意的是,以下代码是一个示例,需要在用户的本地环境中运行。
# 导入所需库
import argparse
import os
import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from efficientnet_pytorch import EfficientNet
from tqdm import tqdm
# 初始化解析器
parser = argparse.ArgumentParser(description='Perform inference on a single image or a directory of images')
parser.add_argument('input', type=str, help='Path to an image or directory of images to perform inference on')
parser.add_argument('--model_path', type=str, default='F:/skin_cancer/code/ISIC_HAM10000/zcyarchive/model.pth', help='Path to the model checkpoint')
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda', help='Compute device to use')
args = parser.parse_args()
# 设备配置
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
# 定义图像预处理,调整为EfficientNet的标准
transform = Compose([
Resize(256), # 调整为稍大的尺寸
CenterCrop(224), # 中心裁剪到224x224
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化
])
# 载入EfficientNet模型
model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=7)
model.load_state_dict(torch.load(args.model_path, map_location=device))
model.to(device)
model.eval()
# 推理单个图像
def infer_image(image_path):
image = Image.open(image_path).convert("RGB")
image = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(image)
probabilities = torch.nn.functional.softmax(logits, dim=1)
predicted_class = probabilities.argmax(dim=1).item()
class_probabilities = probabilities.squeeze().tolist()
return predicted_class, class_probabilities
# 推理目录中的图像
def process_directory(directory_path):
images = [img for img in os.listdir(directory_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
predictions = []
for img_name in tqdm(images, desc="Inferencing"):
img_path = os.path.join(directory_path, img_name)
predicted_class, class_probabilities = infer_image(img_path)
predictions.append((img_name, predicted_class, class_probabilities))
print(f'Image: {img_name}, Class: {predicted_class}, Probabilities: {class_probabilities}')
return predictions
# 判断是单个文件还是目录
if __name__ == '__main__':
if os.path.isdir(args.input):
process_directory(args.input)
elif os.path.isfile(args.input):
predicted_class, class_probabilities = infer_image(args.input)
print(f'Image: {args.input}, Class: {predicted_class}, Probabilities: {class_probabilities}')
else:
print("The input path provided does not exist.")