import os
import matplotlib.pyplot as plt
import random
from PIL import Image
def main(path):
classes = [i for i in os.listdir(path)] # ['cat', 'dog']
# 将所有图片按照 类别:路径 字典形式保存
images_path = [] # [{'cat': './data/train\\cat\\Baidu_0000.jpeg'}, {'cat': './data/train\\cat\\Baidu_0002.jpeg'}]
for cla in classes:
for i in os.listdir(os.path.join(path, cla)):
dic = {} # 类别:图像路径
img_path = os.path.join(path, cla, i)
dic[cla] = img_path # {'cat': './data/train\\cat\\Baidu_0000.jpeg'}
images_path.append(dic)
# 随机展示4张图像
plt.figure(figsize=(12, 8))
for i in range(4):
r = random.randint(0, len(images_path) - 1) # 生成随机数
label, im_path = list(images_path[r].keys())[0], list(images_path[r].values())[0]
# cat , ./data/train\cat\Baidu_0049.jpeg
im = Image.open(im_path)
plt.subplot(2, 2, i + 1)
plt.title(label)
plt.imshow(im)
plt.savefig('show.png') # 保存图片
# plt.show()
if __name__ == '__main__':
root = './data/train' # 传入目录
main(path=root)