pytorch 实现实现yolo3详细理解(三)详细理解(三) 数据集处理数据集处理
本章详细讲解数据的处理问题,将coco数据集读取,以及之后自定义数据集的处理,
数据预处理思想数据预处理思想
yolo3的数据集处理也是一大亮点,由于yolo3对数据集的输入有要求,指定的照片输入大小必须是416,所有对于不满足照片
的大小有一系列的操作,如果直接resize操作,将直接损失照片信息,网络在学习分类的过程还要适应照片尺寸的问题,导致
训练效果不佳,在yolo3中是先进行高和宽的调整一样大,在进行上采样的resize,同时要修改label的坐标位置,随机水平翻
转,再一次随机变化大小,之后再变化到416的大小尺寸作为输入。
代码代码
class ListDataset(Dataset): #继承Dataset
def __init__(self, list_path, img_size=416, augment=True, multiscale=True, normalized_labels=True):
with open(list_path, "r") as file:
self.img_files = file.readlines()
self.label_files = [
path.replace("images", "labels").replace(".png", ".txt").replace(".jpg", ".txt") #这一步是生成labels的位置
for path in self.img_files
] self.img_size = img_size
self.max_objects = 100
self.augment = augment
self.multiscale = multiscale
self.normalized_labels = normalized_labels
self.min_size = self.img_size - 3 * 32
self.max_size = self.img_size + 3 * 32
self.batch_count = 0
def __getitem__(self, index):
# ---------
# Image
# ---------
img_path = self.img_files[index % len(self.img_files)].rstrip() #按照索引的方式找到对应的路径
# Extract image as PyTorch tensor
img = transforms.ToTensor()(Image.open(img_path).convert('RGB')) #读取照片
# Handle images with less than three channels
if len(img.shape) != 3:
img = img.unsqueeze(0)
img = img.expand((3, img.shape[1:]))
_, h, w = img.shape
h_factor, w_factor = (h, w) if self.normalized_labels else (1, 1) #直接理解为照片的宽度和高度
# Pad to square resolution
img, pad = pad_to_square(img, 0) #这一步就是将高和宽变成一样大小
_, padded_h, padded_w = img.shape
# ---------
# Label
# ---------
label_path = self.label_files[index % len(self.img_files)].rstrip() #照片对应的label路径
targets = None
if os.path.exists(label_path):
boxes = torch.from_numpy(np.loadtxt(label_path).reshape(-1, 5))
# Extract coordinates for unpadded + unscaled image
x1 = w_factor * (boxes[:, 1] - boxes[:, 3] / 2) #label的坐标点位置是xywh所以先进行转化
y1 = h_factor * (boxes[:, 2] - boxes[:, 4] / 2)
x2 = w_factor * (boxes[:, 1] + boxes[:, 3] / 2)
y2 = h_factor * (boxes[:, 2] + boxes[:, 4] / 2)
# Adjust for added padding
x1 += pad[0] #照片大小变化了所以框的坐标点需要修改
y1 += pad[2] x2 += pad[1] y2 += pad[3] # Returns (x, y, w, h)
boxes[:, 1] = ((x1 + x2) / 2) / padded_w #在次重新转化xywh形式
boxes[:, 2] = ((y1 + y2) / 2) / padded_h